-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mock Enzyme plugin #636
base: 09-26-make_gpuinterpreter_extensible
Are you sure you want to change the base?
Mock Enzyme plugin #636
Conversation
Warning This pull request is not mergeable via GitHub because a downstack PR is open. Once all requirements are satisfied, merge this PR as a stack on Graphite.
This stack of pull requests is managed by Graphite. Learn more about stacking. |
7c4dc5c
to
e00cfb7
Compare
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## 09-26-make_gpuinterpreter_extensible #636 +/- ##
=========================================================================
- Coverage 72.29% 61.78% -10.51%
=========================================================================
Files 24 24
Lines 3353 3274 -79
=========================================================================
- Hits 2424 2023 -401
- Misses 929 1251 +322 ☔ View full report in Codecov by Sentry. |
e00cfb7
to
0296364
Compare
args = [Core.Compiler.Argument(i) for i in 2:length(stmt.args)] # f, args... | ||
idx = 0 | ||
|
||
# 0. Enzyme proper: Desugar args | ||
primal_args = args | ||
primal_argtypes = match.spec_types.parameters[2:end] | ||
|
||
adjoint_rt = info.rt | ||
adjoint_args = args # TODO | ||
adjoint_argtypes = primal_argtypes | ||
|
||
# 1: Since Julia's inliner goes bottom up we need to pretend that we inlined the deferred call | ||
expr = Expr(:foreigncall, | ||
"extern gpuc.lookup", | ||
Ptr{Cvoid}, | ||
Core.svec(#=meta=# Any, #=mi=# Any, #=f=# Any, primal_argtypes...), # Must use Any for MethodInstance or ftype | ||
0, | ||
QuoteNode(:llvmcall), | ||
deferred_info.meta, | ||
case.invoke, | ||
primal_args... | ||
) | ||
ptr = insert_node!(ir, (idx += 1), NewInstruction(expr, Ptr{Cvoid})) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@aviatesk does this look correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
handle_call!
isn’t meant to be overloaded, so I think this approach is preferred:
function Core.Compiler.src_inlining_policy(interp::GPUCompiler#=or EnzymeInterpreter?=#,
@nospecialize(src), info::AutodiffCallInfo, stmt_flag::UInt32)
# Goal:
# The IR we want to return here is:
# unpack the args ..
# ptr = gpuc.deferred(MockEnzymeMeta(), f, primal_args...)
# ret = ccall("extern __autodiff", llvmcall, RT, Tuple{Ptr{Cvoid, args...}}, ptr, adjoint_args...)
ir = Core.Compiler.IRCode() # contains a placeholder
...
return ir
end
By overloading src_inlining_policy
(or inlining_policy
in older versions), we can apply this custom inlining to const-propped call sites and semi-concrete interpreted call sites as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Haha, then I have misunderstood the comment in:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I just realized that overloading src_inlining_policy
wouldn't be enough.. We need to overload retrieve_ir_for_inlining
too, but it doesn't take info::CallInfo
, so maybe we need to tweak the interface..
But I believe this approach (overloading inlining_policy
) works at least for pre-1.11.
Haha, then I have misunderstood the comment in:
It seems like I’ve ended up betraying my past self.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I think you once told me to extend handle_call
052a118
to
4ef6019
Compare
3cfc5e1
to
1b8d203
Compare
funcT = LLVM.called_type(call) | ||
funcT = LLVM.FunctionType(LLVM.return_type(funcT), LLVM.parameters(funcT)[3:end]) | ||
direct_call = LLVM.call!(builder, funcT, target, ops[3:end - 1]) # why is the -1 necessary |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the last op is the called value
end | ||
|
||
function mock_enzyme!(@nospecialize(job), intrinsic, mod::LLVM.Module) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think presently Enzyme does more than that as well. To rough approximation, it does the following as its entire compilation step
-
"Before anything else happens"
Set each llvmf to know about its worldage, methodinstance, and return type: https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L6703
Make things inline ready [e.g. remove some tbaa which is broken] https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L6705
Rewrite some nvvm and related intrinsics https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L6709
mark various type unstable calls as inactive and change inttoptr'd ccalls into calls by name [storing the actual int value to later restore]: https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L6714
Replace unhandled blas calls with fallback: https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L6755
Annotate types and activities: https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L6894
Mark custom rules and related as noinline: https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L7035
Lower calling convention of functino being differentiated: https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L7465 -
Optimization pipeline: https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L7514
Currently we use a modified optimization pipeline that also adds new passes which we found to be critical for performance (namely the new jl_inst_simpliy pass among others for interprocedural dead arg elim) -
AD
First we run a julia analysis pass if the fn differentiated was a closure and requested we error if it is written https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L7659
Upgrading some memcpy's to load/store: https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L7799
Actually generating the derivatives: https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L7801
Inverse of the preserve nvvm pass above: https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L7897
Restoring the actual inttoptr => function name from above https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L8074
Other immediate post Enzyme passes: https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L8085 and https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L8087 and https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L8105 and https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L8129
Post Enzyme optimization (https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L8131):
- This includes running new passes to fix garbage collection/etc but presumably can just be scheduled
(; fargs, argtypes) = arginfo | ||
|
||
@assert f === autodiff |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@assert f === autodiff |
end | ||
|
||
function abstract_call_known(meta::Nothing, interp::GPUInterpreter, f::typeof(autodiff), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
abstract_call_known
with this signature would never be called from Core.Compiler
, so this overload would do nothing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a double extension problem. GPUCompiler provides has a GPUInterpreter.
Both GPUCompiler/CUDA and Enzyme want to modify the rules being applied.
But when we are applying Enzyme to CUDA code we must "inherit" the rules from CUDA, up to now Enzyme had a EnzymeInterpreter, but I would like to get rid of that.
But Enzyme rules shouldn't apply to CUDA code by default. However I also need to teach in an extensible matter GPUCompiler about autodiff
such that:
function kernel(args...)
autodiff(f, ....)
end
@cuda kernel(args...)
works.
4ef6019
to
1638cc2
Compare
1b8d203
to
d906034
Compare
Make sure that the new infrastructure can handle the complicate song and dance Enzyme needs to do.