Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MixedDuplicated for custom rules #1534

Merged
merged 7 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions lib/EnzymeCore/src/EnzymeCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,32 @@ end
@inline batch_size(::Type{BatchDuplicatedNoNeed{T,N}}) where {T,N} = N


"""
MixedDuplicated(x, ∂f_∂x)

Like [`Duplicated`](@ref), except x may contain both active [immutable] and duplicated [mutable]
data which is differentiable. Only used within custom rules.
"""
struct MixedDuplicated{T} <: Annotation{T}
val::T
dval::Base.RefValue{T}
@inline MixedDuplicated(x::T1, dx::Base.RefValue{T1}, check::Bool=true) where {T1} = new{T1}(x, dx)
end

"""
BatchMixedDuplicated(x, ∂f_∂xs)

Like [`MixedDuplicated`](@ref), except contains several shadows to compute derivatives
for all at once. Only used within custom rules.
"""
struct BatchMixedDuplicated{T,N} <: Annotation{T}
val::T
dval::NTuple{N,Base.RefValue{T}}
@inline BatchMixedDuplicated(x::T1, dx::NTuple{N,Base.RefValue{T1}}, check::Bool=true) where {T1, N} = new{T1, N}(x, dx)
end
@inline batch_size(::BatchMixedDuplicated{T,N}) where {T,N} = N
@inline batch_size(::Type{BatchMixedDuplicated{T,N}}) where {T,N} = N

