From 6caf936e1ed633d76a36a3df33d6da2072ed159d Mon Sep 17 00:00:00 2001 From: Joe Greener Date: Sun, 30 Jul 2023 18:19:00 +0100 Subject: [PATCH 01/13] Julia function tests --- test/runtests.jl | 117 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 117 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index fd31a13ea0..a0fd73f8ba 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -707,6 +707,123 @@ end Enzyme.API.strictAliasing!(true) f10(x) = hypot(x, 2x) @test autodiff(Reverse, f10, Active, Active(2.0))[1][1] == sqrt(5) + @test autodiff(Forward, f10, Duplicated(2.0, 1.0))[1] == sqrt(5) + + f11(x) = x * sum(LinRange(x, 10.0, 6)) + @test autodiff(Reverse, f11, Active, Active(2.0))[1][1] == 42 + @test autodiff(Forward, f11, Duplicated(2.0, 1.0))[1] == 42 + + f12(x, k) = get(Dict(1 => 1.0, 2 => x, 3 => 3.0), k, 1.0) + @test autodiff(Reverse, f12, Active, Active(2.0), Const(2))[1] == (1.0, nothing) + @test autodiff(Forward, f12, Duplicated(2.0, 1.0), Const(2)) == (1.0,) + @test autodiff(Reverse, f12, Active, Active(2.0), Const(3))[1] == (0.0, nothing) + @test autodiff(Forward, f12, Duplicated(2.0, 1.0), Const(3)) == (0.0,) + @test autodiff(Reverse, f12, Active, Active(2.0), Const(4))[1] == (0.0, nothing) + @test autodiff(Forward, f12, Duplicated(2.0, 1.0), Const(4)) == (0.0,) + + f13(x) = muladd(x, 3, x) + @test autodiff(Reverse, f13, Active, Active(2.0))[1][1] == 4 + @test autodiff(Forward, f13, Duplicated(2.0, 1.0))[1] == 4 + + f14(x) = x * cmp(x, 3) + @test autodiff(Reverse, f14, Active, Active(2.0))[1][1] == -1 + @test autodiff(Forward, f14, Duplicated(2.0, 1.0))[1] == -1 + + f15(x) = x * argmax([1.0, 3.0, 2.0]) + @test autodiff(Reverse, f15, Active, Active(3.0))[1][1] == 2 + @test autodiff(Forward, f15, Duplicated(3.0, 1.0))[1] == 2 + + f16(x) = evalpoly(2, (1, 2, x)) + @test autodiff(Reverse, f16, Active, Active(3.0))[1][1] == 4 + @test autodiff(Forward, f16, Duplicated(3.0, 1.0))[1] == 4 + + f17(x) = @evalpoly(2, 1, 2, x) + @test autodiff(Reverse, f17, Active, Active(3.0))[1][1] == 4 + @test autodiff(Forward, f17, Duplicated(3.0, 1.0))[1] == 4 + + f18(x) = widemul(x, 5.0f0) + @test autodiff(Reverse, f18, Active, Active(2.0f0))[1][1] == 5 + @test autodiff(Forward, f18, Duplicated(2.0f0, 1.0f0))[1] == 5 + + f19(x) = copysign(x, -x) + @test autodiff(Reverse, f19, Active, Active(2.0))[1][1] == -1 + @test autodiff(Forward, f19, Duplicated(2.0, 1.0))[1] == -1 + + f20(x) = sum([ifelse(i > 5, i, zero(i)) for i in [x, 2x, 3x, 4x]]) + @test autodiff(Reverse, f20, Active, Active(2.0))[1][1] == 7 + @test autodiff(Forward, f20, Duplicated(2.0, 1.0))[1] == 7 + + function f21(x) + nt = (a=x, b=2x, c=3x) + return nt.c + end + @test autodiff(Reverse, f21, Active, Active(2.0))[1][1] == 3 + @test autodiff(Forward, f21, Duplicated(2.0, 1.0))[1] == 3 + + f22(x) = sum(fill(x, (3, 3))) + @test autodiff(Reverse, f22, Active, Active(2.0))[1][1] == 9 + @test autodiff(Forward, f22, Duplicated(2.0, 1.0))[1] == 9 + + function f23(x) + a = similar(rand(3, 3)) + fill!(a, x) + return sum(a) + end + @test autodiff(Reverse, f23, Active, Active(2.0))[1][1] == 9 + @test autodiff(Forward, f23, Duplicated(2.0, 1.0))[1] == 9 + + function f24(x) + try + return 3x + catch + return 2x + end + end + @test autodiff(Reverse, f24, Active, Active(2.0))[1][1] == 3 + @test autodiff(Forward, f24, Duplicated(2.0, 1.0))[1] == 3 + + function f25(x) + try + sqrt(-1.0) + return 3x + catch + return 2x + end + end + @test autodiff(Reverse, f25, Active, Active(2.0))[1][1] == 2 + @test autodiff(Forward, f25, Duplicated(2.0, 1.0))[1] == 2 + + f26(x) = circshift([1.0, 2x, 3.0], 1)[end] + @test autodiff(Reverse, f26, Active, Active(2.0))[1][1] == 2 + @test autodiff(Forward, f26, Duplicated(2.0, 1.0))[1] == 2 + + f27(x) = sum(diff([0.0 x; 1.0 2x]; dims=2)) + @test autodiff(Reverse, f27, Active, Active(2.0))[1][1] == 3 + @test autodiff(Forward, f27, Duplicated(2.0, 1.0))[1] == 3 + + f28(x) = repeat([x 3x], 3)[2, 2] + @test autodiff(Reverse, f28, Active, Active(2.0))[1][1] == 3 + @test autodiff(Forward, f28, Duplicated(2.0, 1.0))[1] == 3 + + f29(x) = rot180([x 2x; 3x 4x], 3)[1, 1] + @test autodiff(Reverse, f29, Active, Active(2.0))[1][1] == 4 + @test autodiff(Forward, f29, Duplicated(2.0, 1.0))[1] == 4 + + f30(x) = x * sum(trues(4, 3)) + @test autodiff(Reverse, f30, Active, Active(2.0))[1][1] == 12 + @test autodiff(Forward, f30, Duplicated(2.0, 1.0))[1] == 12 + + f31(x) = sum(Set([1.0, x, 2x, x])) + @test autodiff(Reverse, f31, Active, Active(2.0))[1][1] == 3 + @test autodiff(Forward, f31, Duplicated(2.0, 1.0))[1] == 3 + + f32(x) = reverse([x 2.0 3x])[1] + @test autodiff(Reverse, f32, Active, Active(2.0))[1][1] == 3 + @test autodiff(Forward, f32, Duplicated(2.0, 1.0))[1] == 3 + + f33(x) = sum(skipmissing([x, missing, 3.0, 3x])) + @test autodiff(Reverse, f33, Active, Active(2.0))[1][1] == 4 + @test autodiff(Forward, f33, Duplicated(2.0, 1.0))[1] == 4 end function deadarg_pow(z::T, i) where {T<:Real} From 41dca7858f6b5b2c8fb725f3f2755ab497788093 Mon Sep 17 00:00:00 2001 From: Joe Greener Date: Sun, 30 Jul 2023 18:19:10 +0100 Subject: [PATCH 02/13] Nested reverse test --- test/runtests.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index a0fd73f8ba..a599a0f659 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -887,6 +887,14 @@ end tonest(x,y) = (x + y)^2 @test autodiff(Forward, (x,y) -> autodiff(Forward, Const(tonest), Duplicated(x, 1.0), Const(y))[1], Const(1.0), Duplicated(2.0, 1.0))[1] ≈ 2.0 + + f_nest(x) = 2 * x^4 + deriv(f, x) = first(first(autodiff_deferred(Reverse, f, Active(x)))) + f′(x) = deriv(f_nest, x) + f′′(x) = deriv(f′, x) + + @test f′(2.0) == 64 + @test f′′(2.0) == 96 end @testset "Hessian" begin From ece2df3cb2eff7f0ea5b21c238b2c75fa00126b1 Mon Sep 17 00:00:00 2001 From: Joe Greener Date: Sun, 30 Jul 2023 23:58:00 +0100 Subject: [PATCH 03/13] Disable failing tests on earlier Julia versions --- test/runtests.jl | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index a599a0f659..56193a4923 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -782,16 +782,19 @@ end @test autodiff(Reverse, f24, Active, Active(2.0))[1][1] == 3 @test autodiff(Forward, f24, Duplicated(2.0, 1.0))[1] == 3 - function f25(x) - try - sqrt(-1.0) - return 3x - catch - return 2x + # See #971 + @static if VERSION ≥ v"1.9-" + function f25(x) + try + sqrt(-1.0) + return 3x + catch + return 2x + end end + @test autodiff(Reverse, f25, Active, Active(2.0))[1][1] == 2 + @test autodiff(Forward, f25, Duplicated(2.0, 1.0))[1] == 2 end - @test autodiff(Reverse, f25, Active, Active(2.0))[1][1] == 2 - @test autodiff(Forward, f25, Duplicated(2.0, 1.0))[1] == 2 f26(x) = circshift([1.0, 2x, 3.0], 1)[end] @test autodiff(Reverse, f26, Active, Active(2.0))[1][1] == 2 @@ -821,9 +824,12 @@ end @test autodiff(Reverse, f32, Active, Active(2.0))[1][1] == 3 @test autodiff(Forward, f32, Duplicated(2.0, 1.0))[1] == 3 - f33(x) = sum(skipmissing([x, missing, 3.0, 3x])) - @test autodiff(Reverse, f33, Active, Active(2.0))[1][1] == 4 - @test autodiff(Forward, f33, Duplicated(2.0, 1.0))[1] == 4 + # See #970 + @static if VERSION ≥ v"1.9-" + f33(x) = sum(skipmissing([x, missing, 3.0, 3x])) + @test autodiff(Reverse, f33, Active, Active(2.0))[1][1] == 4 + @test autodiff(Forward, f33, Duplicated(2.0, 1.0))[1] == 4 + end end function deadarg_pow(z::T, i) where {T<:Real} From c51e81fba8880de3b9f18ec638788976ece6a954 Mon Sep 17 00:00:00 2001 From: Joe Greener Date: Tue, 1 Aug 2023 14:02:40 +0100 Subject: [PATCH 04/13] Enable try/catch test --- test/runtests.jl | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 56193a4923..f8ea8fa585 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -782,19 +782,16 @@ end @test autodiff(Reverse, f24, Active, Active(2.0))[1][1] == 3 @test autodiff(Forward, f24, Duplicated(2.0, 1.0))[1] == 3 - # See #971 - @static if VERSION ≥ v"1.9-" - function f25(x) - try - sqrt(-1.0) - return 3x - catch - return 2x - end + function f25(x) + try + sqrt(-1.0) + return 3x + catch + return 2x end - @test autodiff(Reverse, f25, Active, Active(2.0))[1][1] == 2 - @test autodiff(Forward, f25, Duplicated(2.0, 1.0))[1] == 2 end + @test autodiff(Reverse, f25, Active, Active(2.0))[1][1] == 2 + @test autodiff(Forward, f25, Duplicated(2.0, 1.0))[1] == 2 f26(x) = circshift([1.0, 2x, 3.0], 1)[end] @test autodiff(Reverse, f26, Active, Active(2.0))[1][1] == 2 From 8f15f16bc6035ba3d531ad5143cdedec79359773 Mon Sep 17 00:00:00 2001 From: Joe Greener Date: Tue, 1 Aug 2023 14:03:03 +0100 Subject: [PATCH 05/13] Printing to find CI error --- test/runtests.jl | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index f8ea8fa585..7148ce0c87 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -708,10 +708,12 @@ end f10(x) = hypot(x, 2x) @test autodiff(Reverse, f10, Active, Active(2.0))[1][1] == sqrt(5) @test autodiff(Forward, f10, Duplicated(2.0, 1.0))[1] == sqrt(5) + println("Done 10") f11(x) = x * sum(LinRange(x, 10.0, 6)) @test autodiff(Reverse, f11, Active, Active(2.0))[1][1] == 42 @test autodiff(Forward, f11, Duplicated(2.0, 1.0))[1] == 42 + println("Done 11") f12(x, k) = get(Dict(1 => 1.0, 2 => x, 3 => 3.0), k, 1.0) @test autodiff(Reverse, f12, Active, Active(2.0), Const(2))[1] == (1.0, nothing) @@ -720,38 +722,47 @@ end @test autodiff(Forward, f12, Duplicated(2.0, 1.0), Const(3)) == (0.0,) @test autodiff(Reverse, f12, Active, Active(2.0), Const(4))[1] == (0.0, nothing) @test autodiff(Forward, f12, Duplicated(2.0, 1.0), Const(4)) == (0.0,) + println("Done 12") f13(x) = muladd(x, 3, x) @test autodiff(Reverse, f13, Active, Active(2.0))[1][1] == 4 @test autodiff(Forward, f13, Duplicated(2.0, 1.0))[1] == 4 + println("Done 13") f14(x) = x * cmp(x, 3) @test autodiff(Reverse, f14, Active, Active(2.0))[1][1] == -1 @test autodiff(Forward, f14, Duplicated(2.0, 1.0))[1] == -1 + println("Done 14") f15(x) = x * argmax([1.0, 3.0, 2.0]) @test autodiff(Reverse, f15, Active, Active(3.0))[1][1] == 2 @test autodiff(Forward, f15, Duplicated(3.0, 1.0))[1] == 2 + println("Done 15") f16(x) = evalpoly(2, (1, 2, x)) @test autodiff(Reverse, f16, Active, Active(3.0))[1][1] == 4 @test autodiff(Forward, f16, Duplicated(3.0, 1.0))[1] == 4 + println("Done 16") f17(x) = @evalpoly(2, 1, 2, x) @test autodiff(Reverse, f17, Active, Active(3.0))[1][1] == 4 @test autodiff(Forward, f17, Duplicated(3.0, 1.0))[1] == 4 + println("Done 17") f18(x) = widemul(x, 5.0f0) @test autodiff(Reverse, f18, Active, Active(2.0f0))[1][1] == 5 @test autodiff(Forward, f18, Duplicated(2.0f0, 1.0f0))[1] == 5 + println("Done 18") f19(x) = copysign(x, -x) @test autodiff(Reverse, f19, Active, Active(2.0))[1][1] == -1 @test autodiff(Forward, f19, Duplicated(2.0, 1.0))[1] == -1 + println("Done 19") f20(x) = sum([ifelse(i > 5, i, zero(i)) for i in [x, 2x, 3x, 4x]]) @test autodiff(Reverse, f20, Active, Active(2.0))[1][1] == 7 @test autodiff(Forward, f20, Duplicated(2.0, 1.0))[1] == 7 + println("Done 20") function f21(x) nt = (a=x, b=2x, c=3x) @@ -759,10 +770,12 @@ end end @test autodiff(Reverse, f21, Active, Active(2.0))[1][1] == 3 @test autodiff(Forward, f21, Duplicated(2.0, 1.0))[1] == 3 + println("Done 21") f22(x) = sum(fill(x, (3, 3))) @test autodiff(Reverse, f22, Active, Active(2.0))[1][1] == 9 @test autodiff(Forward, f22, Duplicated(2.0, 1.0))[1] == 9 + println("Done 22") function f23(x) a = similar(rand(3, 3)) @@ -771,6 +784,7 @@ end end @test autodiff(Reverse, f23, Active, Active(2.0))[1][1] == 9 @test autodiff(Forward, f23, Duplicated(2.0, 1.0))[1] == 9 + println("Done 23") function f24(x) try @@ -781,6 +795,7 @@ end end @test autodiff(Reverse, f24, Active, Active(2.0))[1][1] == 3 @test autodiff(Forward, f24, Duplicated(2.0, 1.0))[1] == 3 + println("Done 24") function f25(x) try @@ -792,34 +807,42 @@ end end @test autodiff(Reverse, f25, Active, Active(2.0))[1][1] == 2 @test autodiff(Forward, f25, Duplicated(2.0, 1.0))[1] == 2 + println("Done 25") f26(x) = circshift([1.0, 2x, 3.0], 1)[end] @test autodiff(Reverse, f26, Active, Active(2.0))[1][1] == 2 @test autodiff(Forward, f26, Duplicated(2.0, 1.0))[1] == 2 + println("Done 26") f27(x) = sum(diff([0.0 x; 1.0 2x]; dims=2)) @test autodiff(Reverse, f27, Active, Active(2.0))[1][1] == 3 @test autodiff(Forward, f27, Duplicated(2.0, 1.0))[1] == 3 + println("Done 27") f28(x) = repeat([x 3x], 3)[2, 2] @test autodiff(Reverse, f28, Active, Active(2.0))[1][1] == 3 @test autodiff(Forward, f28, Duplicated(2.0, 1.0))[1] == 3 + println("Done 28") f29(x) = rot180([x 2x; 3x 4x], 3)[1, 1] @test autodiff(Reverse, f29, Active, Active(2.0))[1][1] == 4 @test autodiff(Forward, f29, Duplicated(2.0, 1.0))[1] == 4 + println("Done 29") f30(x) = x * sum(trues(4, 3)) @test autodiff(Reverse, f30, Active, Active(2.0))[1][1] == 12 @test autodiff(Forward, f30, Duplicated(2.0, 1.0))[1] == 12 + println("Done 30") f31(x) = sum(Set([1.0, x, 2x, x])) @test autodiff(Reverse, f31, Active, Active(2.0))[1][1] == 3 @test autodiff(Forward, f31, Duplicated(2.0, 1.0))[1] == 3 + println("Done 31") f32(x) = reverse([x 2.0 3x])[1] @test autodiff(Reverse, f32, Active, Active(2.0))[1][1] == 3 @test autodiff(Forward, f32, Duplicated(2.0, 1.0))[1] == 3 + println("Done 32") # See #970 @static if VERSION ≥ v"1.9-" @@ -827,6 +850,7 @@ end @test autodiff(Reverse, f33, Active, Active(2.0))[1][1] == 4 @test autodiff(Forward, f33, Duplicated(2.0, 1.0))[1] == 4 end + println("Done 33") end function deadarg_pow(z::T, i) where {T<:Real} From 30ec609f1596cb01d50334ea77a594d38c1d48a5 Mon Sep 17 00:00:00 2001 From: Joe Greener Date: Tue, 8 Aug 2023 23:50:30 +0100 Subject: [PATCH 06/13] Mark try test as broken on Julia 1.6 --- test/runtests.jl | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 7148ce0c87..cd1c06fb9c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -805,8 +805,14 @@ end return 2x end end - @test autodiff(Reverse, f25, Active, Active(2.0))[1][1] == 2 - @test autodiff(Forward, f25, Duplicated(2.0, 1.0))[1] == 2 + # Gives 0.0 on Julia 1.6, see #971 + @static if VERSION ≥ v"1.7-" + @test autodiff(Reverse, f25, Active, Active(2.0))[1][1] == 2 + @test autodiff(Forward, f25, Duplicated(2.0, 1.0))[1] == 2 + else + @test_broken autodiff(Reverse, f25, Active, Active(2.0))[1][1] == 2 + @test_broken autodiff(Forward, f25, Duplicated(2.0, 1.0))[1] == 2 + end println("Done 25") f26(x) = circshift([1.0, 2x, 3.0], 1)[end] From 83ccd217ac63e842ab86565332d67f8385952e3c Mon Sep 17 00:00:00 2001 From: Joe Greener Date: Wed, 9 Aug 2023 14:12:10 +0100 Subject: [PATCH 07/13] Mark try test as broken on Julia 1.7 --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index cd1c06fb9c..132e916fcb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -806,7 +806,7 @@ end end end # Gives 0.0 on Julia 1.6, see #971 - @static if VERSION ≥ v"1.7-" + @static if VERSION ≥ v"1.8-" @test autodiff(Reverse, f25, Active, Active(2.0))[1][1] == 2 @test autodiff(Forward, f25, Duplicated(2.0, 1.0))[1] == 2 else From 1e94ad17fd499634f4802658ba16f0ec20d02df3 Mon Sep 17 00:00:00 2001 From: Joe Greener Date: Sun, 26 Nov 2023 19:47:51 +0000 Subject: [PATCH 08/13] Remove skipmissing test --- test/runtests.jl | 8 -------- 1 file changed, 8 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 132e916fcb..26b05bb4fc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -849,14 +849,6 @@ end @test autodiff(Reverse, f32, Active, Active(2.0))[1][1] == 3 @test autodiff(Forward, f32, Duplicated(2.0, 1.0))[1] == 3 println("Done 32") - - # See #970 - @static if VERSION ≥ v"1.9-" - f33(x) = sum(skipmissing([x, missing, 3.0, 3x])) - @test autodiff(Reverse, f33, Active, Active(2.0))[1][1] == 4 - @test autodiff(Forward, f33, Duplicated(2.0, 1.0))[1] == 4 - end - println("Done 33") end function deadarg_pow(z::T, i) where {T<:Real} From ba040d20806e9cec0ccd30dd54e165b413c27b13 Mon Sep 17 00:00:00 2001 From: Joe Greener Date: Mon, 27 Nov 2023 23:56:48 +0000 Subject: [PATCH 09/13] Remove print debugging --- test/runtests.jl | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 26b05bb4fc..6a31d9b2a3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -708,12 +708,10 @@ end f10(x) = hypot(x, 2x) @test autodiff(Reverse, f10, Active, Active(2.0))[1][1] == sqrt(5) @test autodiff(Forward, f10, Duplicated(2.0, 1.0))[1] == sqrt(5) - println("Done 10") f11(x) = x * sum(LinRange(x, 10.0, 6)) @test autodiff(Reverse, f11, Active, Active(2.0))[1][1] == 42 @test autodiff(Forward, f11, Duplicated(2.0, 1.0))[1] == 42 - println("Done 11") f12(x, k) = get(Dict(1 => 1.0, 2 => x, 3 => 3.0), k, 1.0) @test autodiff(Reverse, f12, Active, Active(2.0), Const(2))[1] == (1.0, nothing) @@ -722,47 +720,38 @@ end @test autodiff(Forward, f12, Duplicated(2.0, 1.0), Const(3)) == (0.0,) @test autodiff(Reverse, f12, Active, Active(2.0), Const(4))[1] == (0.0, nothing) @test autodiff(Forward, f12, Duplicated(2.0, 1.0), Const(4)) == (0.0,) - println("Done 12") f13(x) = muladd(x, 3, x) @test autodiff(Reverse, f13, Active, Active(2.0))[1][1] == 4 @test autodiff(Forward, f13, Duplicated(2.0, 1.0))[1] == 4 - println("Done 13") f14(x) = x * cmp(x, 3) @test autodiff(Reverse, f14, Active, Active(2.0))[1][1] == -1 @test autodiff(Forward, f14, Duplicated(2.0, 1.0))[1] == -1 - println("Done 14") f15(x) = x * argmax([1.0, 3.0, 2.0]) @test autodiff(Reverse, f15, Active, Active(3.0))[1][1] == 2 @test autodiff(Forward, f15, Duplicated(3.0, 1.0))[1] == 2 - println("Done 15") f16(x) = evalpoly(2, (1, 2, x)) @test autodiff(Reverse, f16, Active, Active(3.0))[1][1] == 4 @test autodiff(Forward, f16, Duplicated(3.0, 1.0))[1] == 4 - println("Done 16") f17(x) = @evalpoly(2, 1, 2, x) @test autodiff(Reverse, f17, Active, Active(3.0))[1][1] == 4 @test autodiff(Forward, f17, Duplicated(3.0, 1.0))[1] == 4 - println("Done 17") f18(x) = widemul(x, 5.0f0) @test autodiff(Reverse, f18, Active, Active(2.0f0))[1][1] == 5 @test autodiff(Forward, f18, Duplicated(2.0f0, 1.0f0))[1] == 5 - println("Done 18") f19(x) = copysign(x, -x) @test autodiff(Reverse, f19, Active, Active(2.0))[1][1] == -1 @test autodiff(Forward, f19, Duplicated(2.0, 1.0))[1] == -1 - println("Done 19") f20(x) = sum([ifelse(i > 5, i, zero(i)) for i in [x, 2x, 3x, 4x]]) @test autodiff(Reverse, f20, Active, Active(2.0))[1][1] == 7 @test autodiff(Forward, f20, Duplicated(2.0, 1.0))[1] == 7 - println("Done 20") function f21(x) nt = (a=x, b=2x, c=3x) @@ -770,12 +759,10 @@ end end @test autodiff(Reverse, f21, Active, Active(2.0))[1][1] == 3 @test autodiff(Forward, f21, Duplicated(2.0, 1.0))[1] == 3 - println("Done 21") f22(x) = sum(fill(x, (3, 3))) @test autodiff(Reverse, f22, Active, Active(2.0))[1][1] == 9 @test autodiff(Forward, f22, Duplicated(2.0, 1.0))[1] == 9 - println("Done 22") function f23(x) a = similar(rand(3, 3)) @@ -784,7 +771,6 @@ end end @test autodiff(Reverse, f23, Active, Active(2.0))[1][1] == 9 @test autodiff(Forward, f23, Duplicated(2.0, 1.0))[1] == 9 - println("Done 23") function f24(x) try @@ -795,7 +781,6 @@ end end @test autodiff(Reverse, f24, Active, Active(2.0))[1][1] == 3 @test autodiff(Forward, f24, Duplicated(2.0, 1.0))[1] == 3 - println("Done 24") function f25(x) try @@ -813,42 +798,34 @@ end @test_broken autodiff(Reverse, f25, Active, Active(2.0))[1][1] == 2 @test_broken autodiff(Forward, f25, Duplicated(2.0, 1.0))[1] == 2 end - println("Done 25") f26(x) = circshift([1.0, 2x, 3.0], 1)[end] @test autodiff(Reverse, f26, Active, Active(2.0))[1][1] == 2 @test autodiff(Forward, f26, Duplicated(2.0, 1.0))[1] == 2 - println("Done 26") f27(x) = sum(diff([0.0 x; 1.0 2x]; dims=2)) @test autodiff(Reverse, f27, Active, Active(2.0))[1][1] == 3 @test autodiff(Forward, f27, Duplicated(2.0, 1.0))[1] == 3 - println("Done 27") f28(x) = repeat([x 3x], 3)[2, 2] @test autodiff(Reverse, f28, Active, Active(2.0))[1][1] == 3 @test autodiff(Forward, f28, Duplicated(2.0, 1.0))[1] == 3 - println("Done 28") f29(x) = rot180([x 2x; 3x 4x], 3)[1, 1] @test autodiff(Reverse, f29, Active, Active(2.0))[1][1] == 4 @test autodiff(Forward, f29, Duplicated(2.0, 1.0))[1] == 4 - println("Done 29") f30(x) = x * sum(trues(4, 3)) @test autodiff(Reverse, f30, Active, Active(2.0))[1][1] == 12 @test autodiff(Forward, f30, Duplicated(2.0, 1.0))[1] == 12 - println("Done 30") f31(x) = sum(Set([1.0, x, 2x, x])) @test autodiff(Reverse, f31, Active, Active(2.0))[1][1] == 3 @test autodiff(Forward, f31, Duplicated(2.0, 1.0))[1] == 3 - println("Done 31") f32(x) = reverse([x 2.0 3x])[1] @test autodiff(Reverse, f32, Active, Active(2.0))[1][1] == 3 @test autodiff(Forward, f32, Duplicated(2.0, 1.0))[1] == 3 - println("Done 32") end function deadarg_pow(z::T, i) where {T<:Real} From 603b4590def998d13a4f46e8662b71f27b61d4cb Mon Sep 17 00:00:00 2001 From: Joe Greener Date: Sun, 21 Apr 2024 19:33:16 +0100 Subject: [PATCH 10/13] Revert changes --- test/runtests.jl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 6a31d9b2a3..c4304d609a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3763,9 +3763,12 @@ end @test autodiff(Reverse, f8, Active, Active(1.5))[1][1] == 0 @test autodiff(Forward, f8, Duplicated(1.5, 1.0))[1] == 0 - f9(x) = sum(quantile([1.0, x], [0.5, 0.7])) - @test autodiff(Reverse, f9, Active, Active(2.0))[1][1] == 1.2 - @test autodiff(Forward, f9, Duplicated(2.0, 1.0))[1] == 1.2 + # On Julia 1.6 the gradients are wrong (0.7 not 1.2) and on 1.7 it errors + @static if VERSION ≥ v"1.8-" + f9(x) = sum(quantile([1.0, x], [0.5, 0.7])) + @test autodiff(Reverse, f9, Active, Active(2.0))[1][1] == 1.2 + @test autodiff(Forward, f9, Duplicated(2.0, 1.0))[1] == 1.2 + end end @testset "hvcat_fill" begin From d3dd70106589b69cfba3adfb2d61b200a1d71933 Mon Sep 17 00:00:00 2001 From: Joe Greener Date: Sun, 6 Oct 2024 23:40:15 +0100 Subject: [PATCH 11/13] Remove version check --- test/runtests.jl | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index c4304d609a..51b83d549d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -790,14 +790,8 @@ end return 2x end end - # Gives 0.0 on Julia 1.6, see #971 - @static if VERSION ≥ v"1.8-" - @test autodiff(Reverse, f25, Active, Active(2.0))[1][1] == 2 - @test autodiff(Forward, f25, Duplicated(2.0, 1.0))[1] == 2 - else - @test_broken autodiff(Reverse, f25, Active, Active(2.0))[1][1] == 2 - @test_broken autodiff(Forward, f25, Duplicated(2.0, 1.0))[1] == 2 - end + @test autodiff(Reverse, f25, Active, Active(2.0))[1][1] == 2 + @test autodiff(Forward, f25, Duplicated(2.0, 1.0))[1] == 2 f26(x) = circshift([1.0, 2x, 3.0], 1)[end] @test autodiff(Reverse, f26, Active, Active(2.0))[1][1] == 2 From 34e06cd0052ce59807859bb121f8da7a5c1bd5a7 Mon Sep 17 00:00:00 2001 From: Joe Greener Date: Sun, 17 Nov 2024 19:34:57 +0000 Subject: [PATCH 12/13] updates for Enzyme changes --- test/runtests.jl | 39 ++++++++++++++------------------------- 1 file changed, 14 insertions(+), 25 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 51b83d549d..0c9df8c873 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -797,29 +797,21 @@ end @test autodiff(Reverse, f26, Active, Active(2.0))[1][1] == 2 @test autodiff(Forward, f26, Duplicated(2.0, 1.0))[1] == 2 - f27(x) = sum(diff([0.0 x; 1.0 2x]; dims=2)) + f27(x) = repeat([x 3x], 3)[2, 2] @test autodiff(Reverse, f27, Active, Active(2.0))[1][1] == 3 @test autodiff(Forward, f27, Duplicated(2.0, 1.0))[1] == 3 - f28(x) = repeat([x 3x], 3)[2, 2] - @test autodiff(Reverse, f28, Active, Active(2.0))[1][1] == 3 - @test autodiff(Forward, f28, Duplicated(2.0, 1.0))[1] == 3 + f28(x) = x * sum(trues(4, 3)) + @test autodiff(Reverse, f28, Active, Active(2.0))[1][1] == 12 + @test autodiff(Forward, f28, Duplicated(2.0, 1.0))[1] == 12 - f29(x) = rot180([x 2x; 3x 4x], 3)[1, 1] - @test autodiff(Reverse, f29, Active, Active(2.0))[1][1] == 4 - @test autodiff(Forward, f29, Duplicated(2.0, 1.0))[1] == 4 + f29(x) = sum(Set([1.0, x, 2x, x])) + @test autodiff(Reverse, f29, Active, Active(2.0))[1][1] == 3 + @test autodiff(Forward, f29, Duplicated(2.0, 1.0))[1] == 3 - f30(x) = x * sum(trues(4, 3)) - @test autodiff(Reverse, f30, Active, Active(2.0))[1][1] == 12 - @test autodiff(Forward, f30, Duplicated(2.0, 1.0))[1] == 12 - - f31(x) = sum(Set([1.0, x, 2x, x])) - @test autodiff(Reverse, f31, Active, Active(2.0))[1][1] == 3 - @test autodiff(Forward, f31, Duplicated(2.0, 1.0))[1] == 3 - - f32(x) = reverse([x 2.0 3x])[1] - @test autodiff(Reverse, f32, Active, Active(2.0))[1][1] == 3 - @test autodiff(Forward, f32, Duplicated(2.0, 1.0))[1] == 3 + f30(x) = reverse([x 2.0 3x])[1] + @test autodiff(Reverse, f30, Active, Active(2.0))[1][1] == 3 + @test autodiff(Forward, f30, Duplicated(2.0, 1.0))[1] == 3 end function deadarg_pow(z::T, i) where {T<:Real} @@ -885,7 +877,7 @@ end @test autodiff(Forward, (x,y) -> autodiff(Forward, Const(tonest), Duplicated(x, 1.0), Const(y))[1], Const(1.0), Duplicated(2.0, 1.0))[1] ≈ 2.0 f_nest(x) = 2 * x^4 - deriv(f, x) = first(first(autodiff_deferred(Reverse, f, Active(x)))) + deriv(f, x) = first(first(autodiff(Reverse, f, Active(x)))) f′(x) = deriv(f_nest, x) f′′(x) = deriv(f′, x) @@ -3757,12 +3749,9 @@ end @test autodiff(Reverse, f8, Active, Active(1.5))[1][1] == 0 @test autodiff(Forward, f8, Duplicated(1.5, 1.0))[1] == 0 - # On Julia 1.6 the gradients are wrong (0.7 not 1.2) and on 1.7 it errors - @static if VERSION ≥ v"1.8-" - f9(x) = sum(quantile([1.0, x], [0.5, 0.7])) - @test autodiff(Reverse, f9, Active, Active(2.0))[1][1] == 1.2 - @test autodiff(Forward, f9, Duplicated(2.0, 1.0))[1] == 1.2 - end + f9(x) = sum(quantile([1.0, x], [0.5, 0.7])) + @test autodiff(Reverse, f9, Active, Active(2.0))[1][1] == 1.2 + @test autodiff(Forward, f9, Duplicated(2.0, 1.0))[1] == 1.2 end @testset "hvcat_fill" begin From 31d9fc78a199300f7c71a4365dcf35af9990aeae Mon Sep 17 00:00:00 2001 From: Joe Greener Date: Mon, 2 Dec 2024 19:14:04 +0000 Subject: [PATCH 13/13] remove higher order test --- test/runtests.jl | 8 -------- 1 file changed, 8 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 0c9df8c873..c9cde02c28 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -875,14 +875,6 @@ end tonest(x,y) = (x + y)^2 @test autodiff(Forward, (x,y) -> autodiff(Forward, Const(tonest), Duplicated(x, 1.0), Const(y))[1], Const(1.0), Duplicated(2.0, 1.0))[1] ≈ 2.0 - - f_nest(x) = 2 * x^4 - deriv(f, x) = first(first(autodiff(Reverse, f, Active(x)))) - f′(x) = deriv(f_nest, x) - f′′(x) = deriv(f′, x) - - @test f′(2.0) == 64 - @test f′′(2.0) == 96 end @testset "Hessian" begin