Skip to content

Commit

Permalink
Merge ab7cee5 into 835b6d5
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Jun 14, 2024
2 parents 835b6d5 + ab7cee5 commit 6a663bb
Show file tree
Hide file tree
Showing 8 changed files with 651 additions and 174 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Enzyme"
uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9"
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>"]
version = "0.12.13"
version = "0.12.14"

[deps]
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
Expand All @@ -20,7 +20,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
CEnum = "0.4, 0.5"
ChainRulesCore = "1"
EnzymeCore = "0.7.5"
Enzyme_jll = "0.0.121"
Enzyme_jll = "0.0.122"
GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26"
LLVM = "6.1, 7"
ObjectFile = "0.4"
Expand Down
14 changes: 14 additions & 0 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ end
arg = @inbounds args[i]
if arg isa Active
return true
elseif arg isa MixedDuplicated
return true
elseif arg isa BatchMixedDuplicated
return true
else
return false
end
Expand Down Expand Up @@ -95,6 +99,10 @@ end
end

@inline same_or_one_rec(current) = current
@inline same_or_one_rec(current, arg::BatchMixedDuplicated{T, N}, args...) where {T,N} =
same_or_one_rec(same_or_one_helper(current, N), args...)
@inline same_or_one_rec(current, arg::Type{BatchMixedDuplicated{T, N}}, args...) where {T,N} =
same_or_one_rec(same_or_one_helper(current, N), args...)
@inline same_or_one_rec(current, arg::BatchDuplicatedFunc{T, N}, args...) where {T,N} =
same_or_one_rec(same_or_one_helper(current, N), args...)
@inline same_or_one_rec(current, arg::Type{BatchDuplicatedFunc{T, N}}, args...) where {T,N} =
Expand Down Expand Up @@ -844,6 +852,12 @@ result, ∂v, ∂A
else
BatchDuplicatedNoNeed{eltype(A2), width}
end
elseif A2 <: MixedDuplicated && width != 1
if A2 isa UnionAll
BatchMixedDuplicated{T, width} where T
else
BatchMixedDuplicated{eltype(A2), width}
end
else
A2
end
Expand Down
137 changes: 102 additions & 35 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,13 @@ end
return res
end

# check if a value is guaranteed to be not contain active[register] data
# (aka not either mixed or active)
@inline function guaranteed_nonactive(::Type{T}) where T
rt = Enzyme.Compiler.active_reg_nothrow(T, Val(nothing))
return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState
end

@inline Enzyme.guess_activity(::Type{T}, mode::Enzyme.Mode) where {T} = guess_activity(T, convert(API.CDerivativeMode, mode))

@inline function Enzyme.guess_activity(::Type{T}, Mode::API.CDerivativeMode) where {T}
Expand All @@ -555,6 +562,8 @@ end
else
if ActReg == ActiveState
return Active{T}
elseif ActReg == MixedState
return MixedDuplicated{T}
else
return Duplicated{T}
end
Expand Down Expand Up @@ -2494,7 +2503,7 @@ function store_nonjl_types!(B, startval, p)
return
end

function get_julia_inner_types(B, p, startvals...; added=[])
function get_julia_inner_types(B, p, startvals...; added=LLVM.API.LLVMValueRef[])
T_jlvalue = LLVM.StructType(LLVMType[])
T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked)
vals = LLVM.Value[]
Expand Down Expand Up @@ -2547,8 +2556,20 @@ function get_julia_inner_types(B, p, startvals...; added=[])
end
continue
end
GPUCompiler.@safe_warn "Enzyme illegal subtype", ty, cur, SI, p, v
@assert false
if isa(ty, LLVM.IntegerType)
continue
end
if isa(ty, LLVM.FloatingPointType)
continue
end
msg = sprint() do io
println(io, "Enzyme illegal subtype")
println(io, "ty=", ty)
println(io, "cur=", cur)
println(io, "p=", p)
println(io, "startvals=", startvals)
end
throw(AssertionError(msg))
end
return vals
end
Expand Down Expand Up @@ -3474,7 +3495,11 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr
# If requested, the shadow return value of the function
# For each active (non duplicated) argument
# The adjoint of that argument
retType = convert(API.CDIFFE_TYPE, rt)
retType = if rt <: MixedDuplicated || rt <: BatchMixedDuplicated
API.DFT_OUT_DIFF
else
convert(API.CDIFFE_TYPE, rt)
end

