Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Enzyme tests #254

Closed
wants to merge 18 commits into from
Closed
22 changes: 9 additions & 13 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ jobs:
- x64
group:
- Others
- Enzyme
- ForwardDiff
- Tracker
- ReverseDiff
Expand All @@ -36,6 +37,10 @@ jobs:
os: macOS-latest
arch: x64
group: Others
- version: '1'
os: macOS-latest
arch: x64
group: Enzyme
- version: '1'
os: macOS-latest
arch: x64
Expand All @@ -53,23 +58,14 @@ jobs:
arch: x64
group: Zygote
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- uses: julia-actions/setup-julia@v1
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
- uses: actions/cache@v1
env:
cache-name: cache-artifacts
with:
path: ~/.julia/artifacts
key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }}
restore-keys: |
${{ runner.os }}-test-${{ env.cache-name }}-
${{ runner.os }}-test-
${{ runner.os }}-
- uses: julia-actions/julia-buildpkg@latest
- uses: julia-actions/julia-runtest@latest
- uses: julia-actions/cache@v1
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
with:
coverage: false
env:
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
Expand All @@ -21,6 +22,7 @@ ChainRulesCore = "1"
ChainRulesTestUtils = "1.9.2"
Combinatorics = "1.0.2"
Distributions = "0.25.15"
Enzyme = "0.12"
FiniteDifferences = "0.11.3, 0.12"
ForwardDiff = "0.10.12"
LazyArrays = "1"
Expand Down
8 changes: 4 additions & 4 deletions test/ad/distributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,10 @@
DistSpec(LogNormal, (1.0, 2.0), 0.5),

# Dispatch error caused by ccall
DistSpec(NoncentralBeta, (1.0, 2.0, 1.0), 0.5, broken=(:Tracker, :ForwardDiff, :Zygote, :ReverseDiff)),
DistSpec(NoncentralChisq, (1.0, 2.0), 0.5, broken=(:Tracker, :ForwardDiff, :Zygote, :ReverseDiff)),
DistSpec(NoncentralF, (1.0, 2.0, 1.0), 0.5, broken=(:Tracker, :ForwardDiff, :Zygote, :ReverseDiff)),
DistSpec(NoncentralT, (1.0, 2.0), 0.5, broken=(:Tracker, :ForwardDiff, :Zygote, :ReverseDiff)),
DistSpec(NoncentralBeta, (1.0, 2.0, 1.0), 0.5, broken=(:Enzyme, :Tracker, :ForwardDiff, :Zygote, :ReverseDiff)),
DistSpec(NoncentralChisq, (1.0, 2.0), 0.5, broken=(:Enzyme, :Tracker, :ForwardDiff, :Zygote, :ReverseDiff)),
DistSpec(NoncentralF, (1.0, 2.0, 1.0), 0.5, broken=(:Enzyme, :Tracker, :ForwardDiff, :Zygote, :ReverseDiff)),
DistSpec(NoncentralT, (1.0, 2.0), 0.5, broken=(:Enzyme, :Tracker, :ForwardDiff, :Zygote, :ReverseDiff)),

DistSpec(Normal, (), 0.5),
DistSpec(Normal, (1.0,), 0.5),
Expand Down
2 changes: 1 addition & 1 deletion test/ad/others.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
A = to_posdef(rand(3, 3))
B = to_posdef(rand(3, 3))

test_reverse_mode_ad(randn(3, 3), A, B) do A, B
test_reverse_mode_ad(randn(3, 3), A, B; broken = (:Enzyme,)) do A, B
return DistributionsAD.zygote_ldiv(A, B)
end
end
Expand Down
83 changes: 70 additions & 13 deletions test/ad/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ using FiniteDifferences
const FDM = FiniteDifferences

# Load AD backends
if GROUP == "All" || GROUP == "Enzyme"
@eval begin
using Enzyme
end
end
if GROUP == "All" || GROUP == "ForwardDiff"
@eval using ForwardDiff
end
Expand All @@ -18,13 +23,63 @@ if GROUP == "All" || GROUP == "Tracker"
@eval using Tracker
end

function test_reverse_mode_ad(f, ȳ, x...; rtol=1e-6, atol=1e-6)
function test_reverse_mode_ad(f, ȳ, x...; rtol=1e-6, atol=1e-6, broken=())
# Perform a regular forwards-pass.
y = f(x...)

# Use finite differencing to compute reverse-mode sensitivities.
x̄s_fdm = FDM.j′vp(central_fdm(5, 1), f, ȳ, x...)

