Skip to content

Commit

Permalink
MixedDuplicated for custom rules (#1534)
Browse files Browse the repository at this point in the history
* MixedDuplicated for custom rules

* more mixed duplicated

* Handle mixed custom rule arg

* starting batching

* fix

* fix tests

* simplify mixed activity use
  • Loading branch information
wsmoses authored Jun 13, 2024
1 parent fb6f959 commit 15f9bb1
Show file tree
Hide file tree
Showing 9 changed files with 508 additions and 205 deletions.
26 changes: 26 additions & 0 deletions lib/EnzymeCore/src/EnzymeCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,32 @@ end
@inline batch_size(::Type{BatchDuplicatedNoNeed{T,N}}) where {T,N} = N


"""
MixedDuplicated(x, ∂f_∂x)
Like [`Duplicated`](@ref), except x may contain both active [immutable] and duplicated [mutable]
data which is differentiable. Only used within custom rules.
"""
struct MixedDuplicated{T} <: Annotation{T}
val::T
dval::Base.RefValue{T}
@inline MixedDuplicated(x::T1, dx::Base.RefValue{T1}, check::Bool=true) where {T1} = new{T1}(x, dx)
end

"""
BatchMixedDuplicated(x, ∂f_∂xs)
Like [`MixedDuplicated`](@ref), except contains several shadows to compute derivatives
for all at once. Only used within custom rules.
"""
struct BatchMixedDuplicated{T,N} <: Annotation{T}
val::T
dval::NTuple{N,Base.RefValue{T}}
@inline BatchMixedDuplicated(x::T1, dx::NTuple{N,Base.RefValue{T1}}, check::Bool=true) where {T1, N} = new{T1, N}(x, dx)
end
@inline batch_size(::BatchMixedDuplicated{T,N}) where {T,N} = N
@inline batch_size(::Type{BatchMixedDuplicated{T,N}}) where {T,N} = N

"""
abstract type ABI
Expand Down
3 changes: 3 additions & 0 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ export Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated,
import EnzymeCore: BatchDuplicatedFunc
export BatchDuplicatedFunc

import EnzymeCore: MixedDuplicated, BatchMixedDuplicated
export MixedDuplicated, BatchMixedDuplicated

import EnzymeCore: batch_size, get_func
export batch_size, get_func

Expand Down
177 changes: 164 additions & 13 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2450,6 +2450,50 @@ else
end
end

function store_nonjl_types!(B, startval, p)
T_jlvalue = LLVM.StructType(LLVMType[])
T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked)
vals = LLVM.Value[]
if p != nothing
push!(vals, p)
end
todo = Tuple{Tuple, LLVM.Value}[((), startval)]
while length(todo) != 0
path, cur = popfirst!(todo)
ty = value_type(cur)
if isa(ty, LLVM.PointerType)
if any_jltypes(ty)
continue
end
end
if isa(ty, LLVM.ArrayType)
if any_jltypes(ty)
for i=1:length(ty)
ev = extract_value!(B, cur, i-1)
push!(todo, ((path..., i-1), ev))
end
continue
end
end
if isa(ty, LLVM.StructType)
if any_jltypes(ty)
for (i, t) in enumerate(LLVM.elements(ty))
ev = extract_value!(B, cur, i-1)
push!(todo, ((path..., i-1), ev))
end
continue
end
end
parray = LLVM.Value[LLVM.ConstantInt(LLVM.IntType(64), 0)]
for v in path
push!(parray, LLVM.ConstantInt(LLVM.IntType(32), v))
end
gptr = gep!(B, value_type(startval), p, parray)
st = store!(B, cur, gptr)
end
return
end

function get_julia_inner_types(B, p, startvals...; added=[])
T_jlvalue = LLVM.StructType(LLVMType[])
T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked)
Expand Down Expand Up @@ -3404,7 +3448,7 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr
else
push!(args_activity, API.DFT_OUT_DIFF)
end
elseif T <: Duplicated || T<: BatchDuplicated || T<: BatchDuplicatedFunc
elseif T <: Duplicated || T<: BatchDuplicated || T<: BatchDuplicatedFunc || T <: MixedDuplicated || T <: BatchMixedDuplicated
push!(args_activity, API.DFT_DUP_ARG)
elseif T <: DuplicatedNoNeed || T<: BatchDuplicatedNoNeed
push!(args_activity, API.DFT_DUP_NONEED)
Expand Down Expand Up @@ -3588,7 +3632,6 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,

isboxed = GPUCompiler.deserves_argbox(source_typ)
llvmT = isboxed ? T_prjlvalue : convert(LLVMType, source_typ)

push!(T_wrapperargs, llvmT)

if T <: Const || T <: BatchDuplicatedFunc
Expand Down Expand Up @@ -3617,6 +3660,11 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,
if is_adjoint && i != 1
push!(ActiveRetTypes, Nothing)
end
elseif T <: MixedDuplicated || T <: BatchMixedDuplicated
push!(T_wrapperargs, LLVM.LLVMType(API.EnzymeGetShadowType(width, T_prjlvalue)))
if is_adjoint && i != 1
push!(ActiveRetTypes, Nothing)
end
else
error("calling convention should be annotated, got $T")
end
Expand Down Expand Up @@ -3799,7 +3847,23 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,
if isghostty(T′) || Core.Compiler.isconstType(T′)
continue
end
push!(realparms, params[i])

isboxed = GPUCompiler.deserves_argbox(T′)

llty = value_type(params[i])

convty = convert(LLVMType, T′; allow_boxed=true)

if (T <: MixedDuplicated || T <: BatchMixedDuplicated) && !isboxed # && (isa(llty, LLVM.ArrayType) || isa(llty, LLVM.StructType))
al = emit_allocobj!(builder, Base.RefValue{T′})
al = bitcast!(builder, al, LLVM.PointerType(llty, addrspace(value_type(al))))
store!(builder, params[i], al)
al = addrspacecast!(builder, al, LLVM.PointerType(llty, Derived))
push!(realparms, al)
else
push!(realparms, params[i])
end

i += 1
if T <: Const
elseif T <: Active
Expand Down Expand Up @@ -3827,6 +3891,34 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,
elseif T <: Duplicated || T <: DuplicatedNoNeed
push!(realparms, params[i])
i += 1
elseif T <: MixedDuplicated || T <: BatchMixedDuplicated
parmsi = params[i]

if T <: BatchMixedDuplicated
if GPUCompiler.deserves_argbox(NTuple{width, Base.RefValue{T′}})
njlvalue = LLVM.ArrayType(Int(width), T_prjlvalue)
parmsi = bitcast!(builder, parmsi, LLVM.PointerType(njlvalue, addrspace(value_type(parmsi))))
parmsi = load!(builder, njlvalue, parmsi)
end
end

isboxed = GPUCompiler.deserves_argbox(T′)

resty = isboxed ? llty : LLVM.PointerType(llty, Derived)

ival = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, resty)))
for idx in 1:width
pv = (width == 1) ? parmsi : extract_value!(builder, parmsi, idx-1)
pv = bitcast!(builder, pv, LLVM.PointerType(llty, addrspace(value_type(pv))))
pv = addrspacecast!(builder, pv, LLVM.PointerType(llty, Derived))
if isboxed
pv = load!(builder, llty, pv, "mixedboxload")
end
ival = (width == 1 ) ? pv : insert_value!(builder, ival, pv, idx-1)
end

push!(realparms, ival)
i += 1
elseif T <: BatchDuplicated || T <: BatchDuplicatedNoNeed
isboxed = GPUCompiler.deserves_argbox(NTuple{width, T′})
val = params[i]
Expand Down Expand Up @@ -4357,6 +4449,7 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function

# generate the wrapper function type & definition
wrapper_types = LLVM.LLVMType[]
wrapper_attrs = Vector{LLVM.Attribute}[]
_, sret, returnRoots = get_return_info(actualRetType)
sret_union = is_sret_union(actualRetType)

Expand Down Expand Up @@ -4391,31 +4484,44 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function

if swiftself
push!(wrapper_types, value_type(parameters(entry_f)[1+sret+returnRoots]))
push!(wrapper_attrs, LLVM.Attribute[EnumAttribute("swiftself")])
end

boxedArgs = Set{Int}()
loweredArgs = Set{Int}()
raisedArgs = Set{Int}()

for arg in args
typ = arg.codegen.typ
if GPUCompiler.deserves_argbox(arg.typ)
push!(boxedArgs, arg.arg_i)
push!(wrapper_types, typ)
push!(wrapper_attrs, LLVM.Attribute[])
elseif arg.cc != GPUCompiler.BITS_REF
push!(wrapper_types, typ)
if TT != nothing && (TT.parameters[arg.arg_i] <: MixedDuplicated || TT.parameters[arg.arg_i] <: BatchMixedDuplicated)
push!(boxedArgs, arg.arg_i)
push!(raisedArgs, arg.arg_i)
push!(wrapper_types, LLVM.PointerType(typ, Derived))
push!(wrapper_attrs, LLVM.Attribute[EnumAttribute("noalias")])
else
push!(wrapper_types, typ)
push!(wrapper_attrs, LLVM.Attribute[])
end
else
# bits ref, and not boxed
# if TT.parameters[arg.arg_i] <: Const
# push!(boxedArgs, arg.arg_i)
# push!(wrapper_types, typ)
# else
if TT != nothing && (TT.parameters[arg.arg_i] <: MixedDuplicated || TT.parameters[arg.arg_i] <: BatchMixedDuplicated)
push!(boxedArgs, arg.arg_i)
push!(wrapper_types, typ)
push!(wrapper_attrs, LLVM.Attribute[EnumAttribute("noalias")])
else
push!(wrapper_types, eltype(typ))
push!(wrapper_attrs, LLVM.Attribute[])
push!(loweredArgs, arg.arg_i)
# end
end
end
end

if length(loweredArgs) == 0 && !sret && !sret_union
if length(loweredArgs) == 0 && length(raisedArgs) == 0 && !sret && !sret_union
return entry_f, returnRoots, boxedArgs, loweredArgs
end

Expand All @@ -4436,8 +4542,10 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function
end
push!(function_attributes(wrapper_f), EnumAttribute("returns_twice"))
push!(function_attributes(entry_f), EnumAttribute("returns_twice"))
if swiftself
push!(parameter_attributes(wrapper_f, 1), EnumAttribute("swiftself"))
for (i, v) in enumerate(wrapper_attrs)
for attr in v
push!(parameter_attributes(wrapper_f, i), attr)
end
end

seen = TypeTreeTable()
Expand All @@ -4463,6 +4571,12 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function
parm = ops[arg.codegen.i]
if arg.arg_i in loweredArgs
push!(nops, load!(builder, convert(LLVMType, arg.typ), parm))
elseif arg.arg_i in raisedArgs
obj = emit_allocobj!(builder, arg.typ)
bc = bitcast!(builder, obj, LLVM.PointerType(value_type(parm), addrspace(value_type(obj))))
store!(builder, parm, bc)
addr = addrspacecast!(builder, bc, LLVM.PointerType(value_type(parm), Derived))
push!(nops, addr)
else
push!(nops, parm)
end
Expand Down Expand Up @@ -4547,6 +4661,13 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function
push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzyme_type", string(typetree(arg.typ, ctx, dl, seen))))
push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzymejl_parmtype", string(convert(UInt, unsafe_to_pointer(arg.typ)))))
push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzymejl_parmtype_ref", string(UInt(GPUCompiler.BITS_REF))))
elseif arg.arg_i in raisedArgs
wrapparm = load!(builder, convert(LLVMType, arg.typ), wrapparm)
ctx = LLVM.context(wrapparm)
push!(wrapper_args, wrapparm)
push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzyme_type", string(typetree(Base.RefValue{arg.typ}, ctx, dl, seen))))
push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzymejl_parmtype", string(convert(UInt, unsafe_to_pointer(arg.typ)))))
push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzymejl_parmtype_ref", string(UInt(GPUCompiler.BITS_REF))))
else
push!(wrapper_args, wrapparm)
for attr in collect(parameter_attributes(entry_f, arg.codegen.i))
Expand Down Expand Up @@ -4626,6 +4747,7 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function
elseif LLVM.return_type(entry_ft) == LLVM.VoidType()
ret!(builder)
else
ctx = LLVM.context(wrapper_f)
push!(return_attributes(wrapper_f), StringAttribute("enzyme_type", string(typetree(actualRetType, ctx, dl, seen))))
push!(return_attributes(wrapper_f), StringAttribute("enzymejl_parmtype", string(convert(UInt, unsafe_to_pointer(actualRetType)))))
push!(return_attributes(wrapper_f), StringAttribute("enzymejl_parmtype_ref", string(UInt(GPUCompiler.BITS_REF))))
Expand Down Expand Up @@ -4687,7 +4809,7 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function
if LLVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMReturnStatusAction) != 0
msg = sprint() do io
println(io, string(mod))
println(io, LVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMPrintMessageAction))
println(io, LLVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMPrintMessageAction))
println(io, string(wrapper_f))
println(io, "parmsRemoved=", parmsRemoved, " retRemoved=", retRemoved, " prargs=", prargs)
println(io, "Broken function")
Expand Down Expand Up @@ -5966,6 +6088,35 @@ end
push!(ActiveRetTypes, Nothing)
end
push!(ccexprs, argexpr)
elseif T <: MixedDuplicated
if RawCall
argexpr = argexprs[i]
i+=1
else
argexpr = Expr(:., expr, QuoteNode(:dval))
end
push!(types, Any)
if is_adjoint
push!(ActiveRetTypes, Nothing)
end
push!(ccexprs, argexpr)
elseif T <: BatchMixedDuplicated
if RawCall
argexpr = argexprs[i]
i+=1
else
argexpr = Expr(:., expr, QuoteNode(:dval))
end
isboxedvec = GPUCompiler.deserves_argbox(NTuple{width, Base.RefValue{source_typ}})
if isboxedvec
push!(types, Any)
else
push!(types, NTuple{width, Base.RefValue{source_typ}})
end
if is_adjoint
push!(ActiveRetTypes, Nothing)
end
push!(ccexprs, argexpr)
else
error("calling convention should be annotated, got $T")
end
Expand Down
Loading

2 comments on commit 15f9bb1

@wsmoses
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register subdir="lib/EnzymeCore"

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error while trying to register: Version 0.7.4 already exists

Please sign in to comment.