Skip to content

Commit

Permalink
Merge #251
Browse files Browse the repository at this point in the history
251: Backport #249 r=vchuravy a=vchuravy

Backport #249.

@ali-ramadhan I noticed that you only tested `CPU()` execution,
but the changes are limited to CUDAKernels.





Co-authored-by: bors[bot] <26634292+bors[bot]@users.noreply.github.com>
Co-authored-by: Valentin Churavy <vchuravy@users.noreply.github.com>
  • Loading branch information
bors[bot] and vchuravy authored May 15, 2021
2 parents 2cff888 + d049410 commit 1e51126
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 17 deletions.
20 changes: 10 additions & 10 deletions lib/CUDAKernels/src/CUDAKernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -280,11 +280,11 @@ KernelAbstractions.generate_overdubs(@__MODULE__, CUDACtx)
# CUDA specific method rewrites
###

@inline Cassette.overdub(::CUDACtx, ::typeof(^), x::Float64, y::Float64) = CUDA.pow(x, y)
@inline Cassette.overdub(::CUDACtx, ::typeof(^), x::Float32, y::Float32) = CUDA.pow(x, y)
@inline Cassette.overdub(::CUDACtx, ::typeof(^), x::Float64, y::Int32) = CUDA.pow(x, y)
@inline Cassette.overdub(::CUDACtx, ::typeof(^), x::Float32, y::Int32) = CUDA.pow(x, y)
@inline Cassette.overdub(::CUDACtx, ::typeof(^), x::Union{Float32, Float64}, y::Int64) = CUDA.pow(x, y)
@inline Cassette.overdub(::CUDACtx, ::typeof(^), x::Float64, y::Float64) = ^(x, y)
@inline Cassette.overdub(::CUDACtx, ::typeof(^), x::Float32, y::Float32) = ^(x, y)
@inline Cassette.overdub(::CUDACtx, ::typeof(^), x::Float64, y::Int32) = ^(x, y)
@inline Cassette.overdub(::CUDACtx, ::typeof(^), x::Float32, y::Int32) = ^(x, y)
@inline Cassette.overdub(::CUDACtx, ::typeof(^), x::Union{Float32, Float64}, y::Int64) = ^(x, y)

