From c40fc88386cea52f80d77c4fe99f4ae852243a8d Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sat, 15 Jul 2023 22:29:21 +0200 Subject: [PATCH 01/18] Run Enzyme tests --- .github/workflows/CI.yml | 22 ++++++-------- test/Project.toml | 2 ++ test/ad/utils.jl | 63 ++++++++++++++++++++++++++++++++-------- test/runtests.jl | 2 +- 4 files changed, 63 insertions(+), 26 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 270363e..64009d3 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -27,6 +27,7 @@ jobs: - x64 group: - Others + - Enzyme - ForwardDiff - Tracker - ReverseDiff @@ -36,6 +37,10 @@ jobs: os: macOS-latest arch: x64 group: Others + - version: '1' + os: macOS-latest + arch: x64 + group: Enzyme - version: '1' os: macOS-latest arch: x64 @@ -53,23 +58,14 @@ jobs: arch: x64 group: Zygote steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - uses: julia-actions/setup-julia@v1 with: version: ${{ matrix.version }} arch: ${{ matrix.arch }} - - uses: actions/cache@v1 - env: - cache-name: cache-artifacts - with: - path: ~/.julia/artifacts - key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} - restore-keys: | - ${{ runner.os }}-test-${{ env.cache-name }}- - ${{ runner.os }}-test- - ${{ runner.os }}- - - uses: julia-actions/julia-buildpkg@latest - - uses: julia-actions/julia-runtest@latest + - uses: julia-actions/cache@v1 + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-runtest@v1 with: coverage: false env: diff --git a/test/Project.toml b/test/Project.toml index 086a88d..87720f3 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -3,6 +3,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" @@ -21,6 +22,7 @@ ChainRulesCore = "1" ChainRulesTestUtils = "1.9.2" Combinatorics = "1.0.2" Distributions = "0.25.15" +Enzyme = "0.11.5" FiniteDifferences = "0.11.3, 0.12" ForwardDiff = "0.10.12" LazyArrays = "1" diff --git a/test/ad/utils.jl b/test/ad/utils.jl index a4dcd6e..154ec1b 100644 --- a/test/ad/utils.jl +++ b/test/ad/utils.jl @@ -5,6 +5,9 @@ using FiniteDifferences const FDM = FiniteDifferences # Load AD backends +if GROUP == "All" || GROUP == "Enzyme" + @eval using Enzyme +end if GROUP == "All" || GROUP == "ForwardDiff" @eval using ForwardDiff end @@ -25,6 +28,40 @@ function test_reverse_mode_ad(f, ȳ, x...; rtol=1e-6, atol=1e-6) # Use finite differencing to compute reverse-mode sensitivities. x̄s_fdm = FDM.j′vp(central_fdm(5, 1), f, ȳ, x...) + if GROUP == "All" || GROUP == "Enzyme" + # Use Enzyme to compute reverse-mode sensitivities. + x̄s_enzyme_init = map(x) do xi + xi isa Real ? nothing : zero(xi) + end + enzyme_autodiff_args = map(x, x̄s_enzyme_init) do xi, x̄si + return if x̄si === nothing + @assert xi isa Real + Active(xi) + else + @assert typeof(xi) === typeof(x̄si) + Duplicated(xi, x̄si) + end + end + x̄s_enzyme_autodiff, y_enzyme = Enzyme.autodiff(ReverseWithPrimal, f, Active, enzyme_autodiff_args...) + x̄s_enzyme = map(x̄s_enzyme_init, x̄s_enzyme_autodiff) do x̄s_init_i, x̄s_autodiff_i + return if x̄s_init_i === nothing + @assert x̄s_autodiff_i isa Real + x̄s_autodiff_i + else + @assert x̄s_autodiff_i === nothing + x̄s_init_i + end + end + + # Check that Enzyme primal is correct. + @test y ≈ y_enzyme atol=atol rtol=rtol + + # Check that Enzyme reverse-mode sensitivities are correct. + @test all(zip(x̄s_enzyme, x̄s_fdm)) do (x̄_enzyme, x̄_fdm) + return isapprox(x̄_enzyme, x̄_fdm; atol=atol, rtol=rtol) + end + end + if GROUP == "All" || GROUP == "Zygote" # Use Zygote to compute reverse-mode sensitivities. y_zygote, back_zygote = Zygote.pullback(f, x...) @@ -350,6 +387,19 @@ end function test_ad(f, x, broken = (); rtol = 1e-6, atol = 1e-6) finitediff = FDM.grad(central_fdm(5, 1), f, x)[1] + if GROUP == "All" || GROUP == "Enzyme" + if (:Enzyme in broken) || (:EnzymeForward in broken) + @test_broken collect(Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff rtol=rtol atol=atol + else + @test collect(Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff rtol=rtol atol=atol + end + if (:Enzyme in broken) || (:EnzymeReverse in broken) + @test_broken Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff rtol=rtol atol=atol + else + @test Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff rtol=rtol atol=atol + end + end + if GROUP == "All" || GROUP == "Tracker" if :Tracker in broken @test_broken Tracker.data(Tracker.gradient(f, x)[1]) ≈ finitediff rtol=rtol atol=atol @@ -398,7 +448,7 @@ end function testset_zygote_broken(distspec, args...; kwargs...) # don't show test errors - tests are known to be broken :) - testset = suppress_stdout() do + testset = redirect_stdout(devnull) do testset_zygote(distspec, args...; kwargs...) end @@ -417,17 +467,6 @@ function testset_zygote_broken(distspec, args...; kwargs...) return testset end -# `redirect_stdout(f, devnull)` is only available in Julia >= 1.6 -function suppress_stdout(f) - @static if VERSION < v"1.6" - open((@static Sys.iswindows() ? "NUL" : "/dev/null"), "w") do devnull - redirect_stdout(f, devnull) - end - else - redirect_stdout(f, devnull) - end -end - # change test errors and failures to broken results function errors_to_broken!(ts::Test.DefaultTestSet) results = ts.results diff --git a/test/runtests.jl b/test/runtests.jl index baee8ad..4f12390 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -26,7 +26,7 @@ if GROUP == "All" || GROUP == "Others" include("others.jl") end -if GROUP == "All" || GROUP in ("ForwardDiff", "Zygote", "ReverseDiff", "Tracker") +if GROUP == "All" || GROUP in ("Enzyme", "ForwardDiff", "Zygote", "ReverseDiff", "Tracker") include("ad/utils.jl") include("ad/others.jl") include("ad/distributions.jl") From b8b69d85516654d89cfe549cdfe8e06337930987 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 3 Aug 2023 20:49:26 +0200 Subject: [PATCH 02/18] Test Enzyme#main --- test/runtests.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index 4f12390..f115438 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,3 +1,6 @@ +using Pkg +Pkg.add(Pkg.PackageSpec(; url="https://github.com/EnzymeAD/Enzyme.jl.git", rev="main")) + using DistributionsAD using Combinatorics From 784930283a971d5e30a160745cffdc1f0a6e8cf3 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 15 Aug 2023 23:42:04 +0200 Subject: [PATCH 03/18] Mark some tests as broken --- test/ad/others.jl | 14 +++++----- test/ad/utils.jl | 71 +++++++++++++++++++++++++++++------------------ 2 files changed, 51 insertions(+), 34 deletions(-) diff --git a/test/ad/others.jl b/test/ad/others.jl index 1c70558..75eae1d 100644 --- a/test/ad/others.jl +++ b/test/ad/others.jl @@ -7,14 +7,14 @@ @testset "logsumexp" begin x = rand(3) - test_reverse_mode_ad(logsumexp, randn(), x; rtol=1e-8, atol=1e-6) + test_reverse_mode_ad(logsumexp, randn(), x; rtol=1e-8, atol=1e-6, broken = (:Enzyme,)) end @testset "zygote_ldiv" begin A = to_posdef(rand(3, 3)) B = to_posdef(rand(3, 3)) - test_reverse_mode_ad(randn(3, 3), A, B) do A, B + test_reverse_mode_ad(randn(3, 3), A, B; broken = (:Enzyme,)) do A, B return DistributionsAD.zygote_ldiv(A, B) end end @@ -23,10 +23,10 @@ N = 7 B = randn(N, N) - test_reverse_mode_ad(randn(), B; rtol=1e-8, atol=1e-6) do B + test_reverse_mode_ad(randn(), B; rtol=1e-8, atol=1e-6, broken = (:Enzyme,)) do B return logdet(cholesky(to_posdef(B))) end - test_reverse_mode_ad(randn(), B; rtol=1e-8, atol=1e-6) do B + test_reverse_mode_ad(randn(), B; rtol=1e-8, atol=1e-6, broken = (:Enzyme,)) do B return logdet(cholesky(Symmetric(to_posdef(B)))) end end @@ -47,13 +47,13 @@ B = randn(N, N) x = rand(TuringDenseMvNormal(m, to_posdef(B))) - test_reverse_mode_ad(randn(), m, B, x) do m, B, x + test_reverse_mode_ad(randn(), m, B, x; broken = (:Enzyme,)) do m, B, x return logpdf(MvNormal(m, to_posdef(B)), x) end - test_reverse_mode_ad(randn(), m, B, x) do m, B, x + test_reverse_mode_ad(randn(), m, B, x; broken = (:Enzyme,)) do m, B, x return logpdf(TuringMvNormal(m, to_posdef(B)), x) end - test_reverse_mode_ad(randn(), m, B, x) do m, B, x + test_reverse_mode_ad(randn(), m, B, x; broken = (:Enzyme,)) do m, B, x return logpdf(TuringMvNormal(m, Symmetric(to_posdef(B))), x) end end diff --git a/test/ad/utils.jl b/test/ad/utils.jl index 154ec1b..8356fe8 100644 --- a/test/ad/utils.jl +++ b/test/ad/utils.jl @@ -21,7 +21,7 @@ if GROUP == "All" || GROUP == "Tracker" @eval using Tracker end -function test_reverse_mode_ad(f, ȳ, x...; rtol=1e-6, atol=1e-6) +function test_reverse_mode_ad(f, ȳ, x...; rtol=1e-6, atol=1e-6, broken=()) # Perform a regular forwards-pass. y = f(x...) @@ -29,36 +29,53 @@ function test_reverse_mode_ad(f, ȳ, x...; rtol=1e-6, atol=1e-6) x̄s_fdm = FDM.j′vp(central_fdm(5, 1), f, ȳ, x...) if GROUP == "All" || GROUP == "Enzyme" - # Use Enzyme to compute reverse-mode sensitivities. - x̄s_enzyme_init = map(x) do xi - xi isa Real ? nothing : zero(xi) - end - enzyme_autodiff_args = map(x, x̄s_enzyme_init) do xi, x̄si - return if x̄si === nothing - @assert xi isa Real - Active(xi) - else - @assert typeof(xi) === typeof(x̄si) - Duplicated(xi, x̄si) - end - end - x̄s_enzyme_autodiff, y_enzyme = Enzyme.autodiff(ReverseWithPrimal, f, Active, enzyme_autodiff_args...) - x̄s_enzyme = map(x̄s_enzyme_init, x̄s_enzyme_autodiff) do x̄s_init_i, x̄s_autodiff_i - return if x̄s_init_i === nothing - @assert x̄s_autodiff_i isa Real - x̄s_autodiff_i - else - @assert x̄s_autodiff_i === nothing - x̄s_init_i + enzyme_broken = :Enzyme in broken + io = enzyme_broken ? devnull : stdout + testset = redirect_stdout(io) do + # Use Enzyme to compute reverse-mode sensitivities. + @testset "Enzyme: Reverse-mode AD of $f" begin + x̄s_enzyme_init = map(x) do xi + xi isa Real ? nothing : zero(xi) + end + enzyme_autodiff_args = map(x, x̄s_enzyme_init) do xi, x̄si + return if x̄si === nothing + @assert xi isa Real + Active(xi) + else + @assert typeof(xi) === typeof(x̄si) + Duplicated(xi, x̄si) + end + end + x̄s_enzyme_autodiff, y_dot_ȳ_enzyme = + Enzyme.autodiff(ReverseWithPrimal, Active, enzyme_autodiff_args...) do args... + return dot(f(args...), ȳ) + end + x̄s_enzyme = map(x̄s_enzyme_init, x̄s_enzyme_autodiff) do x̄s_init_i, x̄s_autodiff_i + return if x̄s_init_i === nothing + @assert x̄s_autodiff_i isa Real + x̄s_autodiff_i + else + @assert x̄s_autodiff_i === nothing + x̄s_init_i + end + end + + # Check that Enzyme primal is correct. + @test dot(y, ȳ) ≈ y_dot_ȳ_enzyme atol=atol rtol=rtol + + # Check that Enzyme reverse-mode sensitivities are correct. + @test all(zip(x̄s_enzyme, x̄s_fdm)) do (x̄_enzyme, x̄_fdm) + return isapprox(x̄_enzyme, x̄_fdm; atol=atol, rtol=rtol) + end end end - # Check that Enzyme primal is correct. - @test y ≈ y_enzyme atol=atol rtol=rtol + # change errors and fails to broken results, and count number of errors and fails + efs = errors_to_broken!(testset) - # Check that Enzyme reverse-mode sensitivities are correct. - @test all(zip(x̄s_enzyme, x̄s_fdm)) do (x̄_enzyme, x̄_fdm) - return isapprox(x̄_enzyme, x̄_fdm; atol=atol, rtol=rtol) + # ensure that passing tests are not marked as broken + if iszero(efs) && enzyme_broken + error("Enzyme tests of $f passed unexpectedly, please mark not as broken") end end From 86657fd2577f6baabb4f787b0340b39ffb1da40a Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 15 Aug 2023 23:59:45 +0200 Subject: [PATCH 04/18] Fix order of arguments --- test/ad/utils.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/ad/utils.jl b/test/ad/utils.jl index 8356fe8..10ad346 100644 --- a/test/ad/utils.jl +++ b/test/ad/utils.jl @@ -46,10 +46,9 @@ function test_reverse_mode_ad(f, ȳ, x...; rtol=1e-6, atol=1e-6, broken=()) Duplicated(xi, x̄si) end end + dot_f_ȳ(args...) = dot(f(args...), ȳ) x̄s_enzyme_autodiff, y_dot_ȳ_enzyme = - Enzyme.autodiff(ReverseWithPrimal, Active, enzyme_autodiff_args...) do args... - return dot(f(args...), ȳ) - end + Enzyme.autodiff(ReverseWithPrimal, dot_f_ȳ, Active, enzyme_autodiff_args...) x̄s_enzyme = map(x̄s_enzyme_init, x̄s_enzyme_autodiff) do x̄s_init_i, x̄s_autodiff_i return if x̄s_init_i === nothing @assert x̄s_autodiff_i isa Real From 5fb9422d73e00ec07a1c08755b56bfe2c07c9b04 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 16 Aug 2023 00:19:48 +0200 Subject: [PATCH 05/18] logsumexp tests pass? --- test/ad/others.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/ad/others.jl b/test/ad/others.jl index 75eae1d..0c4e1f2 100644 --- a/test/ad/others.jl +++ b/test/ad/others.jl @@ -7,7 +7,7 @@ @testset "logsumexp" begin x = rand(3) - test_reverse_mode_ad(logsumexp, randn(), x; rtol=1e-8, atol=1e-6, broken = (:Enzyme,)) + test_reverse_mode_ad(logsumexp, randn(), x; rtol=1e-8, atol=1e-6) end @testset "zygote_ldiv" begin From 6da01ea430e0216d46911d98d9ce6ed9c4da947e Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 21 Aug 2023 23:39:42 +0200 Subject: [PATCH 06/18] Improve logs and use workaround for runtime activity --- test/ad/utils.jl | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/test/ad/utils.jl b/test/ad/utils.jl index 10ad346..cb5d0ba 100644 --- a/test/ad/utils.jl +++ b/test/ad/utils.jl @@ -6,7 +6,13 @@ const FDM = FiniteDifferences # Load AD backends if GROUP == "All" || GROUP == "Enzyme" - @eval using Enzyme + @eval begin + using Enzyme + # Disable Enzyme warnings + Enzyme.API.typeWarning!(false) + # Enable runtime activity (workaround) + Enzyme.API.runtimeActivity!(true) + end end if GROUP == "All" || GROUP == "ForwardDiff" @eval using ForwardDiff From 0c87cb03b0603dbac7d54c089618b5b3a77af424 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Sat, 25 May 2024 21:49:19 +0100 Subject: [PATCH 07/18] Update test/Project.toml --- test/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index 87720f3..3f4bf7e 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -22,7 +22,7 @@ ChainRulesCore = "1" ChainRulesTestUtils = "1.9.2" Combinatorics = "1.0.2" Distributions = "0.25.15" -Enzyme = "0.11.5" +Enzyme = "0.12" FiniteDifferences = "0.11.3, 0.12" ForwardDiff = "0.10.12" LazyArrays = "1" From 41326af19de2120b178f89801b91e40a5a9d2f54 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Tue, 9 Jul 2024 13:19:35 +0100 Subject: [PATCH 08/18] Update distributions.jl --- test/ad/distributions.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/ad/distributions.jl b/test/ad/distributions.jl index 8baa50c..9a7f5d2 100644 --- a/test/ad/distributions.jl +++ b/test/ad/distributions.jl @@ -159,10 +159,10 @@ DistSpec(LogNormal, (1.0, 2.0), 0.5), # Dispatch error caused by ccall - DistSpec(NoncentralBeta, (1.0, 2.0, 1.0), 0.5, broken=(:Tracker, :ForwardDiff, :Zygote, :ReverseDiff)), - DistSpec(NoncentralChisq, (1.0, 2.0), 0.5, broken=(:Tracker, :ForwardDiff, :Zygote, :ReverseDiff)), - DistSpec(NoncentralF, (1.0, 2.0, 1.0), 0.5, broken=(:Tracker, :ForwardDiff, :Zygote, :ReverseDiff)), - DistSpec(NoncentralT, (1.0, 2.0), 0.5, broken=(:Tracker, :ForwardDiff, :Zygote, :ReverseDiff)), + DistSpec(NoncentralBeta, (1.0, 2.0, 1.0), 0.5, broken=(:Enzyme, :Tracker, :ForwardDiff, :Zygote, :ReverseDiff)), + DistSpec(NoncentralChisq, (1.0, 2.0), 0.5, broken=(:Enzyme, :Tracker, :ForwardDiff, :Zygote, :ReverseDiff)), + DistSpec(NoncentralF, (1.0, 2.0, 1.0), 0.5, broken=(:Enzyme, :Tracker, :ForwardDiff, :Zygote, :ReverseDiff)), + DistSpec(NoncentralT, (1.0, 2.0), 0.5, broken=(:Enzyme, :Tracker, :ForwardDiff, :Zygote, :ReverseDiff)), DistSpec(Normal, (), 0.5), DistSpec(Normal, (1.0,), 0.5), From 8b9fa77cc121f23295a01b62926945437fda7bee Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Tue, 9 Jul 2024 17:38:42 +0100 Subject: [PATCH 09/18] Apply suggestions from code review --- test/ad/others.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/ad/others.jl b/test/ad/others.jl index 0c4e1f2..94d1492 100644 --- a/test/ad/others.jl +++ b/test/ad/others.jl @@ -23,10 +23,10 @@ N = 7 B = randn(N, N) - test_reverse_mode_ad(randn(), B; rtol=1e-8, atol=1e-6, broken = (:Enzyme,)) do B + test_reverse_mode_ad(randn(), B; rtol=1e-8, atol=1e-6) do B return logdet(cholesky(to_posdef(B))) end - test_reverse_mode_ad(randn(), B; rtol=1e-8, atol=1e-6, broken = (:Enzyme,)) do B + test_reverse_mode_ad(randn(), B; rtol=1e-8, atol=1e-6) do B return logdet(cholesky(Symmetric(to_posdef(B)))) end end @@ -47,13 +47,13 @@ B = randn(N, N) x = rand(TuringDenseMvNormal(m, to_posdef(B))) - test_reverse_mode_ad(randn(), m, B, x; broken = (:Enzyme,)) do m, B, x + test_reverse_mode_ad(randn(), m, B, x) do m, B, x return logpdf(MvNormal(m, to_posdef(B)), x) end - test_reverse_mode_ad(randn(), m, B, x; broken = (:Enzyme,)) do m, B, x + test_reverse_mode_ad(randn(), m, B, x) do m, B, x return logpdf(TuringMvNormal(m, to_posdef(B)), x) end - test_reverse_mode_ad(randn(), m, B, x; broken = (:Enzyme,)) do m, B, x + test_reverse_mode_ad(randn(), m, B, x) do m, B, x return logpdf(TuringMvNormal(m, Symmetric(to_posdef(B))), x) end end From c95342e878f22170aba071cc52fd41732e62dff7 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Tue, 9 Jul 2024 17:39:49 +0100 Subject: [PATCH 10/18] Update test/ad/utils.jl --- test/ad/utils.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/test/ad/utils.jl b/test/ad/utils.jl index cb5d0ba..d3bd5c6 100644 --- a/test/ad/utils.jl +++ b/test/ad/utils.jl @@ -8,10 +8,6 @@ const FDM = FiniteDifferences if GROUP == "All" || GROUP == "Enzyme" @eval begin using Enzyme - # Disable Enzyme warnings - Enzyme.API.typeWarning!(false) - # Enable runtime activity (workaround) - Enzyme.API.runtimeActivity!(true) end end if GROUP == "All" || GROUP == "ForwardDiff" From 8647c817e50fca2405e68771682a1c8a67ea6073 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Tue, 9 Jul 2024 22:12:51 +0100 Subject: [PATCH 11/18] Update runtests.jl --- test/runtests.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index f115438..9e4c6e3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,3 @@ -using Pkg -Pkg.add(Pkg.PackageSpec(; url="https://github.com/EnzymeAD/Enzyme.jl.git", rev="main")) - using DistributionsAD using Combinatorics @@ -15,6 +12,10 @@ using DistributionsAD: TuringMvNormal, TuringMvLogNormal, TuringPoissonBinomial, TuringDirichlet using StatsFuns: StatsFuns, logsumexp, logistic +import Enzyme +# Enable runtime activity (workaround) +Enzyme.API.runtimeActivity!(true) + @static if VERSION >= v"1.8" using Pkg; Pkg.status(outdated=true) # show reasons why packages are held back end From 043a4d0721c1c004424eb46609b95ce1df8b697c Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Tue, 9 Jul 2024 23:04:53 +0100 Subject: [PATCH 12/18] Update runtests.jl --- test/runtests.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/runtests.jl b/test/runtests.jl index 9e4c6e3..118fb20 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,6 +13,7 @@ using DistributionsAD: TuringMvNormal, TuringMvLogNormal, using StatsFuns: StatsFuns, logsumexp, logistic import Enzyme +Enzyme.API.typeWarning!(false) # Enable runtime activity (workaround) Enzyme.API.runtimeActivity!(true) From 670c5752b10ff34cce788ce1f8a4b091b2c76f88 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Wed, 21 Aug 2024 18:18:36 +0100 Subject: [PATCH 13/18] Update test/ad/utils.jl --- test/ad/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/ad/utils.jl b/test/ad/utils.jl index d3bd5c6..9c106f3 100644 --- a/test/ad/utils.jl +++ b/test/ad/utils.jl @@ -50,7 +50,7 @@ function test_reverse_mode_ad(f, ȳ, x...; rtol=1e-6, atol=1e-6, broken=()) end dot_f_ȳ(args...) = dot(f(args...), ȳ) x̄s_enzyme_autodiff, y_dot_ȳ_enzyme = - Enzyme.autodiff(ReverseWithPrimal, dot_f_ȳ, Active, enzyme_autodiff_args...) + Enzyme.autodiff(ReverseWithPrimal, Const(dot_f_ȳ), Active, enzyme_autodiff_args...) x̄s_enzyme = map(x̄s_enzyme_init, x̄s_enzyme_autodiff) do x̄s_init_i, x̄s_autodiff_i return if x̄s_init_i === nothing @assert x̄s_autodiff_i isa Real From 6a1b64132cc821ad9d5aed596011a66b16bae46c Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Wed, 21 Aug 2024 18:58:33 +0100 Subject: [PATCH 14/18] Update utils.jl --- test/ad/utils.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/ad/utils.jl b/test/ad/utils.jl index 9c106f3..275b19c 100644 --- a/test/ad/utils.jl +++ b/test/ad/utils.jl @@ -407,14 +407,14 @@ function test_ad(f, x, broken = (); rtol = 1e-6, atol = 1e-6) if GROUP == "All" || GROUP == "Enzyme" if (:Enzyme in broken) || (:EnzymeForward in broken) - @test_broken collect(Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff rtol=rtol atol=atol + @test_broken collect(Enzyme.gradient(Enzyme.Forward, Const(f), x)) ≈ finitediff rtol=rtol atol=atol else - @test collect(Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff rtol=rtol atol=atol + @test collect(Enzyme.gradient(Enzyme.Forward, Const(f), x)) ≈ finitediff rtol=rtol atol=atol end if (:Enzyme in broken) || (:EnzymeReverse in broken) - @test_broken Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff rtol=rtol atol=atol + @test_broken Enzyme.gradient(Enzyme.Reverse, Const(f), x) ≈ finitediff rtol=rtol atol=atol else - @test Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff rtol=rtol atol=atol + @test Enzyme.gradient(Enzyme.Reverse, Const(f), x) ≈ finitediff rtol=rtol atol=atol end end From 5d3f0c574016d092c86314c27c1bf5879812cc72 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Wed, 21 Aug 2024 20:17:44 +0100 Subject: [PATCH 15/18] Update distributions.jl --- test/ad/distributions.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/ad/distributions.jl b/test/ad/distributions.jl index 9a7f5d2..c639c60 100644 --- a/test/ad/distributions.jl +++ b/test/ad/distributions.jl @@ -159,10 +159,10 @@ DistSpec(LogNormal, (1.0, 2.0), 0.5), # Dispatch error caused by ccall - DistSpec(NoncentralBeta, (1.0, 2.0, 1.0), 0.5, broken=(:Enzyme, :Tracker, :ForwardDiff, :Zygote, :ReverseDiff)), - DistSpec(NoncentralChisq, (1.0, 2.0), 0.5, broken=(:Enzyme, :Tracker, :ForwardDiff, :Zygote, :ReverseDiff)), - DistSpec(NoncentralF, (1.0, 2.0, 1.0), 0.5, broken=(:Enzyme, :Tracker, :ForwardDiff, :Zygote, :ReverseDiff)), - DistSpec(NoncentralT, (1.0, 2.0), 0.5, broken=(:Enzyme, :Tracker, :ForwardDiff, :Zygote, :ReverseDiff)), + DistSpec(NoncentralBeta, (1.0, 2.0, 1.0), 0.5, broken=(:Enzyme, :Tracker, :ForwardDiff, :Zygote, :ReverseDiff, :Enzyme)), + DistSpec(NoncentralChisq, (1.0, 2.0), 0.5, broken=(:Enzyme, :Tracker, :ForwardDiff, :Zygote, :ReverseDiff, :Enzyme)), + DistSpec(NoncentralF, (1.0, 2.0, 1.0), 0.5, broken=(:Enzyme, :Tracker, :ForwardDiff, :Zygote, :ReverseDiff, :Enzyme)), + DistSpec(NoncentralT, (1.0, 2.0), 0.5, broken=(:Enzyme, :Tracker, :ForwardDiff, :Zygote, :ReverseDiff, :Enzyme)), DistSpec(Normal, (), 0.5), DistSpec(Normal, (1.0,), 0.5), From 48c9afb8726e85876aee3b55a9e3f86e27b7d025 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Wed, 21 Aug 2024 20:40:04 +0100 Subject: [PATCH 16/18] Update distributions.jl --- test/ad/distributions.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/ad/distributions.jl b/test/ad/distributions.jl index c639c60..66c910b 100644 --- a/test/ad/distributions.jl +++ b/test/ad/distributions.jl @@ -61,8 +61,9 @@ DistSpec(Poisson, (0.5,), 1), DistSpec(Poisson, (0.5,), [1, 1]), - DistSpec(Skellam, (1.0, 2.0), -2), - DistSpec(Skellam, (1.0, 2.0), [-2, -2]), + # Enzyme: no forward rule for ccall + DistSpec(Skellam, (1.0, 2.0), -2, broken=(:Zygote,)), + DistSpec(Skellam, (1.0, 2.0), [-2, -2], broken=(:Zygote,)), DistSpec(PoissonBinomial, ([0.5, 0.5],), 0), From 443c731a4e0c0707215bfc6aa12db8bac56f45fe Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Wed, 21 Aug 2024 20:43:10 +0100 Subject: [PATCH 17/18] Update test/ad/distributions.jl --- test/ad/distributions.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/ad/distributions.jl b/test/ad/distributions.jl index 66c910b..699bc7a 100644 --- a/test/ad/distributions.jl +++ b/test/ad/distributions.jl @@ -62,8 +62,8 @@ DistSpec(Poisson, (0.5,), [1, 1]), # Enzyme: no forward rule for ccall - DistSpec(Skellam, (1.0, 2.0), -2, broken=(:Zygote,)), - DistSpec(Skellam, (1.0, 2.0), [-2, -2], broken=(:Zygote,)), + DistSpec(Skellam, (1.0, 2.0), -2, broken=(:Enzyme,)), + DistSpec(Skellam, (1.0, 2.0), [-2, -2], broken=(:Enzyme,)), DistSpec(PoissonBinomial, ([0.5, 0.5],), 0), From f42571d7a9cf64f99a618da866d649d1a1bd4d2b Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Thu, 22 Aug 2024 00:20:50 +0100 Subject: [PATCH 18/18] Update test/runtests.jl --- test/runtests.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 118fb20..4075503 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -16,7 +16,9 @@ import Enzyme Enzyme.API.typeWarning!(false) # Enable runtime activity (workaround) Enzyme.API.runtimeActivity!(true) - +# Supress excessive type deduce failures may result in incorrect gradients. +# https://enzyme.mit.edu/julia/stable/api/#Enzyme.API.looseTypeAnalysis!-Tuple{Any} +Enzyme.API.looseTypeAnalysis!(true) @static if VERSION >= v"1.8" using Pkg; Pkg.status(outdated=true) # show reasons why packages are held back end