Skip to content

Commit

Permalink
[Nodecayed phis] handle repeated phi entry (#1127)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Nov 4, 2023
1 parent 9a8d2cb commit e4774df
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9294,7 +9294,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget};

if params.run_enzyme
# Generate the adjoint
jlrules = String[]
jlrules = String["enzyme_custom"]
for (fname, (ftyp, mi)) in foundTys
haskey(functions(mod), fname) || continue
push!(jlrules, fname)
Expand Down
58 changes: 58 additions & 0 deletions src/internal_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -428,3 +428,61 @@ function EnzymeRules.reverse(config, func::Const{typeof(\)}, ::Type{RT}, cache,

return (nothing,nothing)
end

# Force a rule around hvcat_fill as it is type unstable if the tuple is not of the same type (e.g., int, float, int, float)
function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.hvcat_fill!)}, ::Type{RT}, out::Annotation{AT}, inp::Annotation{BT}) where {RT, AT <: Array, BT <: Tuple}
primal = if EnzymeRules.needs_primal(config)
out.val
else
nothing
end
shadow = if EnzymeRules.needs_shadow(config)
out.dval
else
nothing
end
func.val(out.val, inp.val)

if EnzymeRules.width(config) == 1
out.dval .= 0
else
for i in 1:EnzymeRules.width(config)
out.dval[i] .= 0
end
end

return EnzymeRules.AugmentedReturn(primal, shadow, nothing)
end

function EnzymeRules.reverse(config, func::Const{typeof(Base.hvcat_fill!)}, ::Type{RT}, _, out::Annotation{AT}, inp::Annotation{BT}) where {RT, AT <: Array, BT <: Tuple}
nr, nc = size(out.val,1), size(out.val,2)
for b in 1:EnzymeRules.width(config)
da = if EnzymeRules.width(config) == 1
out.dval
else
out.dval[b]
end
i = 1
j = 1
dinp = if (typeof(inp) <: Const)
nothing
else
ntuple(Val(length(inp.val))) do k
Base.@_inline_meta
res = da[i, j]
da[i, j] = 0
j += 1
if j == nc
i += 1
j = 1
end
if typeof(inp.val[k]) <: AbstractFloat
typeof(inp.val[k])(res)
else
typeof(inp.val[k])(0)
end
end
end
return (nothing, dinp)
end
end
15 changes: 15 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2543,3 +2543,18 @@ end
y = A \ b
@test dA (-z * transpose(y))
end

@testset "hvcat_fill" begin
ar = Matrix{Float64}(undef, 2, 3)
dar = [1.0 2.0 3.0; 4.0 5.0 6.0]

res = Enzyme.autodiff(Reverse, Base.hvcat_fill!, Const, Duplicated(ar, dar), Active((1, 2.2, 3, 4.4, 5, 6.6)))

@test res[2][1] == 0
@test res[2][2] 2.0
@test res[2][3] 0
@test res[2][4] 4.0
@test res[2][5] 0
@test res[2][6] 6.0
end

0 comments on commit e4774df

Please sign in to comment.