diff --git a/src/compiler.jl b/src/compiler.jl index f79e6b1d17..657f026703 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -438,15 +438,25 @@ end throw(AssertionError("Type $T is not concrete type or concrete tuple")) end - if Val(T) ∈ seen + @static if VERSION < v"1.7.0" + nT = T + else + nT = if is_concrete_tuple(T) && any(T2 isa Core.TypeofVararg for T2 in T.parameters) + Tuple{((T2 isa Core.TypeofVararg ? Any : T2) for T2 in T.parameters)...,} + else + T + end + end + + if Val(nT) ∈ seen return MixedState end - seen = (Val(T), seen...) + seen = (Val(nT), seen...) fty = Merger{seen,typeof(world),justActive, UnionSret}(world) - ty = forcefold(Val(AnyState), ntuple(fty, Val(fieldcount(T)))...) + ty = forcefold(Val(AnyState), ntuple(fty, Val(fieldcount(nT)))...) return ty end diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index 50f72509e4..a3689470a8 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -742,16 +742,17 @@ function common_apply_iterate_augfwd(offset, B, orig, gutils, normalR, shadowR, width = get_width(gutils) - if v && v2 && isiter == Base.iterate && istup == Base.tuple && length(operands(orig)) == offset+4 + if v && v2 && isiter == Base.iterate && istup == Base.tuple && length(operands(orig)) >= offset+4 origops = collect(operands(orig)[1:end-1]) - shadowin = invert_pointer(gutils, origops[offset + 3], B) + shadowins = [ invert_pointer(gutils, origops[i], B) for i in (offset+3):length(origops) ] shadowres = if width == 1 - emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(error_if_active), shadowin]) newops = LLVM.Value[] newvals = API.CValueType[] for (i, v) in enumerate(origops) - if i == offset + 3 - push!(newops, shadowin) + if i >= offset + 3 + shadowin2 = shadowins[i-offset-3+1] + emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(error_if_active), shadowin2]) + push!(newops, shadowin2) push!(newvals, API.VT_Shadow) else push!(newops, new_from_original(gutils, origops[i])) @@ -765,13 +766,12 @@ function common_apply_iterate_augfwd(offset, B, orig, gutils, normalR, shadowR, ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) shadow = LLVM.UndefValue(ST) for j in 1:width - shadowin2 = extract_value!(B, shadowin, j-1) - emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(error_if_active), shadowin2]) - newops = LLVM.Value[] newvals = API.CValueType[] for (i, v) in enumerate(origops) - if i == offset + 3 + if i >= offset + 3 + shadowin2 = extract_value!(B, shadowins[i-offset-3+1], j-1) + emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(error_if_active), shadowin2]) push!(newops, shadowin2) push!(newvals, API.VT_Shadow) else @@ -787,12 +787,12 @@ function common_apply_iterate_augfwd(offset, B, orig, gutils, normalR, shadowR, end unsafe_store!(shadowR, shadowres.ref) - return false end - emit_error(B, orig, "Enzyme: Not yet implemented augmented forward for jl_f__apply_iterate") + emit_error(B, orig, "Enzyme: Not yet implemented augmented forward for jl_f__apply_iterate "*string((v, v2, isiter, istup, length(operands(orig)), offset+4))) + unsafe_store!(shadowR,UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))).ref) return false end diff --git a/test/runtests.jl b/test/runtests.jl index f2e7864f9f..7d3d2f361e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1835,6 +1835,17 @@ end ddata = [[0.0], nothing, 0.0] @test_throws AssertionError Enzyme.autodiff(Reverse, mktup2, Duplicated(data, ddata)) + + function mktup3(v) + tup = tuple(v..., v...) + return tup[1][1] * tup[1][1] + end + + data = [[3.0]] + ddata = [[0.0]] + + Enzyme.autodiff(Reverse, mktup3, Duplicated(data, ddata)) + @test ddata[1][1] ≈ 6.0 end @testset "BLAS" begin