"""
abstract type ABI

Expand Down
3 changes: 3 additions & 0 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ export Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated,
import EnzymeCore: BatchDuplicatedFunc
export BatchDuplicatedFunc

import EnzymeCore: MixedDuplicated, BatchMixedDuplicated
export MixedDuplicated, BatchMixedDuplicated

import EnzymeCore: batch_size, get_func
export batch_size, get_func

Expand Down
177 changes: 164 additions & 13 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2450,6 +2450,50 @@ else
end
end

function store_nonjl_types!(B, startval, p)
T_jlvalue = LLVM.StructType(LLVMType[])
T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked)
vals = LLVM.Value[]
if p != nothing
push!(vals, p)
end
todo = Tuple{Tuple, LLVM.Value}[((), startval)]
while length(todo) != 0
path, cur = popfirst!(todo)
ty = value_type(cur)
if isa(ty, LLVM.PointerType)
if any_jltypes(ty)
continue
end
end
if isa(ty, LLVM.ArrayType)
if any_jltypes(ty)
for i=1:length(ty)
ev = extract_value!(B, cur, i-1)
push!(todo, ((path..., i-1), ev))
end
continue
end
end
if isa(ty, LLVM.StructType)
if any_jltypes(ty)
for (i, t) in enumerate(LLVM.elements(ty))
ev = extract_value!(B, cur, i-1)
push!(todo, ((path..., i-1), ev))
end
continue
end
end
parray = LLVM.Value[LLVM.ConstantInt(LLVM.IntType(64), 0)]
for v in path
push!(parray, LLVM.ConstantInt(LLVM.IntType(32), v))
end
gptr = gep!(B, value_type(startval), p, parray)
st = store!(B, cur, gptr)
end
return
end

function get_julia_inner_types(B, p, startvals...; added=[])
T_jlvalue = LLVM.StructType(LLVMType[])
T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked)
Expand Down Expand Up @@ -3385,7 +3429,7 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr
else
push!(args_activity, API.DFT_OUT_DIFF)
end
elseif T <: Duplicated || T<: BatchDuplicated || T<: BatchDuplicatedFunc
elseif T <: Duplicated || T<: BatchDuplicated || T<: BatchDuplicatedFunc || T <: MixedDuplicated || T <: BatchMixedDuplicated
push!(args_activity, API.DFT_DUP_ARG)
elseif T <: DuplicatedNoNeed || T<: BatchDuplicatedNoNeed
push!(args_activity, API.DFT_DUP_NONEED)
Expand Down Expand Up @@ -3569,7 +3613,6 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,

isboxed = GPUCompiler.deserves_argbox(source_typ)
llvmT = isboxed ? T_prjlvalue : convert(LLVMType, source_typ)

push!(T_wrapperargs, llvmT)

if T <: Const || T <: BatchDuplicatedFunc
Expand Down Expand Up @@ -3598,6 +3641,11 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,
if is_adjoint && i != 1
push!(ActiveRetTypes, Nothing)
end
elseif T <: MixedDuplicated || T <: BatchMixedDuplicated
push!(T_wrapperargs, LLVM.LLVMType(API.EnzymeGetShadowType(width, T_prjlvalue)))
if is_adjoint && i != 1
push!(ActiveRetTypes, Nothing)
end
else
error("calling convention should be annotated, got $T")
end
Expand Down Expand Up @@ -3780,7 +3828,23 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,
if isghostty(T′) || Core.Compiler.isconstType(T′)
continue
end
push!(realparms, params[i])

isboxed = GPUCompiler.deserves_argbox(T′)

llty = value_type(params[i])

convty = convert(LLVMType, T′; allow_boxed=true)

if (T <: MixedDuplicated || T <: BatchMixedDuplicated) && !isboxed # && (isa(llty, LLVM.ArrayType) || isa(llty, LLVM.StructType))
al = emit_allocobj!(builder, Base.RefValue{T′})
al = bitcast!(builder, al, LLVM.PointerType(llty, addrspace(value_type(al))))
store!(builder, params[i], al)
al = addrspacecast!(builder, al, LLVM.PointerType(llty, Derived))
push!(realparms, al)
else
push!(realparms, params[i])
end

i += 1
if T <: Const
elseif T <: Active
Expand Down Expand Up @@ -3808,6 +3872,34 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,
elseif T <: Duplicated || T <: DuplicatedNoNeed
push!(realparms, params[i])
i += 1
elseif T <: MixedDuplicated || T <: BatchMixedDuplicated
parmsi = params[i]

if T <: BatchMixedDuplicated
if GPUCompiler.deserves_argbox(NTuple{width, Base.RefValue{T′}})
njlvalue = LLVM.ArrayType(Int(width), T_prjlvalue)
parmsi = bitcast!(builder, parmsi, LLVM.PointerType(njlvalue, addrspace(value_type(parmsi))))
parmsi = load!(builder, njlvalue, parmsi)
end
end

isboxed = GPUCompiler.deserves_argbox(T′)

resty = isboxed ? llty : LLVM.PointerType(llty, Derived)

ival = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, resty)))
for idx in 1:width
pv = (width == 1) ? parmsi : extract_value!(builder, parmsi, idx-1)
pv = bitcast!(builder, pv, LLVM.PointerType(llty, addrspace(value_type(pv))))
pv = addrspacecast!(builder, pv, LLVM.PointerType(llty, Derived))
if isboxed
pv = load!(builder, llty, pv, "mixedboxload")
end
ival = (width == 1 ) ? pv : insert_value!(builder, ival, pv, idx-1)
end

push!(realparms, ival)
i += 1
elseif T <: BatchDuplicated || T <: BatchDuplicatedNoNeed
isboxed = GPUCompiler.deserves_argbox(NTuple{width, T′})
val = params[i]
Expand Down Expand Up @@ -4338,6 +4430,7 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function

# generate the wrapper function type & definition
wrapper_types = LLVM.LLVMType[]
wrapper_attrs = Vector{LLVM.Attribute}[]
_, sret, returnRoots = get_return_info(actualRetType)
sret_union = is_sret_union(actualRetType)

Expand Down Expand Up @@ -4372,31 +4465,44 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function

if swiftself
push!(wrapper_types, value_type(parameters(entry_f)[1+sret+returnRoots]))
push!(wrapper_attrs, LLVM.Attribute[EnumAttribute("swiftself")])
end

boxedArgs = Set{Int}()
loweredArgs = Set{Int}()
raisedArgs = Set{Int}()

for arg in args
typ = arg.codegen.typ
if GPUCompiler.deserves_argbox(arg.typ)
push!(boxedArgs, arg.arg_i)
push!(wrapper_types, typ)
push!(wrapper_attrs, LLVM.Attribute[])
elseif arg.cc != GPUCompiler.BITS_REF
push!(wrapper_types, typ)
if TT != nothing && (TT.parameters[arg.arg_i] <: MixedDuplicated || TT.parameters[arg.arg_i] <: BatchMixedDuplicated)
push!(boxedArgs, arg.arg_i)
push!(raisedArgs, arg.arg_i)
push!(wrapper_types, LLVM.PointerType(typ, Derived))
push!(wrapper_attrs, LLVM.Attribute[EnumAttribute("noalias")])
else
push!(wrapper_types, typ)
push!(wrapper_attrs, LLVM.Attribute[])
end
else
# bits ref, and not boxed
# if TT.parameters[arg.arg_i] <: Const
# push!(boxedArgs, arg.arg_i)
# push!(wrapper_types, typ)
# else
if TT != nothing && (TT.parameters[arg.arg_i] <: MixedDuplicated || TT.parameters[arg.arg_i] <: BatchMixedDuplicated)
push!(boxedArgs, arg.arg_i)
push!(wrapper_types, typ)
push!(wrapper_attrs, LLVM.Attribute[EnumAttribute("noalias")])
else
push!(wrapper_types, eltype(typ))
push!(wrapper_attrs, LLVM.Attribute[])
push!(loweredArgs, arg.arg_i)
# end
end
end
end

if length(loweredArgs) == 0 && !sret && !sret_union
if length(loweredArgs) == 0 && length(raisedArgs) == 0 && !sret && !sret_union
return entry_f, returnRoots, boxedArgs, loweredArgs
end

Expand All @@ -4417,8 +4523,10 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function
end
push!(function_attributes(wrapper_f), EnumAttribute("returns_twice"))
push!(function_attributes(entry_f), EnumAttribute("returns_twice"))
if swiftself
push!(parameter_attributes(wrapper_f, 1), EnumAttribute("swiftself"))
for (i, v) in enumerate(wrapper_attrs)
for attr in v
push!(parameter_attributes(wrapper_f, i), attr)
end
end

seen = TypeTreeTable()
Expand All @@ -4444,6 +4552,12 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function
parm = ops[arg.codegen.i]
if arg.arg_i in loweredArgs
push!(nops, load!(builder, convert(LLVMType, arg.typ), parm))
elseif arg.arg_i in raisedArgs
obj = emit_allocobj!(builder, arg.typ)
bc = bitcast!(builder, obj, LLVM.PointerType(value_type(parm), addrspace(value_type(obj))))
store!(builder, parm, bc)
addr = addrspacecast!(builder, bc, LLVM.PointerType(value_type(parm), Derived))
push!(nops, addr)
else
push!(nops, parm)
end
Expand Down Expand Up @@ -4528,6 +4642,13 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function
push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzyme_type", string(typetree(arg.typ, ctx, dl, seen))))
push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzymejl_parmtype", string(convert(UInt, unsafe_to_pointer(arg.typ)))))
push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzymejl_parmtype_ref", string(UInt(GPUCompiler.BITS_REF))))
elseif arg.arg_i in raisedArgs
wrapparm = load!(builder, convert(LLVMType, arg.typ), wrapparm)
ctx = LLVM.context(wrapparm)
push!(wrapper_args, wrapparm)
push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzyme_type", string(typetree(Base.RefValue{arg.typ}, ctx, dl, seen))))
push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzymejl_parmtype", string(convert(UInt, unsafe_to_pointer(arg.typ)))))
push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzymejl_parmtype_ref", string(UInt(GPUCompiler.BITS_REF))))
else
push!(wrapper_args, wrapparm)
for attr in collect(parameter_attributes(entry_f, arg.codegen.i))
Expand Down Expand Up @@ -4607,6 +4728,7 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function
elseif LLVM.return_type(entry_ft) == LLVM.VoidType()
ret!(builder)
else
ctx = LLVM.context(wrapper_f)
push!(return_attributes(wrapper_f), StringAttribute("enzyme_type", string(typetree(actualRetType, ctx, dl, seen))))
push!(return_attributes(wrapper_f), StringAttribute("enzymejl_parmtype", string(convert(UInt, unsafe_to_pointer(actualRetType)))))
push!(return_attributes(wrapper_f), StringAttribute("enzymejl_parmtype_ref", string(UInt(GPUCompiler.BITS_REF))))
Expand Down Expand Up @@ -4668,7 +4790,7 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function
if LLVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMReturnStatusAction) != 0
msg = sprint() do io
println(io, string(mod))
println(io, LVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMPrintMessageAction))
println(io, LLVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMPrintMessageAction))
println(io, string(wrapper_f))
println(io, "parmsRemoved=", parmsRemoved, " retRemoved=", retRemoved, " prargs=", prargs)
println(io, "Broken function")
Expand Down Expand Up @@ -5947,6 +6069,35 @@ end
push!(ActiveRetTypes, Nothing)
end
push!(ccexprs, argexpr)
elseif T <: MixedDuplicated
if RawCall
argexpr = argexprs[i]
i+=1
else
argexpr = Expr(:., expr, QuoteNode(:dval))
end
push!(types, Any)
if is_adjoint
push!(ActiveRetTypes, Nothing)
end
push!(ccexprs, argexpr)
elseif T <: BatchMixedDuplicated
if RawCall
argexpr = argexprs[i]
i+=1
else
argexpr = Expr(:., expr, QuoteNode(:dval))
end
isboxedvec = GPUCompiler.deserves_argbox(NTuple{width, Base.RefValue{source_typ}})
if isboxedvec
push!(types, Any)
else
push!(types, NTuple{width, Base.RefValue{source_typ}})
end
if is_adjoint
push!(ActiveRetTypes, Nothing)
end
push!(ccexprs, argexpr)
else
error("calling convention should be annotated, got $T")
end
Expand Down
Loading
Loading