diff --git a/func.go b/func.go index 5fa1cbc8..5f3c2702 100644 --- a/func.go +++ b/func.go @@ -39,3 +39,23 @@ func Partial5[T1, T2, T3, T4, T5, T6, R any](f func(T1, T2, T3, T4, T5, T6) R, a return f(arg1, t2, t3, t4, t5, t6) } } + +// Compose returns new function that, when called, returns the result of calling g and then f with the result from g +func Compose[T, U, V any](f func(U) V, g func(T) U) func (T) V { + return func(t T) V { + return f(g(t)) + } +} + +func Compose3[T1, T2, T3, R any](f func(T3) R, g func(T2) T3, h func(T1) T2) func (T1) R { + return Compose(f, Compose(g, h)) +} + +// Pipe returns new function that, when called, returns the result of calling f and then g with the result from f +func Pipe[T, U, V any](f func(T) U, g func(U) V) func (T) V { + return Compose(g, f) +} + +func Pipe3[T1, T2, T3, R any](f func(T1) T2, g func(T2) T3, h func(T3) R) func (T1) R { + return Pipe(f, Pipe(g, h)) +} diff --git a/func_test.go b/func_test.go index 284d24a4..fb0299c8 100644 --- a/func_test.go +++ b/func_test.go @@ -78,3 +78,55 @@ func TestPartial5(t *testing.T) { is.Equal("26", f(10, 9, -3, 0, 5)) is.Equal("21", f(-5, 8, 7, -1, 7)) } + +func sumBy2(x int) int { return x + 2 } +func mulBy3(x int) int { return x * 3 } + +func TestCompose(t *testing.T) { + t.Parallel() + is := assert.New(t) + + sumBy2AndMulBy3 := Compose(mulBy3, sumBy2) + mulBy3AndSumBy2 := Compose(sumBy2, mulBy3) + + val := 1 + is.Equal(9, sumBy2AndMulBy3(val)) + is.Equal(5, mulBy3AndSumBy2(val)) +} + +func TestCompose3(t *testing.T) { + t.Parallel() + + sumBy2MulBy3AndSumBy2 := Compose3(sumBy2, mulBy3, sumBy2) + mulBy3SumBy2AndMulBy3 := Compose3(mulBy3, sumBy2, mulBy3) + + is := assert.New(t) + + val := 1 + is.Equal(11, sumBy2MulBy3AndSumBy2(val)) + is.Equal(15, mulBy3SumBy2AndMulBy3(val)) +} + +func TestPipe(t *testing.T) { + t.Parallel() + is := assert.New(t) + + sumBy2AndMulBy3 := Pipe(sumBy2, mulBy3) + mulBy3AndSumBy2 := Pipe(mulBy3, sumBy2) + + val := 1 + is.Equal(9, sumBy2AndMulBy3(val)) + is.Equal(5, mulBy3AndSumBy2(val)) +} + +func TestPipe3(t *testing.T) { + t.Parallel() + is := assert.New(t) + + sumBy2MulBy3AndSumBy2 := Pipe3(sumBy2, mulBy3, sumBy2) + mulBy3SumBy2AndMulBy3 := Pipe3(mulBy3, sumBy2, mulBy3) + + val := 1 + is.Equal(11, sumBy2MulBy3AndSumBy2(val)) + is.Equal(15, mulBy3SumBy2AndMulBy3(val)) +}