Skip to content

Commit

Permalink
Try #253:
Browse files Browse the repository at this point in the history
  • Loading branch information
bors[bot] authored May 14, 2021
2 parents 1496f4e + 8b22893 commit c44e28a
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 10 deletions.
14 changes: 10 additions & 4 deletions lib/CUDAKernels/src/CUDAKernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,12 @@ KernelAbstractions.generate_overdubs(@__MODULE__, CUDACtx)
# CUDA specific method rewrites
###

@inline Cassette.overdub(::CUDACtx, ::typeof(^), x::Float64, y::Float64) = ^(x, y)
@inline Cassette.overdub(::CUDACtx, ::typeof(^), x::Float32, y::Float32) = ^(x, y)
@inline Cassette.overdub(::CUDACtx, ::typeof(^), x::Float64, y::Int32) = ^(x, y)
@inline Cassette.overdub(::CUDACtx, ::typeof(^), x::Float32, y::Int32) = ^(x, y)
@inline Cassette.overdub(::CUDACtx, ::typeof(^), x::Union{Float32, Float64}, y::Int64) = ^(x, y)

# libdevice.jl
const cudafuns = (:cos, :cospi, :sin, :sinpi, :tan,
:acos, :asin, :atan,
Expand All @@ -300,12 +306,12 @@ for f in cudafuns
end
end

@inline Cassette.overdub(::CUDACtx, ::typeof(sincos), x::Union{Float32, Float64}) = (CUDA.sin(x), CUDA.cos(x))
@inline Cassette.overdub(::CUDACtx, ::typeof(exp), x::Union{ComplexF32, ComplexF64}) = CUDA.exp(x)
@inline Cassette.overdub(::CUDACtx, ::typeof(sincos), x::Union{Float32, Float64}) = (Base.sin(x), Base.cos(x))
@inline Cassette.overdub(::CUDACtx, ::typeof(exp), x::Union{ComplexF32, ComplexF64}) = Base.exp(x)

@inline Cassette.overdub(::CUDACtx, ::typeof(SpecialFunctions.gamma), x::Union{Float32, Float64}) = CUDA.tgamma(x)
@inline Cassette.overdub(::CUDACtx, ::typeof(SpecialFunctions.erf), x::Union{Float32, Float64}) = CUDA.erf(x)
@inline Cassette.overdub(::CUDACtx, ::typeof(SpecialFunctions.erfc), x::Union{Float32, Float64}) = CUDA.erfc(x)
@inline Cassette.overdub(::CUDACtx, ::typeof(SpecialFunctions.erf), x::Union{Float32, Float64}) = SpecialFunctions.erf(x)
@inline Cassette.overdub(::CUDACtx, ::typeof(SpecialFunctions.erfc), x::Union{Float32, Float64}) = SpecialFunctions.erfc(x)

@static if Base.isbindingresolved(CUDA, :emit_shmem) && Base.isdefined(CUDA, :emit_shmem)
const emit_shmem = CUDA.emit_shmem
Expand Down
11 changes: 7 additions & 4 deletions test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,28 @@ end
A[1] = Base.Checked.checked_add(a, b)
end

function compiler_testsuite()
function compiler_testsuite(backend, ArrayT)
kernel = index(CPU(), DynamicSize(), DynamicSize())
iterspace = NDRange{1, StaticSize{(128,)}, StaticSize{(8,)}}();
ctx = KernelAbstractions.mkcontext(kernel, 1, nothing, iterspace, Val(KernelAbstractions.NoDynamicCheck()))
CTX = KernelAbstractions.cassette(kernel)
@test KernelAbstractions.Cassette.overdub(CTX, KernelAbstractions.__index_Global_NTuple, ctx, CartesianIndex(1)) == (1,)

let (CI, rt) = @ka_code_typed literal_pow(CPU())(zeros(Int,1), ndrange=1)
A = ArrayT{Int}(undef, 1)
let (CI, rt) = @ka_code_typed literal_pow(backend())(A, ndrange=1)
# test that there is no invoke of overdub
@test !any(stmt->(stmt isa Expr) && stmt.head == :invoke, CI.code)
end

let (CI, rt) = @ka_code_typed square(CPU())(zeros(1), zeros(1), ndrange=1)
A = ArrayT{Float64}(undef, 1)
let (CI, rt) = @ka_code_typed square(backend())(A, A, ndrange=1)
# test that there is no invoke of overdub
@test !any(stmt->(stmt isa Expr) && stmt.head == :invoke, CI.code)
end

if VERSION >= v"1.5"
let (CI, rt) = @ka_code_typed checked(CPU())(zeros(Int,1), 1, 2, ndrange=1)
A = ArrayT{Int}(undef, 1)
let (CI, rt) = @ka_code_typed checked(backend())(A, 1, 2, ndrange=1)
# test that there is no invoke of overdub
@test !any(stmt->(stmt isa Expr) && stmt.head == :invoke, CI.code)
end
Expand Down
4 changes: 2 additions & 2 deletions test/testsuite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ function testsuite(backend, backend_str, backend_mod, AT, DAT)
end
end

if backend == CPU
if backend_str != "ROCM"
@testset "Compiler" begin
compiler_testsuite()
compiler_testsuite(backend, AT)
end
end

Expand Down

0 comments on commit c44e28a

Please sign in to comment.