if GROUP == "All" || GROUP == "Enzyme"
enzyme_broken = :Enzyme in broken
io = enzyme_broken ? devnull : stdout
testset = redirect_stdout(io) do
# Use Enzyme to compute reverse-mode sensitivities.
@testset "Enzyme: Reverse-mode AD of $f" begin
x̄s_enzyme_init = map(x) do xi
xi isa Real ? nothing : zero(xi)
end
enzyme_autodiff_args = map(x, x̄s_enzyme_init) do xi, x̄si
return if x̄si === nothing
@assert xi isa Real
Active(xi)
else
@assert typeof(xi) === typeof(x̄si)
Duplicated(xi, x̄si)
end
end
dot_f_ȳ(args...) = dot(f(args...), ȳ)
x̄s_enzyme_autodiff, y_dot_ȳ_enzyme =
Enzyme.autodiff(ReverseWithPrimal, dot_f_ȳ, Active, enzyme_autodiff_args...)
yebai marked this conversation as resolved.
Show resolved Hide resolved
x̄s_enzyme = map(x̄s_enzyme_init, x̄s_enzyme_autodiff) do x̄s_init_i, x̄s_autodiff_i
return if x̄s_init_i === nothing
@assert x̄s_autodiff_i isa Real
x̄s_autodiff_i
else
@assert x̄s_autodiff_i === nothing
x̄s_init_i
end
end

# Check that Enzyme primal is correct.
@test dot(y, ȳ) ≈ y_dot_ȳ_enzyme atol=atol rtol=rtol

# Check that Enzyme reverse-mode sensitivities are correct.
@test all(zip(x̄s_enzyme, x̄s_fdm)) do (x̄_enzyme, x̄_fdm)
return isapprox(x̄_enzyme, x̄_fdm; atol=atol, rtol=rtol)
end
end
end

# change errors and fails to broken results, and count number of errors and fails
efs = errors_to_broken!(testset)

# ensure that passing tests are not marked as broken
if iszero(efs) && enzyme_broken
error("Enzyme tests of $f passed unexpectedly, please mark not as broken")
end
end

if GROUP == "All" || GROUP == "Zygote"
# Use Zygote to compute reverse-mode sensitivities.
y_zygote, back_zygote = Zygote.pullback(f, x...)
Expand Down Expand Up @@ -350,6 +405,19 @@ end
function test_ad(f, x, broken = (); rtol = 1e-6, atol = 1e-6)
finitediff = FDM.grad(central_fdm(5, 1), f, x)[1]

if GROUP == "All" || GROUP == "Enzyme"
if (:Enzyme in broken) || (:EnzymeForward in broken)
@test_broken collect(Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff rtol=rtol atol=atol
else
@test collect(Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff rtol=rtol atol=atol
end
if (:Enzyme in broken) || (:EnzymeReverse in broken)
@test_broken Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff rtol=rtol atol=atol
else
@test Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff rtol=rtol atol=atol
end
end

if GROUP == "All" || GROUP == "Tracker"
if :Tracker in broken
@test_broken Tracker.data(Tracker.gradient(f, x)[1]) ≈ finitediff rtol=rtol atol=atol
Expand Down Expand Up @@ -398,7 +466,7 @@ end

function testset_zygote_broken(distspec, args...; kwargs...)
# don't show test errors - tests are known to be broken :)
testset = suppress_stdout() do
testset = redirect_stdout(devnull) do
testset_zygote(distspec, args...; kwargs...)
end

Expand All @@ -417,17 +485,6 @@ function testset_zygote_broken(distspec, args...; kwargs...)
return testset
end

# `redirect_stdout(f, devnull)` is only available in Julia >= 1.6
function suppress_stdout(f)
@static if VERSION < v"1.6"
open((@static Sys.iswindows() ? "NUL" : "/dev/null"), "w") do devnull
redirect_stdout(f, devnull)
end
else
redirect_stdout(f, devnull)
end
end

# change test errors and failures to broken results
function errors_to_broken!(ts::Test.DefaultTestSet)
results = ts.results
Expand Down
7 changes: 6 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ using DistributionsAD: TuringMvNormal, TuringMvLogNormal,
TuringPoissonBinomial, TuringDirichlet
using StatsFuns: StatsFuns, logsumexp, logistic

import Enzyme
Enzyme.API.typeWarning!(false)
# Enable runtime activity (workaround)
Enzyme.API.runtimeActivity!(true)

yebai marked this conversation as resolved.
Show resolved Hide resolved
@static if VERSION >= v"1.8"
using Pkg; Pkg.status(outdated=true) # show reasons why packages are held back
end
Expand All @@ -26,7 +31,7 @@ if GROUP == "All" || GROUP == "Others"
include("others.jl")
end

if GROUP == "All" || GROUP in ("ForwardDiff", "Zygote", "ReverseDiff", "Tracker")
if GROUP == "All" || GROUP in ("Enzyme", "ForwardDiff", "Zygote", "ReverseDiff", "Tracker")
include("ad/utils.jl")
include("ad/others.jl")
include("ad/distributions.jl")
Expand Down
Loading