Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Nov 5, 2023
1 parent ef0157c commit 9d34980
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 11 deletions.
3 changes: 2 additions & 1 deletion src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4740,7 +4740,8 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils,

idx = 0
dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(orig)))))
for (v, Ty) in zip(actives, Tys)
Tys2 = (eltype(A) for A in activity[2+isKWCall:end] if A <: Active)
for (v, Ty) in zip(actives, Tys2)
TT = typetree(Ty, ctx, dl)
Typ = C_NULL
ext = extract_value!(B, res, idx)
Expand Down
11 changes: 2 additions & 9 deletions src/internal_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,7 @@ function EnzymeRules.reverse(config, func::Const{typeof(\)}, ::Type{RT}, cache,
return (nothing,nothing)
end

@static if VERSION >= v"1.7-"
# Force a rule around hvcat_fill as it is type unstable if the tuple is not of the same type (e.g., int, float, int, float)
function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.hvcat_fill!)}, ::Type{RT}, out::Annotation{AT}, inp::Annotation{BT}) where {RT, AT <: Array, BT <: Tuple}
primal = if EnzymeRules.needs_primal(config)
Expand All @@ -442,15 +443,6 @@ function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.hvcat_fill
nothing
end
func.val(out.val, inp.val)

if EnzymeRules.width(config) == 1
out.dval .= 0
else
for i in 1:EnzymeRules.width(config)
out.dval[i] .= 0
end
end

return EnzymeRules.AugmentedReturn(primal, shadow, nothing)
end

Expand Down Expand Up @@ -486,3 +478,4 @@ function EnzymeRules.reverse(config, func::Const{typeof(Base.hvcat_fill!)}, ::Ty
end
return (nothing, nothing)
end
end
4 changes: 3 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2544,11 +2544,12 @@ end
@test dA (-z * transpose(y))
end

@static if VERSION >= v"1.7-"
@testset "hvcat_fill" begin
ar = Matrix{Float64}(undef, 2, 3)
dar = [1.0 2.0 3.0; 4.0 5.0 6.0]

res = Enzyme.autodiff(Reverse, Base.hvcat_fill!, Const, Duplicated(ar, dar), Active((1, 2.2, 3, 4.4, 5, 6.6)))
res = first(Enzyme.autodiff(Reverse, Base.hvcat_fill!, Const, Duplicated(ar, dar), Active((1, 2.2, 3, 4.4, 5, 6.6))))

@test res[2][1] == 0
@test res[2][2] 2.0
Expand All @@ -2557,4 +2558,5 @@ end
@test res[2][5] 0
@test res[2][6] 6.0
end
end

0 comments on commit 9d34980

Please sign in to comment.