Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enzyme doesn't work for AdvancedVI #1548

Closed
Red-Portal opened this issue Jun 20, 2024 · 6 comments
Closed

Enzyme doesn't work for AdvancedVI #1548

Red-Portal opened this issue Jun 20, 2024 · 6 comments

Comments

@Red-Portal
Copy link

Red-Portal commented Jun 20, 2024

The following snippet, which is the key element of AdvancedVI, doesn't work:

using Enzyme, Functors, Optimisers, Distributions, LinearAlgebra, Random

struct MvLocationScale{
    S, D <: ContinuousDistribution, L
} <: ContinuousMultivariateDistribution
    location ::L
    scale    ::S
    dist     ::D
end

@functor (location, scale)

function Distributions.rand(
    rng::AbstractRNG, q::MvLocationScale{S, D, L}, num_samples::Int
)  where {S, D, L}
    (; location, scale, dist) = q
    n_dims = length(location)
    scale*rand(rng, dist, n_dims, num_samples) .+ location
end

function f(params, aux)
	(; n_samples, rng, restructure) = aux
    q = restructure(params)
    samples = rand(rng, q, n_samples)
    sum(samples)
end

function main()
	d = 10
	m = zeros(d)
	C = LowerTriangular(Matrix{Float64}(I, d, d))
	q = MvLocationScale(m, C, Normal())
	params, re = Optimisers.destructure(q)
	aux = (
		rng         = Random.default_rng(),
		restructure = re,
		n_samples   = 4,
	)
	
	println(f(params, aux))
	
	x = ones(length(params))
	∇x = zeros(length(params))
	Enzyme.API.runtimeActivity!(true)
	Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(x, ∇x), Enzyme.Const(aux))
	∇x
end

This yields an error message (quite a mouthful!) that can be found on this pastebin.

Skipping the rand, which means we are only calling restructure as follows:

function f(params, aux)
    (; n_samples, rng, restructure) = aux
    q = restructure(params)
    sum(q.location)
end

results in a different error that can be found here.

@Red-Portal Red-Portal changed the title Reparameterizationg gradient doesn't work? Enzyme doesn't work for AdvancedVI Jun 20, 2024
@wsmoses
Copy link
Member

wsmoses commented Jun 21, 2024

Looks like you hit our just added [and accidentally released] direct trmm support. This EnzymeAD/Enzyme#1933 should fix the issue on a jll bump.

That fixes the first issue (which incidentally doesn't hit the second once fixed).

The second issue should be fixed by #1550

@wsmoses
Copy link
Member

wsmoses commented Jun 21, 2024

Merged on main closing. Will make a release once things run

@wsmoses wsmoses closed this as completed Jun 21, 2024
@Red-Portal
Copy link
Author

Thanks! I'll give it a go after the new release and re-open if there are issues still looming.

@wsmoses
Copy link
Member

wsmoses commented Jun 21, 2024

started the registration, it'll probably be out in ~1hr or so (but you may need to use julia pkg servers set to eager if they've already checked with the package servers today)

@Red-Portal
Copy link
Author

Red-Portal commented Jun 22, 2024

Got a new case, so this issue will have to be re-opened:

using Enzyme, Functors, Optimisers, Distributions, LinearAlgebra, Random, StatsBase

struct MvLocationScale{
    S, D <: ContinuousDistribution, L
} <: ContinuousMultivariateDistribution
    location ::L
    scale    ::S
    dist     ::D
end

@functor MvLocationScale (location, scale)

function StatsBase.entropy(q::MvLocationScale)
    (; location, scale, dist) = q
    n_dims = length(location)
    # `convert` is necessary because `entropy` is not type stable upstream
    n_dims*convert(eltype(location), entropy(dist)) + logdet(scale)
end

function Distributions.rand(
    rng::AbstractRNG, q::MvLocationScale{S, D, L}, num_samples::Int
)  where {S, D, L}
    (; location, scale, dist) = q
    n_dims = length(location)
    scale*rand(rng, dist, n_dims, num_samples) .+ location
end

function reparam_with_entropy(
    rng      ::Random.AbstractRNG,
    q,
    n_samples::Int,
)
    samples = rand(rng, q, n_samples)
    samples, entropy(q)
end

function f(params, aux)
	(; n_samples, rng, restructure) = aux
    q = restructure(params)
    samples, ent = reparam_with_entropy(rng, q, 10)
	sum(samples) + ent
end

