Skip to content

Commit

Permalink
Merge branch 'main' into mhauru/distributions-integration-tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Dec 7, 2024
2 parents c14e27f + 8e10a0a commit 81df5ec
Show file tree
Hide file tree
Showing 11 changed files with 450 additions and 259 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Enzyme"
uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9"
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>"]
version = "0.13.19"
version = "0.13.21"

[deps]
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
Expand Down Expand Up @@ -37,7 +37,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5"
CEnum = "0.4, 0.5"
ChainRulesCore = "1"
EnzymeCore = "0.8.8"
Enzyme_jll = "0.0.167"
Enzyme_jll = "0.0.168"
GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 1"
LLVM = "6.1, 7, 8, 9"
LogExpFunctions = "0.3"
Expand Down
3 changes: 3 additions & 0 deletions ext/EnzymeSpecialFunctionsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ using Enzyme

function __init__()
Enzyme.Compiler.known_ops[typeof(SpecialFunctions._logabsgamma)] = (:logabsgamma, 1, (:digamma, typeof(SpecialFunctions.digamma)))
Enzyme.Compiler.cmplx_known_ops[typeof(SpecialFunctions.bessely)] = (:cmplx_jn, 2, nothing)
Enzyme.Compiler.cmplx_known_ops[typeof(SpecialFunctions.besselj)] = (:cmplx_jn, 2, nothing)
Enzyme.Compiler.cmplx_known_ops[typeof(SpecialFunctions.besselk)] = (:cmplx_kn, 2, nothing)
end

end
6 changes: 5 additions & 1 deletion src/absint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,11 @@ end
end

function should_recurse(@nospecialize(typ2), @nospecialize(arg_t::LLVM.LLVMType), byref::GPUCompiler.ArgumentCC, dl::LLVM.DataLayout)::Bool
sz = sizeof(dl, arg_t)
sz = if arg_t == LLVM.IntType(1)
1
else
sizeof(dl, arg_t)
end
if byref != GPUCompiler.BITS_VALUE
if sz != sizeof(Int)
throw(AssertionError("non bits type $byref of $typ2 has size $sz != sizeof(Int) from arg type $arg_t"))
Expand Down
51 changes: 45 additions & 6 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,11 @@ const known_ops = Dict{DataType,Tuple{Symbol,Int,Union{Nothing,Tuple{Symbol,Data
name, arity, toinject = cmplx_known_ops[func]
Tys = (Complex{Float32}, Complex{Float64})
if length(sparam_vals) == arity
if name == :cmplx_jn || name == :cmplx_yn
if (sparam_vals[2] Tys) && sparam_vals[2].parameters[1] == sparam_vals[1]
return name, toinject, sparam_vals[2]
end
end
T = first(sparam_vals)
if (T isa Type)
T = T::Type
Expand Down Expand Up @@ -3956,6 +3961,9 @@ end
lowerConvention = false
end
k_name = LLVM.name(llvmfn)
if !has_fn_attr(llvmfn, EnumAttribute("nofree"))
push!(LLVM.function_attributes(llvmfn), EnumAttribute("nofree"))
end
end

name = string(name)
Expand Down Expand Up @@ -5218,12 +5226,12 @@ end
# JIT
##

function _link(@nospecialize(job::CompilerJob{<:EnzymeTarget}), mod::LLVM.Module, adjoint_name::String, @nospecialize(primal_name::Union{String, Nothing}), @nospecialize(TapeType))
function _link(@nospecialize(job::CompilerJob{<:EnzymeTarget}), mod::LLVM.Module, adjoint_name::String, @nospecialize(primal_name::Union{String, Nothing}), @nospecialize(TapeType), prepost::String)
if job.config.params.ABI <: InlineABI
return CompileResult(
Val((Symbol(mod), Symbol(adjoint_name))),
Val((Symbol(mod), Symbol(primal_name))),
TapeType,
TapeType
)
end

Expand Down Expand Up @@ -5261,7 +5269,7 @@ end
const DumpPostOpt = Ref(false)

# actual compilation
function _thunk(job, postopt::Bool = true)
function _thunk(job, postopt::Bool = true)::Tuple{LLVM.Module, String, Union{String, Nothing}, Type, String}
mod, meta = codegen(:llvm, job; optimize = false)
adjointf, augmented_primalf = meta.adjointf, meta.augmented_primalf

Expand All @@ -5279,7 +5287,12 @@ function _thunk(job, postopt::Bool = true)
end

# Run post optimization pipeline
if postopt
prepost = if postopt
mstr = if job.config.params.ABI <: InlineABI
""
else
string(mod)
end
if job.config.params.ABI <: FFIABI || job.config.params.ABI <: NonGenABI
post_optimze!(mod, JIT.get_tm())
if DumpPostOpt[]
Expand All @@ -5288,12 +5301,17 @@ function _thunk(job, postopt::Bool = true)
else
propagate_returned!(mod)
end
mstr
else
""
end
return (mod, adjoint_name, primal_name, meta.TapeType)
return (mod, adjoint_name, primal_name, meta.TapeType, prepost)
end

const cache = Dict{UInt,CompileResult}()

const autodiff_cache = Dict{Ptr{Cvoid},Tuple{String, String}}()

const cache_lock = ReentrantLock()
@inline function cached_compilation(@nospecialize(job::CompilerJob))::CompileResult
key = hash(job)
Expand All @@ -5305,6 +5323,12 @@ const cache_lock = ReentrantLock()
if obj === nothing
asm = _thunk(job)
obj = _link(job, asm...)
if obj.adjoint isa Ptr{Nothing}
autodiff_cache[obj.adjoint] = (asm[2], asm[5])
end
if obj.primal isa Ptr{Nothing} && asm[3] isa String
autodiff_cache[obj.primal] = (asm[3], asm[5])
end
cache[key] = obj
end
obj
Expand Down Expand Up @@ -5549,7 +5573,22 @@ function thunk_generator(world::UInt, source::LineNumberNode, @nospecialize(FA::
# new_ci.min_world = min_world[]
new_ci.min_world = world
new_ci.max_world = max_world[]
new_ci.edges = Core.MethodInstance[mi]

edges = Core.MethodInstance[mi]

if Mode == API.DEM_ForwardMode
push!(edges, GPUCompiler.methodinstance(typeof(Compiler.Interpreter.rule_backedge_holder), Tuple{typeof(EnzymeRules.forward)}, world))
Compiler.Interpreter.rule_backedge_holder(Base.inferencebarrier(EnzymeRules.forward))
else
push!(edges, GPUCompiler.methodinstance(typeof(Compiler.Interpreter.rule_backedge_holder), Tuple{typeof(EnzymeRules.augmented_primal)}, world))
end

push!(edges, GPUCompiler.methodinstance(typeof(Compiler.Interpreter.rule_backedge_holder), Tuple{typeof(EnzymeRules.inactive)}, world))
push!(edges, GPUCompiler.methodinstance(typeof(Compiler.Interpreter.rule_backedge_holder), Tuple{Val{0}}, world))
Compiler.Interpreter.rule_backedge_holder(Base.inferencebarrier(Val(0)))

new_ci.edges = edges

# XXX: setting this edge does not give us proper method invalidation, see
# JuliaLang/julia#34962 which demonstrates we also need to "call" the kernel.
# invoking `code_llvm` also does the necessary codegen, as does calling the
Expand Down
Loading

0 comments on commit 81df5ec

Please sign in to comment.