Skip to content

Commit

Permalink
Use new style kwarg constructor for AutoReverseDiff (#2273)
Browse files Browse the repository at this point in the history
  • Loading branch information
mhauru authored Jun 25, 2024
1 parent 7b2869f commit a0db647
Show file tree
Hide file tree
Showing 7 changed files with 14 additions and 12 deletions.
2 changes: 1 addition & 1 deletion benchmarks/benchmarks_suite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,5 +84,5 @@ BenchmarkSuite["mnormal"]["forwarddiff"] = @benchmarkable sample(

# ReverseDiff
BenchmarkSuite["mnormal"]["reversediff"] = @benchmarkable sample(
$(mdemo(d, 1)), $(HMC(0.1, 5; adtype=AutoReverseDiff(false))), 5000
$(mdemo(d, 1)), $(HMC(0.1, 5; adtype=AutoReverseDiff(; compile=false))), 5000
)
12 changes: 7 additions & 5 deletions test/essential/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,22 +154,22 @@ end
return theta ~ Dirichlet(1 ./ fill(4, 4))
end
sample(dir(), HMC(0.01, 1; adtype=AutoZygote()), 1000)
sample(dir(), HMC(0.01, 1; adtype=AutoReverseDiff(false)), 1000)
sample(dir(), HMC(0.01, 1; adtype=AutoReverseDiff(true)), 1000)
sample(dir(), HMC(0.01, 1; adtype=AutoReverseDiff(; compile=false)), 1000)
sample(dir(), HMC(0.01, 1; adtype=AutoReverseDiff(; compile=true)), 1000)
end
@testset "PDMatDistribution AD" begin
@model function wishart()
return theta ~ Wishart(4, Matrix{Float64}(I, 4, 4))
end

sample(wishart(), HMC(0.01, 1; adtype=AutoReverseDiff(false)), 1000)
sample(wishart(), HMC(0.01, 1; adtype=AutoReverseDiff(; compile=false)), 1000)
sample(wishart(), HMC(0.01, 1; adtype=AutoZygote()), 1000)

@model function invwishart()
return theta ~ InverseWishart(4, Matrix{Float64}(I, 4, 4))
end

sample(invwishart(), HMC(0.01, 1; adtype=AutoReverseDiff(false)), 1000)
sample(invwishart(), HMC(0.01, 1; adtype=AutoReverseDiff(; compile=false)), 1000)
sample(invwishart(), HMC(0.01, 1; adtype=AutoZygote()), 1000)
end
@testset "Hessian test" begin
Expand Down Expand Up @@ -231,7 +231,9 @@ end
for i in 1:5
d = Normal(0.0, i)
data = rand(d, N)
chn = sample(demo(data), NUTS(0.65; adtype=AutoReverseDiff(true)), 1000)
chn = sample(
demo(data), NUTS(0.65; adtype=AutoReverseDiff(; compile=true)), 1000
)
@test mean(Array(chn[:sigma])) std(data) atol = 0.5
end
end
Expand Down
2 changes: 1 addition & 1 deletion test/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import ReverseDiff
using Test: @test, @test_throws, @testset
using Turing

@testset "Testing inference.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(false))
@testset "Testing inference.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false))
# Only test threading if 1.3+.
if VERSION > v"1.2"
@testset "threaded sampling" begin
Expand Down
2 changes: 1 addition & 1 deletion test/mcmc/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ using Turing: Inference
using Turing.RandomMeasures: ChineseRestaurantProcess, DirichletProcess

@testset "Testing gibbs.jl with $adbackend" for adbackend in (
AutoForwardDiff(; chunksize=0), AutoReverseDiff(false)
AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false)
)
@testset "gibbs constructor" begin
N = 500
Expand Down
2 changes: 1 addition & 1 deletion test/mcmc/gibbs_conditional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ using Test: @test, @testset
using Turing

@testset "Testing gibbs conditionals.jl with $adbackend" for adbackend in (
AutoForwardDiff(; chunksize=0), AutoReverseDiff(false)
AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false)
)
Random.seed!(1000)
rng = StableRNG(123)
Expand Down
2 changes: 1 addition & 1 deletion test/mcmc/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ using StatsFuns: logistic
using Test: @test, @test_logs, @testset
using Turing

@testset "Testing hmc.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(false))
@testset "Testing hmc.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false))
# Set a seed
rng = StableRNG(123)
@testset "constrained bounded" begin
Expand Down
4 changes: 2 additions & 2 deletions test/mcmc/sghmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ using StableRNGs: StableRNG
using Test: @test, @testset
using Turing

@testset "Testing sghmc.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(false))
@testset "Testing sghmc.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false))
@testset "sghmc constructor" begin
alg = SGHMC(; learning_rate=0.01, momentum_decay=0.1, adtype=adbackend)
@test alg isa SGHMC
Expand All @@ -36,7 +36,7 @@ using Turing
end
end

@testset "Testing sgld.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(false))
@testset "Testing sgld.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false))
@testset "sgld constructor" begin
alg = SGLD(; stepsize=PolynomialStepsize(0.25), adtype=adbackend)
@test alg isa SGLD
Expand Down

0 comments on commit a0db647

Please sign in to comment.