Skip to content

Commit

Permalink
Add optimization callbacks that fire on a marker function
Browse files Browse the repository at this point in the history
  • Loading branch information
vchuravy committed Sep 27, 2024
1 parent 5281e86 commit 828ee63
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 2 deletions.
32 changes: 31 additions & 1 deletion src/optim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
function optimize!(@nospecialize(job::CompilerJob), mod::LLVM.Module; opt_level=1)
tm = llvm_machine(job.config.target)

global current_job
global current_job # ScopedValue?
current_job = job

@dispose pb=NewPMPassBuilder() begin
Expand All @@ -14,6 +14,10 @@ function optimize!(@nospecialize(job::CompilerJob), mod::LLVM.Module; opt_level=
register!(pb, LowerKernelStatePass())
register!(pb, CleanupKernelStatePass())

for (name, callback) in PIPELINE_CALLBACKS
register!(pb, CallbackPass(name, callback))
end

add!(pb, NewPMModulePassManager()) do mpm
buildNewPMPipeline!(mpm, job, opt_level)
end
Expand All @@ -24,6 +28,15 @@ function optimize!(@nospecialize(job::CompilerJob), mod::LLVM.Module; opt_level=
return
end

# TODO: Priority heap to provide order between different plugins
const PIPELINE_CALLBACKS = Dict{String, Any}()
function register_plugin!(name::String, plugin)
if haskey(PIPELINE_CALLBACKS, name)
error("GPUCompiler plugin with name $name is already registered")
end
PIPELINE_CALLBACKS[name] = plugin
end

function buildNewPMPipeline!(mpm, @nospecialize(job::CompilerJob), opt_level)
buildEarlySimplificationPipeline(mpm, job, opt_level)
add!(mpm, AlwaysInlinerPass())
Expand All @@ -41,6 +54,9 @@ function buildNewPMPipeline!(mpm, @nospecialize(job::CompilerJob), opt_level)
add!(fpm, WarnMissedTransformationsPass())
end
end
for (name, callback) in PIPELINE_CALLBACKS
add!(mpm, CallbackPass(name, callback))
end
buildIntrinsicLoweringPipeline(mpm, job, opt_level)
buildCleanupPipeline(mpm, job, opt_level)
end
Expand Down Expand Up @@ -423,3 +439,17 @@ function lower_ptls!(mod::LLVM.Module)
return changed
end
LowerPTLSPass() = NewPMModulePass("GPULowerPTLS", lower_ptls!)


function callback_pass!(name, callback::F, mod::LLVM.Module) where F
job = current_job::CompilerJob
changed = false

if haskey(functions(mod), name)
marker = functions(mod)[name]
changed = callback(job, marker, mod)
end
return changed
end

CallbackPass(name, callback) = NewPMModulePass("CallbackPass<$name>", (mod)->callback_pass!(name, callback, mod))
12 changes: 12 additions & 0 deletions test/ptx_tests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
@testitem "PTX" setup=[PTX, Helpers] begin

using LLVM
import InteractiveUtils

############################################################################################

Expand Down Expand Up @@ -276,6 +277,17 @@ end
@test "We did not crash!" != ""
end

@testset "Pipeline callbacks" begin
function kernel(x)
PTX.mark(x)
return
end
ir = sprint(io->InteractiveUtils.code_llvm(io, kernel, Tuple{Int}))
@test occursin("gpucompiler.mark", ir)
ir = sprint(io->PTX.code_llvm(io, kernel, Tuple{Int}))
@test !occursin("gpucompiler.mark", ir)
end

@testset "exception arguments" begin
function kernel(a)
unsafe_store!(a, trunc(Int, unsafe_load(a)))
Expand Down
24 changes: 23 additions & 1 deletion test/ptx_testsetup.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
@testsetup module PTX

using GPUCompiler

import LLVM

# create a PTX-based test compiler, and generate reflection methods for it

Expand All @@ -16,6 +16,28 @@ end
GPUCompiler.kernel_state_type(@nospecialize(job::PTXCompilerJob)) = PTXKernelState
@inline @generated kernel_state() = GPUCompiler.kernel_state_value(PTXKernelState)

function mark(x)
ccall("extern gpucompiler.mark", llvmcall, Nothing, (Int,), x)
end

function remove_mark!(@nospecialize(job), intrinsic, mod::LLVM.Module)
changed = false

for use in LLVM.uses(intrinsic)
val = LLVM.user(use)
if isempty(LLVM.uses(val))
LLVM.unsafe_delete!(LLVM.parent(val), val)
changed = true
else
# the validator will detect this
end
end

return changed
end

GPUCompiler.register_plugin!("gpucompiler.mark", remove_mark!)

# a version of the test runtime that has some side effects, loading the kernel state
# (so that we can test if kernel state arguments are appropriately optimized away)
module PTXTestRuntime
Expand Down

0 comments on commit 828ee63

Please sign in to comment.