diff --git a/slice.go b/slice.go index 54d1b004..74973d7c 100644 --- a/slice.go +++ b/slice.go @@ -8,8 +8,8 @@ import ( // Filter iterates over elements of collection, returning an array of all elements predicate returns truthy for. // Play: https://go.dev/play/p/Apjg3WeSi7K -func Filter[V any](collection []V, predicate func(item V, index int) bool) []V { - result := make([]V, 0, len(collection)) +func Filter[V any, VC ~[]V](collection VC, predicate func(item V, index int) bool) VC { + result := make(VC, 0, len(collection)) for i := range collection { if predicate(collection[i], i) { diff --git a/slice_test.go b/slice_test.go index cd3ee829..1d0611ee 100644 --- a/slice_test.go +++ b/slice_test.go @@ -26,6 +26,13 @@ func TestFilter(t *testing.T) { }) is.Equal(r2, []string{"foo", "bar"}) + + type myStrings []string + allStrings := myStrings{"", "foo", "bar"} + nonempty := Filter(allStrings, func(x string, _ int) bool { + return len(x) > 0 + }) + is.IsType(nonempty, allStrings, "type preserved") } func TestMap(t *testing.T) {