diff --git a/src/compiler.jl b/src/compiler.jl index 936b30c37e..8ad1511ec5 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1106,7 +1106,8 @@ end function runtime_newtask_fwd(world::Val{World}, fn::FT1, dfn::FT2, post::Any, ssize::Int, ::Val{width}) where {FT1, FT2, World, width} FT = Core.Typeof(fn) - forward = thunk(world, (guaranteed_const(FT) ? Const : Duplicated){FT}, Const, Tuple{}, Val(API.DEM_ForwardMode), Val(width), Val((false,)), #=returnPrimal=#Val(true), #=shadowinit=#Val(false), FFIABI) + ghos = guaranteed_const(FT) + forward = thunk(world, (ghos ? Const : Duplicated){FT}, Const, Tuple{}, Val(API.DEM_ForwardMode), Val(width), Val((false,)), #=returnPrimal=#Val(true), #=shadowinit=#Val(false), FFIABI) ft = ghos ? Const(fn) : Duplicated(fn, dfn) function fclosure() res = forward(ft) @@ -1119,7 +1120,8 @@ end function runtime_newtask_augfwd(world::Val{World}, fn::FT1, dfn::FT2, post::Any, ssize::Int, ::Val{width}, ::Val{ModifiedBetween}) where {FT1, FT2, World, width, ModifiedBetween} # TODO make this AD subcall type stable FT = Core.Typeof(fn) - forward, adjoint = thunk(world, (guaranteed_const(FT) ? Const : Duplicated){FT}, Const, Tuple{}, Val(API.DEM_ReverseModePrimal), Val(width), Val(ModifiedBetween), #=returnPrimal=#Val(true), #=shadowinit=#Val(false), FFIABI) + ghos = guaranteed_const(FT) + forward, adjoint = thunk(world, (ghos ? Const : Duplicated){FT}, Const, Tuple{}, Val(API.DEM_ReverseModePrimal), Val(width), Val(ModifiedBetween), #=returnPrimal=#Val(true), #=shadowinit=#Val(false), FFIABI) ft = ghos ? Const(fn) : Duplicated(fn, dfn) taperef = Ref{Any}() @@ -3283,7 +3285,7 @@ end elseif mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient if dupClosure - ty = active_reg_nothrow(funcTy) + ty = active_reg_nothrow(funcT) has_active = ty == MixedState || ty == ActiveState if has_active refed = true