Skip to content

Commit

Permalink
Try #163:
Browse files Browse the repository at this point in the history
  • Loading branch information
bors[bot] authored Jan 6, 2021
2 parents 4ec8066 + 4a3acad commit 3784d67
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 4 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
Expand Down
100 changes: 99 additions & 1 deletion src/KernelAbstractions.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
module KernelAbstractions

export @kernel
export @Const, @localmem, @private, @uniform, @synchronize, @index, groupsize, @print
export @Const, @localmem, @private, @uniform, @synchronize, @index, groupsize, @print, @printf
export Device, GPU, CPU, CUDADevice, Event, MultiEvent, NoneEvent
export async_copy!


using MacroTools
using Printf
using StaticArrays
using Cassette
using Adapt
Expand All @@ -28,6 +29,7 @@ and then invoked on the arguments.
- [`@uniform`](@ref)
- [`@synchronize`](@ref)
- [`@print`](@ref)
- [`@printf`](@ref)
# Example:
Expand Down Expand Up @@ -236,6 +238,32 @@ macro print(items...)
end
end

# When a function with a variable-length argument list is called, the variable
# arguments are passed using C's old ``default argument promotions.'' These say that
# types char and short int are automatically promoted to int, and type float is
# automatically promoted to double. Therefore, varargs functions will never receive
# arguments of type char, short int, or float.

promote_c_argument(arg) = arg
promote_c_argument(arg::Cfloat) = Cdouble(arg)
promote_c_argument(arg::Cchar) = Cint(arg)
promote_c_argument(arg::Cshort) = Cint(arg)

"""
@printf(fmt::String, args...)
This is a unified formatted printf statement.
# Platform differences
- `GPU`: This will reorganize the items to print via @cuprintf
- `CPU`: This will call `sprintf(fmt, items...)`
"""
macro printf(fmt::String, args...)
fmt_val = Val(Symbol(fmt))

return :(__printf($fmt_val, $(map(arg -> :(promote_c_argument($arg)), esc.(args))...)))
end

"""
@index
Expand Down Expand Up @@ -452,6 +480,76 @@ end
end
end

# Results in "Conversion of boxed type String is not allowed"
# @generated function __printf(::Val{fmt}, argspec...) where {fmt}
# arg_exprs = [:( argspec[$i] ) for i in 1:length(argspec)]
# arg_types = [argspec...]

# T_void = LLVM.VoidType(LLVM.Interop.JuliaContext())
# T_int32 = LLVM.Int32Type(LLVM.Interop.JuliaContext())
# T_pint8 = LLVM.PointerType(LLVM.Int8Type(LLVM.Interop.JuliaContext()))

# # create functions
# param_types = LLVMType[convert.(LLVMType, arg_types)...]
# llvm_f, _ = create_function(T_int32, param_types)
# mod = LLVM.parent(llvm_f)
# sfmt = String(fmt)
# # generate IR
# Builder(LLVM.Interop.JuliaContext()) do builder
# entry = BasicBlock(llvm_f, "entry", LLVM.Interop.JuliaContext())
# position!(builder, entry)

# str = globalstring_ptr!(builder, sfmt)

# # construct and fill args buffer
# if isempty(argspec)
# buffer = LLVM.PointerNull(T_pint8)
# else
# argtypes = LLVM.StructType("printf_args", LLVM.Interop.JuliaContext())
# elements!(argtypes, param_types)

# args = alloca!(builder, argtypes)
# for (i, param) in enumerate(parameters(llvm_f))
# p = struct_gep!(builder, args, i-1)
# store!(builder, param, p)
# end

# buffer = bitcast!(builder, args, T_pint8)
# end

# # invoke vprintf and return
# vprintf_typ = LLVM.FunctionType(T_int32, [T_pint8, T_pint8])
# vprintf = LLVM.Function(mod, "vprintf", vprintf_typ)
# chars = call!(builder, vprintf, [str, buffer])

# ret!(builder, chars)
# end

# arg_tuple = Expr(:tuple, arg_exprs...)
# call_function(llvm_f, Int32, Tuple{arg_types...}, arg_tuple)
# end

# Results in "InvalidIRError: compiling kernel
# gpu_kernel_printf(... Reason: unsupported dynamic
# function invocation"
@generated function __printf(::Val{fmt}, items...) where {fmt}
str = ""
args = []

for i in 1:length(items)
item = :(items[$i])
T = items[i]
if T <: Val
item = QuoteNode(T.parameters[1])
end
push!(args, item)
end
sfmt = String(fmt)
quote
Printf.@printf($sfmt, $(args...))
end
end

###
# Backends/Implementation
###
Expand Down
4 changes: 4 additions & 0 deletions src/backends/cpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,10 @@ end
__print(items...)
end

@inline function Cassette.overdub(ctx::CPUCtx, ::typeof(__printf), fmt, items...)
__printf(fmt, items...)
end

generate_overdubs(CPUCtx)

# Don't recurse into these functions
Expand Down
8 changes: 8 additions & 0 deletions src/backends/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,14 @@ end
CUDA._cuprint(args...)
end

@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(__printf), fmt, args...)
CUDA._cuprintf(Val(fmt), args...)
end

@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(__printf), ::Val{fmt}, args...) where fmt
CUDA._cuprintf(Val(fmt), args...)
end

###
# GPU implementation of const memory
###
Expand Down
32 changes: 29 additions & 3 deletions test/print_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,51 @@ if has_cuda_gpu()
CUDA.allowscalar(false)
end

struct Foo{A,B} end
get_name(::Type{T}) where T<:Foo = "Foo"

@kernel function kernel_print()
I = @index(Global)
@print("Hello from thread ", I, "!\n")
end

@kernel function kernel_printf()
I = @index(Global)
# @printf("Hello printf %s thread %d! type = %s.\n", "from", I, nameof(Foo))
# @print("Hello printf from thread ", I, "!\n")
# @printf("Hello printf %s thread %d! type = %s.\n", "from", I, string(nameof(Foo)))
@printf("Hello printf %s thread %d! type = %s.\n", "from", I, "Foo")
@printf("Hello printf %s thread %d! type = %s.\n", "from", I, get_name(Foo))
end

function test_print(backend)
kernel = kernel_print(backend, 4)
kernel(ndrange=(4,))
kernel(ndrange=(4,))
end

function test_printf(backend)
kernel = kernel_printf(backend, 4)
kernel(ndrange=(4,))
end

@testset "print test" begin
wait(test_print(CPU()))
@test true

wait(test_printf(CPU()))
@test true

if has_cuda_gpu()
wait(test_print(CUDADevice()))
@test true
wait(test_printf(CUDADevice()))
@test true
end

wait(test_print(CPU()))
@print("Why this should work")
@test true

@print("Why this should work")
@printf("Why this should work")
@test true
end

0 comments on commit 3784d67

Please sign in to comment.