rules = Dict{String, API.CustomRuleType}(
"jl_array_copy" => @cfunction(inout_rule,
Expand Down Expand Up @@ -3513,7 +3538,7 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr

if mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient
returnUsed = !(isghostty(actualRetType) || Core.Compiler.isconstType(actualRetType))
shadowReturnUsed = returnUsed && (retType == API.DFT_DUP_ARG || retType == API.DFT_DUP_NONEED)
shadowReturnUsed = returnUsed && (retType == API.DFT_DUP_ARG || retType == API.DFT_DUP_NONEED || rt <: MixedDuplicated || rt <: BatchMixedDuplicated)
returnUsed &= returnPrimal
augmented = API.EnzymeCreateAugmentedPrimal(
logic, primalf, retType, args_activity, TA, #=returnUsed=# returnUsed,
Expand Down Expand Up @@ -3679,16 +3704,20 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,
end

# API.DFT_OUT_DIFF
if is_adjoint && rettype <: Active
@assert !sret_union
if allocatedinline(actualRetType) != allocatedinline(literal_rt)
throw(AssertionError("Base.allocatedinline(actualRetType) != Base.allocatedinline(literal_rt): actualRetType = $(actualRetType), literal_rt = $(literal_rt), rettype = $(rettype)"))
end
if !allocatedinline(actualRetType)
throw(AssertionError("Base.allocatedinline(actualRetType) returns false: actualRetType = $(actualRetType), rettype = $(rettype)"))
if is_adjoint
if rettype <: Active || rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated
@assert !sret_union
if allocatedinline(actualRetType) != allocatedinline(literal_rt)
throw(AssertionError("Base.allocatedinline(actualRetType) != Base.allocatedinline(literal_rt): actualRetType = $(actualRetType), literal_rt = $(literal_rt), rettype = $(rettype)"))
end
if rettype <: Active
if !allocatedinline(actualRetType)
throw(AssertionError("Base.allocatedinline(actualRetType) returns false: actualRetType = $(actualRetType), rettype = $(rettype)"))
end
end
dretTy = LLVM.LLVMType(API.EnzymeGetShadowType(width, convert(LLVMType, actualRetType; allow_boxed=!(rettype <: Active))))
push!(T_wrapperargs, dretTy)
end
dretTy = LLVM.LLVMType(API.EnzymeGetShadowType(width, convert(LLVMType, actualRetType)))
push!(T_wrapperargs, dretTy)
end

data = Array{Int64}(undef, 3)
Expand Down Expand Up @@ -3730,6 +3759,12 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,
else
push!(sret_types, AnonymousStruct(NTuple{width, literal_rt}))
end
elseif rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated
if width == 1
push!(sret_types, Base.RefValue{literal_rt})
else
push!(sret_types, AnonymousStruct(NTuple{width, Base.RefValue{literal_rt}}))
end
end
else
@assert rettype <: Const || rettype <: Active
Expand Down Expand Up @@ -3953,7 +3988,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,
end
end

if is_adjoint && rettype <: Active
if is_adjoint && (rettype <: Active || rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated)
push!(realparms, params[i])
i += 1
end
Expand Down Expand Up @@ -3999,12 +4034,26 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,
if data[i] != -1
eval = extract_value!(builder, val, data[i])
end
if i == 3
if rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated
ival = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, T_prjlvalue)))
for idx in 1:width
pv = (width == 1) ? eval : extract_value!(builder, eval, idx-1)
al0 = al = emit_allocobj!(builder, Base.RefValue{eltype(rettype)})
llty = value_type(pv)
al = bitcast!(builder, al, LLVM.PointerType(llty, addrspace(value_type(al))))
store!(builder, pv, al)
emit_writebarrier!(builder, get_julia_inner_types(builder, al0, pv))
ival = (width == 1 ) ? al0 : insert_value!(builder, ival, al0, idx-1)
end
eval = ival
end
end
eval = fixup_abi(i, eval)
ptr = inbounds_gep!(builder, jltype, sret, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), returnNum)])
ptr = pointercast!(builder, ptr, LLVM.PointerType(value_type(eval)))
si = store!(builder, eval, ptr)
returnNum+=1