# libdevice.jl
const cudafuns = (:cos, :cospi, :sin, :sinpi, :tan,
Expand All @@ -300,16 +300,16 @@ const cudafuns = (:cos, :cospi, :sin, :sinpi, :tan,
for f in cudafuns
@eval function Cassette.overdub(ctx::CUDACtx, ::typeof(Base.$f), x::Union{Float32, Float64})
@Base._inline_meta
return CUDA.$f(x)
return Base.$f(x)
end
end

@inline Cassette.overdub(::CUDACtx, ::typeof(sincos), x::Union{Float32, Float64}) = (CUDA.sin(x), CUDA.cos(x))
@inline Cassette.overdub(::CUDACtx, ::typeof(exp), x::Union{ComplexF32, ComplexF64}) = CUDA.exp(x)
@inline Cassette.overdub(::CUDACtx, ::typeof(sincos), x::Union{Float32, Float64}) = (Base.sin(x), Base.cos(x))
@inline Cassette.overdub(::CUDACtx, ::typeof(exp), x::Union{ComplexF32, ComplexF64}) = Base.exp(x)

@inline Cassette.overdub(::CUDACtx, ::typeof(SpecialFunctions.gamma), x::Union{Float32, Float64}) = CUDA.tgamma(x)
@inline Cassette.overdub(::CUDACtx, ::typeof(SpecialFunctions.erf), x::Union{Float32, Float64}) = CUDA.erf(x)
@inline Cassette.overdub(::CUDACtx, ::typeof(SpecialFunctions.erfc), x::Union{Float32, Float64}) = CUDA.erfc(x)
@inline Cassette.overdub(::CUDACtx, ::typeof(SpecialFunctions.erf), x::Union{Float32, Float64}) = SpecialFunctions.erf(x)
@inline Cassette.overdub(::CUDACtx, ::typeof(SpecialFunctions.erfc), x::Union{Float32, Float64}) = SpecialFunctions.erfc(x)

@static if Base.isbindingresolved(CUDA, :emit_shmem) && Base.isdefined(CUDA, :emit_shmem)
const emit_shmem = CUDA.emit_shmem
Expand Down
3 changes: 3 additions & 0 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ function generate_overdubs(mod, Ctx)

@inline Cassette.overdub(::$Ctx, ::typeof(Base.literal_pow), f::F, x, p) where F = Base.literal_pow(f, x, p)

@inline Cassette.overdub(::$Ctx, ::typeof(Base.throw_boundserror), args...) = Base.throw_boundserror(args...)
@inline Cassette.overdub(::$Ctx, ::typeof(Base.Math.throw_exp_domainerror), args...) = Base.Math.throw_exp_domainerror(args...)

function Cassette.overdub(::$Ctx, ::typeof(:), start::T, step::T, stop::T) where T<:Union{Float16,Float32,Float64}
lf = (stop-start)/step
if lf < 0
Expand Down
53 changes: 48 additions & 5 deletions test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,69 @@ end
A[1] = 2^11
end

@kernel function square(A, B)
A[1] = B[1]^2
end

@kernel function pow(A, B)
A[1] = A[1]^B[1]
end

@kernel function checked(A, a, b)
A[1] = Base.Checked.checked_add(a, b)
end

function compiler_testsuite()
function check_for_overdub(stmt)
if stmt isa Expr
if stmt.head == :invoke
mi = first(stmt.args)::Core.MethodInstance
if mi.def.name === :overdub
@show stmt
return true
end
end
end
return false
end

function compiler_testsuite(backend, ArrayT)
kernel = index(CPU(), DynamicSize(), DynamicSize())
iterspace = NDRange{1, StaticSize{(128,)}, StaticSize{(8,)}}();
ctx = KernelAbstractions.mkcontext(kernel, 1, nothing, iterspace, Val(KernelAbstractions.NoDynamicCheck()))

@test KernelAbstractions.Cassette.overdub(ctx, KernelAbstractions.__index_Global_NTuple, CartesianIndex(1)) == (1,)

let (CI, rt) = @ka_code_typed literal_pow(CPU())(zeros(Int,1), ndrange=1)
A = ArrayT{Int}(undef, 1)
let (CI, rt) = @ka_code_typed literal_pow(backend())(A, ndrange=1)
# test that there is no invoke of overdub
@test !any(check_for_overdub, CI.code)
end

A = ArrayT{Float64}(undef, 1)
let (CI, rt) = @ka_code_typed square(backend())(A, A, ndrange=1)
# test that there is no invoke of overdub
@test !any(check_for_overdub, CI.code)
end

A = ArrayT{Float64}(undef, 1)
B = ArrayT{Float64}(undef, 1)
let (CI, rt) = @ka_code_typed pow(backend())(A, B, ndrange=1)
# test that there is no invoke of overdub
@test !any(check_for_overdub, CI.code)
end

A = ArrayT{Float64}(undef, 1)
B = ArrayT{Int32}(undef, 1)
let (CI, rt) = @ka_code_typed pow(backend())(A, B, ndrange=1)
# test that there is no invoke of overdub
@test !any(stmt->(stmt isa Expr) && stmt.head == :invoke, CI.code)
@test !any(check_for_overdub, CI.code)
end

if VERSION >= v"1.5"
let (CI, rt) = @ka_code_typed checked(CPU())(zeros(Int,1), 1, 2, ndrange=1)
A = ArrayT{Int}(undef, 1)
let (CI, rt) = @ka_code_typed checked(backend())(A, 1, 2, ndrange=1)
# test that there is no invoke of overdub
@test !any(stmt->(stmt isa Expr) && stmt.head == :invoke, CI.code)
@test !any(check_for_overdub, CI.code)
end
end
end
4 changes: 2 additions & 2 deletions test/testsuite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ function testsuite(backend, backend_str, backend_mod, AT, DAT)
end
end

if backend == CPU
if backend_str != "ROCM"
@testset "Compiler" begin
compiler_testsuite()
compiler_testsuite(backend, AT)
end
end

Expand Down

0 comments on commit 1e51126

Please sign in to comment.