Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Sep 24, 2023
1 parent 5d2af29 commit a3fc6ef
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1106,7 +1106,8 @@ end

function runtime_newtask_fwd(world::Val{World}, fn::FT1, dfn::FT2, post::Any, ssize::Int, ::Val{width}) where {FT1, FT2, World, width}
FT = Core.Typeof(fn)
forward = thunk(world, (guaranteed_const(FT) ? Const : Duplicated){FT}, Const, Tuple{}, Val(API.DEM_ForwardMode), Val(width), Val((false,)), #=returnPrimal=#Val(true), #=shadowinit=#Val(false), FFIABI)
ghos = guaranteed_const(FT)
forward = thunk(world, (ghos ? Const : Duplicated){FT}, Const, Tuple{}, Val(API.DEM_ForwardMode), Val(width), Val((false,)), #=returnPrimal=#Val(true), #=shadowinit=#Val(false), FFIABI)
ft = ghos ? Const(fn) : Duplicated(fn, dfn)
function fclosure()
res = forward(ft)
Expand All @@ -1119,7 +1120,8 @@ end
function runtime_newtask_augfwd(world::Val{World}, fn::FT1, dfn::FT2, post::Any, ssize::Int, ::Val{width}, ::Val{ModifiedBetween}) where {FT1, FT2, World, width, ModifiedBetween}
# TODO make this AD subcall type stable
FT = Core.Typeof(fn)
forward, adjoint = thunk(world, (guaranteed_const(FT) ? Const : Duplicated){FT}, Const, Tuple{}, Val(API.DEM_ReverseModePrimal), Val(width), Val(ModifiedBetween), #=returnPrimal=#Val(true), #=shadowinit=#Val(false), FFIABI)
ghos = guaranteed_const(FT)
forward, adjoint = thunk(world, (ghos ? Const : Duplicated){FT}, Const, Tuple{}, Val(API.DEM_ReverseModePrimal), Val(width), Val(ModifiedBetween), #=returnPrimal=#Val(true), #=shadowinit=#Val(false), FFIABI)
ft = ghos ? Const(fn) : Duplicated(fn, dfn)
taperef = Ref{Any}()

Expand Down Expand Up @@ -3283,7 +3285,7 @@ end

elseif mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient
if dupClosure
ty = active_reg_nothrow(funcTy)
ty = active_reg_nothrow(funcT)
has_active = ty == MixedState || ty == ActiveState
if has_active
refed = true
Expand Down

0 comments on commit a3fc6ef

Please sign in to comment.