if i == 3 && shadow_init
shadows = LLVM.Value[]
if width == 1
Expand Down Expand Up @@ -5943,34 +5992,35 @@ end
end

if !RawCall && !(CC <: PrimalErrorThunk)
if rettype <: Active
if rettype <: Active
if length(argtypes) + is_adjoint + needs_tape != length(argexprs)
return quote
throw(MethodError($CC(fptr), $args))
throw(MethodError($CC(fptr), (fn, args...)))
end
end
elseif rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated
if length(argtypes) + is_adjoint * width + needs_tape != length(argexprs)
return quote
throw(MethodError($CC(fptr), (fn, args...)))
end
end
elseif rettype <: Const
if length(argtypes) + needs_tape != length(argexprs)
return quote
throw(MethodError($CC(fptr), $args))
throw(MethodError($CC(fptr), (fn, args...)))
end
end
else
if length(argtypes) + needs_tape != length(argexprs)
return quote
throw(MethodError($CC(fptr), $args))
throw(MethodError($CC(fptr), (fn, args...)))
end
end
end
end

types = DataType[]

if eltype(rettype) === Union{} && false
return quote
error("Function to differentiate is guaranteed to return an error and doesn't make sense to autodiff. Giving up")
end
end
if !(rettype <: Const) && (isghostty(eltype(rettype)) || Core.Compiler.isconstType(eltype(rettype)) || eltype(rettype) === DataType)
rrt = eltype(rettype)
error("Return type `$rrt` not marked Const, but is ghost or const type.")
Expand Down Expand Up @@ -6133,17 +6183,28 @@ end
end

# API.DFT_OUT_DIFF
if is_adjoint && rettype <: Active
# TODO handle batch width
@assert allocatedinline(jlRT)
j_drT = if width == 1
jlRT
else
NTuple{width, jlRT}
if is_adjoint
if rettype <: Active || rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated
# TODO handle batch width
if rettype <: Active
@assert allocatedinline(jlRT)
end
j_drT = if width == 1
jlRT
else
NTuple{width, jlRT}
end
push!(types, j_drT)
if width == 1 || rettype <: Active
push!(ccexprs, argexprs[i])
i+=1
else
push!(ccexprs, quote
($(argexprs[i:i+width-1]...),)
end)
i+=width
end
end
push!(types, j_drT)
push!(ccexprs, argexprs[i])
i+=1
end

if needs_tape
Expand Down Expand Up @@ -6181,8 +6242,12 @@ end
end
if rettype <: Duplicated || rettype <: DuplicatedNoNeed
push!(sret_types, jlRT)
elseif rettype <: MixedDuplicated
push!(sret_types, Base.RefValue{jlRT})
elseif rettype <: BatchDuplicated || rettype <: BatchDuplicatedNoNeed
push!(sret_types, AnonymousStruct(NTuple{width, jlRT}))
elseif rettype <: BatchMixedDuplicated
push!(sret_types, AnonymousStruct(NTuple{width, Base.RefValue{jlRT}}))
elseif CC <: AugmentedForwardThunk
push!(sret_types, Nothing)
elseif rettype <: Const
Expand Down Expand Up @@ -6406,6 +6471,8 @@ end
@inline remove_innerty(::Type{<:DuplicatedNoNeed}) = DuplicatedNoNeed
@inline remove_innerty(::Type{<:BatchDuplicated}) = Duplicated
@inline remove_innerty(::Type{<:BatchDuplicatedNoNeed}) = DuplicatedNoNeed
@inline remove_innerty(::Type{<:MixedDuplicated}) = MixedDuplicated
@inline remove_innerty(::Type{<:BatchMixedDuplicated}) = MixedDuplicated

@inline @generated function thunk(::Val{World}, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, World, ABI}
JuliaContext() do ctx
Expand Down
Loading

0 comments on commit 6a663bb

Please sign in to comment.