function main()
	d = 10
	m = zeros(d)
	C = LowerTriangular(Matrix{Float64}(I, d, d))
	q = MvLocationScale(m, C, Normal())
	params, re = Optimisers.destructure(q)
	aux = (
		rng         = Random.default_rng(),
		restructure = re,
		n_samples   = 4,
	)
	
	println(f(params, aux))
	
	x = ones(length(params))
	∇x = zeros(length(params))
	Enzyme.API.runtimeActivity!(true)
	Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(x, ∇x), Enzyme.Const(aux))
	∇x
end

This spits out:

ERROR: Enzyme execution failed.
Enzyme could not find shadow for value

 Inverted pointers: 
available inversion for { {} addrspace(10)*, [1 x {} addrspace(10)*], [2 x double] } addrspace(11)* %0 of { {} addrspace(10)*, [1 x {} addrspace(10)*], [2 x double] } addrspace(11)* %"'"
available inversion for   %6 = getelementptr inbounds { {} addrspace(10)*, [1 x {} addrspace(10)*], [2 x double] }, { {} addrspace(10)*, [1 x {} addrspace(10)*], [2 x double] } addrspace(11)* %0, i64 0, i32 1, !dbg !55 of   %"'ipg" = getelementptr inbounds { {} addrspace(10)*, [1 x {} addrspace(10)*], [2 x double] }, { {} addrspace(10)*, [1 x {} addrspace(10)*], [2 x double] } addrspace(11)* %"'", i64 0, i32 1, !dbg !68

cannot find shadow for   %5 = call fastcc nonnull {} addrspace(10)* @julia_rand_8346({ {} addrspace(10)*, [1 x {} addrspace(10)*], [2 x double] } addrspace(11)* noalias nocapture nofree noundef nonnull readonly align 8 dereferenceable(32) %0, i64 signext %1) #51, !dbg !54, !noalias !42

Caused by:
Stacktrace:
 [1] reparam_with_entropy
   @ ./REPL[65]:6
 [2] reparam_with_entropy
   @ ./REPL[65]:0


Stacktrace:
  [1] throwerr(cstr::Cstring)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/RsJic/src/compiler.jl:1612
  [2] reparam_with_entropy
    @ ./REPL[65]:0 [inlined]
  [3] augmented_julia_reparam_with_entropy_8321_inner_1wrap
    @ ./REPL[65]:0
  [4] macro expansion
    @ ~/.julia/packages/Enzyme/RsJic/src/compiler.jl:6587 [inlined]
  [5] enzyme_call(::Val{…}, ::Ptr{…}, ::Type{…}, ::Val{…}, ::Val{…}, ::Type{…}, ::Type{…}, ::Const{…}, ::Type{…}, ::Const{…}, ::MixedDuplicated{…}, ::Const{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/RsJic/src/compiler.jl:6188
  [6] (::Enzyme.Compiler.AugmentedForwardThunk{…})(::Const{…}, ::Const{…}, ::Vararg{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/RsJic/src/compiler.jl:6076
  [7] runtime_generic_augfwd(activity::Type{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::typeof(reparam_with_entropy), df::Nothing, primal_1::TaskLocalRNG, shadow_1_1::Nothing, primal_2::MvLocationScale{…}, shadow_2_1::Base.RefValue{…}, primal_3::Int64, shadow_3_1::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/RsJic/src/rules/jitrules.jl:311
  [8] f
    @ ./REPL[66]:4 [inlined]
  [9] f
    @ ./REPL[66]:0 [inlined]
 [10] augmented_julia_f_8166_inner_1wrap
    @ ./REPL[66]:0
 [11] macro expansion
    @ ~/.julia/packages/Enzyme/RsJic/src/compiler.jl:6587 [inlined]
 [12] enzyme_call
    @ ~/.julia/packages/Enzyme/RsJic/src/compiler.jl:6188 [inlined]
 [13] AugmentedForwardThunk
    @ ~/.julia/packages/Enzyme/RsJic/src/compiler.jl:6076 [inlined]
 [14] autodiff(::ReverseMode{…}, ::Const{…}, ::Type{…}, ::Duplicated{…}, ::Const{…})
    @ Enzyme ~/.julia/packages/Enzyme/RsJic/src/Enzyme.jl:253
 [15] autodiff
    @ ~/.julia/packages/Enzyme/RsJic/src/Enzyme.jl:321 [inlined]
 [16] main()
    @ Main ./REPL[67]:18
 [17] top-level scope
    @ REPL[68]:1
Some type information was truncated. Use `show(err)` to see complete types.

@wsmoses
Copy link
Member

wsmoses commented Jun 22, 2024

Can you open that as a separate issue so we can track it?

Will work on that next

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants