Skip to content

Commit

Permalink
Test and fix CUDA method replacement (#253)
Browse files Browse the repository at this point in the history

Co-authored-by: Ali Ramadhan <ali.hh.ramadhan@gmail.com>
  • Loading branch information
vchuravy and ali-ramadhan authored May 14, 2021
1 parent 1496f4e commit af1b933
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 13 deletions.
14 changes: 10 additions & 4 deletions lib/CUDAKernels/src/CUDAKernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,12 @@ KernelAbstractions.generate_overdubs(@__MODULE__, CUDACtx)
# CUDA specific method rewrites
###

@inline Cassette.overdub(::CUDACtx, ::typeof(^), x::Float64, y::Float64) = ^(x, y)
@inline Cassette.overdub(::CUDACtx, ::typeof(^), x::Float32, y::Float32) = ^(x, y)
@inline Cassette.overdub(::CUDACtx, ::typeof(^), x::Float64, y::Int32) = ^(x, y)
@inline Cassette.overdub(::CUDACtx, ::typeof(^), x::Float32, y::Int32) = ^(x, y)
@inline Cassette.overdub(::CUDACtx, ::typeof(^), x::Union{Float32, Float64}, y::Int64) = ^(x, y)

# libdevice.jl
const cudafuns = (:cos, :cospi, :sin, :sinpi, :tan,
:acos, :asin, :atan,
Expand All @@ -300,12 +306,12 @@ for f in cudafuns
end
end

@inline Cassette.overdub(::CUDACtx, ::typeof(sincos), x::Union{Float32, Float64}) = (CUDA.sin(x), CUDA.cos(x))
@inline Cassette.overdub(::CUDACtx, ::typeof(exp), x::Union{ComplexF32, ComplexF64}) = CUDA.exp(x)
@inline Cassette.overdub(::CUDACtx, ::typeof(sincos), x::Union{Float32, Float64}) = (Base.sin(x), Base.cos(x))
@inline Cassette.overdub(::CUDACtx, ::typeof(exp), x::Union{ComplexF32, ComplexF64}) = Base.exp(x)

@inline Cassette.overdub(::CUDACtx, ::typeof(SpecialFunctions.gamma), x::Union{Float32, Float64}) = CUDA.tgamma(x)
@inline Cassette.overdub(::CUDACtx, ::typeof(SpecialFunctions.erf), x::Union{Float32, Float64}) = CUDA.erf(x)
@inline Cassette.overdub(::CUDACtx, ::typeof(SpecialFunctions.erfc), x::Union{Float32, Float64}) = CUDA.erfc(x)
@inline Cassette.overdub(::CUDACtx, ::typeof(SpecialFunctions.erf), x::Union{Float32, Float64}) = SpecialFunctions.erf(x)
@inline Cassette.overdub(::CUDACtx, ::typeof(SpecialFunctions.erfc), x::Union{Float32, Float64}) = SpecialFunctions.erfc(x)

@static if Base.isbindingresolved(CUDA, :emit_shmem) && Base.isdefined(CUDA, :emit_shmem)
const emit_shmem = CUDA.emit_shmem
Expand Down
3 changes: 3 additions & 0 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ function generate_overdubs(mod, Ctx)

@inline Cassette.overdub(::$Ctx, ::typeof(Base.literal_pow), f::F, x, p) where F = Base.literal_pow(f, x, p)

@inline Cassette.overdub(::$Ctx, ::typeof(Base.throw_boundserror), args...) = Base.throw_boundserror(args...)
@inline Cassette.overdub(::$Ctx, ::typeof(Base.Math.throw_exp_domainerror), args...) = Base.Math.throw_exp_domainerror(args...)

function Cassette.overdub(::$Ctx, ::typeof(:), start::T, step::T, stop::T) where T<:Union{Float16,Float32,Float64}
lf = (stop-start)/step
if lf < 0
Expand Down
48 changes: 41 additions & 7 deletions test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,65 @@ end
A[1] = B[1]^2
end

@kernel function pow(A, B)
A[1] = A[1]^B[1]
end

@kernel function checked(A, a, b)
A[1] = Base.Checked.checked_add(a, b)
end

function compiler_testsuite()
function check_for_overdub(stmt)
if stmt isa Expr
if stmt.head == :invoke
mi = first(stmt.args)::Core.MethodInstance
if mi.def.name === :overdub
@show stmt
return true
end
end
end
return false
end

function compiler_testsuite(backend, ArrayT)
kernel = index(CPU(), DynamicSize(), DynamicSize())
iterspace = NDRange{1, StaticSize{(128,)}, StaticSize{(8,)}}();
ctx = KernelAbstractions.mkcontext(kernel, 1, nothing, iterspace, Val(KernelAbstractions.NoDynamicCheck()))
CTX = KernelAbstractions.cassette(kernel)
@test KernelAbstractions.Cassette.overdub(CTX, KernelAbstractions.__index_Global_NTuple, ctx, CartesianIndex(1)) == (1,)

let (CI, rt) = @ka_code_typed literal_pow(CPU())(zeros(Int,1), ndrange=1)
A = ArrayT{Int}(undef, 1)
let (CI, rt) = @ka_code_typed literal_pow(backend())(A, ndrange=1)
# test that there is no invoke of overdub
@test !any(check_for_overdub, CI.code)
end

A = ArrayT{Float64}(undef, 1)
let (CI, rt) = @ka_code_typed square(backend())(A, A, ndrange=1)
# test that there is no invoke of overdub
@test !any(check_for_overdub, CI.code)
end

A = ArrayT{Float64}(undef, 1)
B = ArrayT{Float64}(undef, 1)
let (CI, rt) = @ka_code_typed pow(backend())(A, B, ndrange=1)
# test that there is no invoke of overdub
@test !any(stmt->(stmt isa Expr) && stmt.head == :invoke, CI.code)
@test !any(check_for_overdub, CI.code)
end

let (CI, rt) = @ka_code_typed square(CPU())(zeros(1), zeros(1), ndrange=1)
A = ArrayT{Float64}(undef, 1)
B = ArrayT{Int32}(undef, 1)
let (CI, rt) = @ka_code_typed pow(backend())(A, B, ndrange=1)
# test that there is no invoke of overdub
@test !any(stmt->(stmt isa Expr) && stmt.head == :invoke, CI.code)
@test !any(check_for_overdub, CI.code)
end

if VERSION >= v"1.5"
let (CI, rt) = @ka_code_typed checked(CPU())(zeros(Int,1), 1, 2, ndrange=1)
A = ArrayT{Int}(undef, 1)
let (CI, rt) = @ka_code_typed checked(backend())(A, 1, 2, ndrange=1)
# test that there is no invoke of overdub
@test !any(stmt->(stmt isa Expr) && stmt.head == :invoke, CI.code)
@test !any(check_for_overdub, CI.code)
end
end
end
4 changes: 2 additions & 2 deletions test/testsuite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ function testsuite(backend, backend_str, backend_mod, AT, DAT)
end
end

if backend == CPU
if backend_str != "ROCM"
@testset "Compiler" begin
compiler_testsuite()
compiler_testsuite(backend, AT)
end
end

Expand Down

0 comments on commit af1b933

Please sign in to comment.