Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference using NN :: Chain inside a GPU kernel #720

Closed
xkykai opened this issue Jun 21, 2024 · 4 comments · Fixed by #723
Closed

Inference using NN :: Chain inside a GPU kernel #720

xkykai opened this issue Jun 21, 2024 · 4 comments · Fixed by #723

Comments

@xkykai
Copy link

xkykai commented Jun 21, 2024

Hello, with @simone-silvestri, we are trying to compute the output of a neural network inside a GPU kernel. Therefore, we would need the NN to be passed to the gpu kernel but we are having an issue with adapting the NN to the GPU.

In particular, we are writing a kernel using KernelAbstractions.jl that passes the NN as an input to the kernel and evaluates the NN's output pointwise inside the kernel. When we try to launch the kernel we get the following error

Argument 5 to your kernel function is of type NNFluxClosure{NN{Chain{@NamedTuple{layer_1::Dense{true, typeof(relu), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(relu), typeof(glorot_uniform), typeof(zeros32)}, layer_3::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}}, ComponentVector{Float32, CuDeviceVector{Float32, 1}, Tuple{Axis{(layer_1 = ViewAxis(1:1536, Axis(weight = ViewAxis(1:1408, ShapedAxis((128, 11))), bias = ViewAxis(1409:1536, ShapedAxis((128, 1))))), layer_2 = ViewAxis(1537:18048, Axis(weight = ViewAxis(1:16384, ShapedAxis((128, 128))), bias = ViewAxis(16385:16512, ShapedAxis((128, 1))))), layer_3 = ViewAxis(18049:18177, Axis(weight = ViewAxis(1:128, ShapedAxis((1, 128))), bias = ViewAxis(129:129, ShapedAxis((1, 1))))))}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}, @NamedTuple{∂T∂z::ZeroMeanUnitVarianceScaling{Float64}, ∂S∂z::ZeroMeanUnitVarianceScaling{Float64}, ∂ρ∂z::ZeroMeanUnitVarianceScaling{Float64}, f::ZeroMeanUnitVarianceScaling{Float64}, wb::ZeroMeanUnitVarianceScaling{Float64}, wT::ZeroMeanUnitVarianceScaling{Float64}, wS::ZeroMeanUnitVarianceScaling{Float64}}}, which is not isbits:
  .wT is of type NN{Chain{@NamedTuple{layer_1::Dense{true, typeof(relu), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(relu), typeof(glorot_uniform), typeof(zeros32)}, layer_3::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}}, ComponentVector{Float32, CuDeviceVector{Float32, 1}, Tuple{Axis{(layer_1 = ViewAxis(1:1536, Axis(weight = ViewAxis(1:1408, ShapedAxis((128, 11))), bias = ViewAxis(1409:1536, ShapedAxis((128, 1))))), layer_2 = ViewAxis(1537:18048, Axis(weight = ViewAxis(1:16384, ShapedAxis((128, 128))), bias = ViewAxis(16385:16512, ShapedAxis((128, 1))))), layer_3 = ViewAxis(18049:18177, Axis(weight = ViewAxis(1:128, ShapedAxis((1, 128))), bias = ViewAxis(129:129, ShapedAxis((1, 1))))))}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}} which is not isbits.
    .model is of type Chain{@NamedTuple{layer_1::Dense{true, typeof(relu), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(relu), typeof(glorot_uniform), typeof(zeros32)}, layer_3::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}} which is not isbits.
      .name is of type Union{Nothing, String, Symbol} which is not isbits.
  .wS is of type NN{Chain{@NamedTuple{layer_1::Dense{true, typeof(relu), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(relu), typeof(glorot_uniform), typeof(zeros32)}, layer_3::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}}, ComponentVector{Float32, CuDeviceVector{Float32, 1}, Tuple{Axis{(layer_1 = ViewAxis(1:1536, Axis(weight = ViewAxis(1:1408, ShapedAxis((128, 11))), bias = ViewAxis(1409:1536, ShapedAxis((128, 1))))), layer_2 = ViewAxis(1537:18048, Axis(weight = ViewAxis(1:16384, ShapedAxis((128, 128))), bias = ViewAxis(16385:16512, ShapedAxis((128, 1))))), layer_3 = ViewAxis(18049:18177, Axis(weight = ViewAxis(1:128, ShapedAxis((1, 128))), bias = ViewAxis(129:129, ShapedAxis((1, 1))))))}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}} which is not isbits.
    .model is of type Chain{@NamedTuple{layer_1::Dense{true, typeof(relu), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(relu), typeof(glorot_uniform), typeof(zeros32)}, layer_3::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}} which is not isbits.
      .name is of type Union{Nothing, String, Symbol} which is not isbits.

We suspect this issue stems from the fact that NAME_TYPE (the type of model.name) is forced to be Union{Nothing, String, Symbol}, where Symbol is not an isbits type, preventing the NN to be passed to the GPU.
Does anyone have any suggestion to solve this problem?

@avik-pal
Copy link
Member

That's an interesting usecase I hadn't thought of. I can add concrete type for name if that solves the issue. Can you give me a script that I can use to test?

@xkykai
Copy link
Author

xkykai commented Jun 23, 2024

Here's a simple test script:

using Lux, LuxCUDA, ComponentArrays, Random, CUDA
using KernelAbstractions

dev = gpu_device()
rng  = Random.default_rng(123)

nn = Chain(Dense(4, 4, relu), Dense(4, 1))
ps, st = Lux.setup(rng, nn)

ps = ps |> ComponentArray |> dev .|> Float64
st = st |> dev

input = randn(4, 100) |> dev .|> Float64

outp = nn(input, ps, st)

@kernel function nn_pass(input, output, nn, ps, st)
    i = @index(Global, Linear)
    output[i] = nn(input[:, i], ps, st)[1][1]
end

output = zeros(100) |> dev .|> Float64

loop! = nn_pass(CUDA.CUDABackend(), 100, 100)
loop!(input, output, nn, ps, st)

which gives an error of

ERROR: InvalidIRError: compiling MethodInstance for gpu_nn_pass(::KernelAbstractions.CompilerMetadata{…}, ::CuDeviceMatrix{…}, ::CuDeviceVector{…}, ::Chain{…}, ::ComponentVector{…}, ::@NamedTuple{}) resulted in invalid LLVM IR
Reason: unsupported call to an unknown function (call to julia.new_gc_frame)
Stacktrace:
  [1] __matmuladd
    @ C:\Users\xinle\.julia\packages\LuxLib\Q3elb\src\impl\fused_dense.jl:6
  [2] __fused_dense_bias_activation_impl
    @ C:\Users\xinle\.julia\packages\LuxLib\Q3elb\src\impl\fused_dense.jl:26
  [3] fused_dense_bias_activation
    @ C:\Users\xinle\.julia\packages\LuxLib\Q3elb\src\api\dense.jl:46
  [4] fused_dense_bias_activation
    @ C:\Users\xinle\.julia\packages\LuxLib\Q3elb\src\api\dense.jl:38
  [5] Dense
    @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\basic.jl:357
  [6] Dense
    @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\basic.jl:353
  [7] apply
    @ C:\Users\xinle\.julia\packages\LuxCore\qiHPC\src\LuxCore.jl:179
  [8] macro expansion
    @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:0
  [9] applychain
    @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:498
 [10] Chain
    @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:496
 [11] macro expansion
    @ c:\Users\xinle\MIT\NN_Oceanaigans\Oceananigans.jl\test_lux_cuda.jl:19
 [12] gpu_nn_pass
    @ C:\Users\xinle\.julia\packages\KernelAbstractions\zPAn3\src\macros.jl:95
 [13] gpu_nn_pass
    @ .\none:0
Reason: unsupported call to an unknown function (call to julia.push_gc_frame)
Stacktrace:
  [1] __matmuladd
    @ C:\Users\xinle\.julia\packages\LuxLib\Q3elb\src\impl\fused_dense.jl:6
  [2] __fused_dense_bias_activation_impl
    @ C:\Users\xinle\.julia\packages\LuxLib\Q3elb\src\impl\fused_dense.jl:26
  [3] fused_dense_bias_activation
    @ C:\Users\xinle\.julia\packages\LuxLib\Q3elb\src\api\dense.jl:46
  [4] fused_dense_bias_activation
    @ C:\Users\xinle\.julia\packages\LuxLib\Q3elb\src\api\dense.jl:38
  [5] Dense
    @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\basic.jl:357
  [6] Dense
    @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\basic.jl:353
  [7] apply
    @ C:\Users\xinle\.julia\packages\LuxCore\qiHPC\src\LuxCore.jl:179
  [8] macro expansion
    @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:0
  [9] applychain
    @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:498
 [10] Chain
    @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:496
 [11] macro expansion
    @ c:\Users\xinle\MIT\NN_Oceanaigans\Oceananigans.jl\test_lux_cuda.jl:19
 [12] gpu_nn_pass
    @ C:\Users\xinle\.julia\packages\KernelAbstractions\zPAn3\src\macros.jl:95
 [13] gpu_nn_pass
    @ .\none:0
Reason: unsupported call through a literal pointer (call to ijl_alloc_array_2d)
Stacktrace:
  [1] Array
    @ .\boot.jl:479
  [2] Array
    @ .\boot.jl:487
  [3] similar
    @ .\abstractarray.jl:842
  [4] similar
    @ .\subarray.jl:65
  [5] similar
    @ .\reshapedarray.jl:209
  [6] similar
    @ .\abstractarray.jl:833
  [7] muladd
    @ C:\Users\xinle\AppData\Local\Programs\Julia-1.10.1\share\julia\stdlib\v1.10\LinearAlgebra\src\matmul.jl:209
  [8] __matmuladd
    @ C:\Users\xinle\.julia\packages\LuxLib\Q3elb\src\impl\fused_dense.jl:6
  [9] __fused_dense_bias_activation_impl
    @ C:\Users\xinle\.julia\packages\LuxLib\Q3elb\src\impl\fused_dense.jl:26
 [10] fused_dense_bias_activation
    @ C:\Users\xinle\.julia\packages\LuxLib\Q3elb\src\api\dense.jl:46
 [11] fused_dense_bias_activation
    @ C:\Users\xinle\.julia\packages\LuxLib\Q3elb\src\api\dense.jl:38
 [12] Dense
    @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\basic.jl:357
 [13] Dense
    @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\basic.jl:353
 [14] apply
    @ C:\Users\xinle\.julia\packages\LuxCore\qiHPC\src\LuxCore.jl:179
 [15] macro expansion
    @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:0
 [16] applychain
    @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:498
 [17] Chain
    @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:496
 [18] macro expansion
    @ c:\Users\xinle\MIT\NN_Oceanaigans\Oceananigans.jl\test_lux_cuda.jl:19
 [19] gpu_nn_pass
    @ C:\Users\xinle\.julia\packages\KernelAbstractions\zPAn3\src\macros.jl:95
 [20] gpu_nn_pass
    @ .\none:0
Reason: unsupported call to an unknown function (call to julia.get_gc_frame_slot)
Stacktrace:
  [1] __matmuladd
    @ C:\Users\xinle\.julia\packages\LuxLib\Q3elb\src\impl\fused_dense.jl:6
  [2] __fused_dense_bias_activation_impl
    @ C:\Users\xinle\.julia\packages\LuxLib\Q3elb\src\impl\fused_dense.jl:26
  [3] fused_dense_bias_activation
    @ C:\Users\xinle\.julia\packages\LuxLib\Q3elb\src\api\dense.jl:46
  [4] fused_dense_bias_activation
    @ C:\Users\xinle\.julia\packages\LuxLib\Q3elb\src\api\dense.jl:38
  [5] Dense
    @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\basic.jl:357
  [6] Dense
    @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\basic.jl:353
  [7] apply
    @ C:\Users\xinle\.julia\packages\LuxCore\qiHPC\src\LuxCore.jl:179
  [8] macro expansion
    @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:0
  [9] applychain
    @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:498
 [10] Chain
    @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:496
 [11] macro expansion
    @ c:\Users\xinle\MIT\NN_Oceanaigans\Oceananigans.jl\test_lux_cuda.jl:19
 [12] gpu_nn_pass
    @ C:\Users\xinle\.julia\packages\KernelAbstractions\zPAn3\src\macros.jl:95
 [13] gpu_nn_pass
    @ .\none:0
Reason: unsupported call to an unknown function (call to julia.pop_gc_frame)
Stacktrace:
  [1] __matmuladd
    @ C:\Users\xinle\.julia\packages\LuxLib\Q3elb\src\impl\fused_dense.jl:6
  [2] __fused_dense_bias_activation_impl
    @ C:\Users\xinle\.julia\packages\LuxLib\Q3elb\src\impl\fused_dense.jl:26
  [3] fused_dense_bias_activation
    @ C:\Users\xinle\.julia\packages\LuxLib\Q3elb\src\api\dense.jl:46
  [4] fused_dense_bias_activation
    @ C:\Users\xinle\.julia\packages\LuxLib\Q3elb\src\api\dense.jl:38
  [5] Dense
    @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\basic.jl:357
  [6] Dense
    @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\basic.jl:353
  [7] apply
    @ C:\Users\xinle\.julia\packages\LuxCore\qiHPC\src\LuxCore.jl:179
  [8] macro expansion
    @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:0
  [9] applychain
    @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:498
 [10] Chain
    @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:496
 [11] macro expansion
    @ c:\Users\xinle\MIT\NN_Oceanaigans\Oceananigans.jl\test_lux_cuda.jl:19
 [12] gpu_nn_pass
    @ C:\Users\xinle\.julia\packages\KernelAbstractions\zPAn3\src\macros.jl:95
 [13] gpu_nn_pass
    @ .\none:0
Reason: unsupported call to an unknown function (call to julia.new_gc_frame)
Reason: unsupported call to an unknown function (call to julia.push_gc_frame)
Reason: unsupported call to an unknown function (call to julia.get_gc_frame_slot)
Reason: unsupported call to an unknown function (call to julia.pop_gc_frame)
Reason: unsupported call to an unknown function (call to julia.new_gc_frame)
Stacktrace:
 [1] GenericIOBuffer
   @ .\iobuffer.jl:106
 [2] print_to_string
   @ .\strings\io.jl:146
 [3] multiple call sites
   @ unknown:0
Reason: unsupported call to an unknown function (call to julia.push_gc_frame)
Stacktrace:
 [1] GenericIOBuffer
   @ .\iobuffer.jl:106
 [2] print_to_string
   @ .\strings\io.jl:146
 [3] multiple call sites
   @ unknown:0
Reason: unsupported call through a literal pointer (call to ijl_alloc_string)
Stacktrace:
 [1] _string_n
   @ .\strings\string.jl:90
 [2] StringVector
   @ .\iobuffer.jl:32
 [3] #IOBuffer#469
   @ .\iobuffer.jl:115
 [4] GenericIOBuffer
   @ .\iobuffer.jl:106
 [5] print_to_string
   @ .\strings\io.jl:146
 [6] multiple call sites
   @ unknown:0
Reason: unsupported call to an unknown function (call to julia.get_gc_frame_slot)
Stacktrace:
 [1] GenericIOBuffer
   @ .\iobuffer.jl:106
 [2] print_to_string
   @ .\strings\io.jl:146
 [3] multiple call sites
   @ unknown:0
Reason: unsupported call through a literal pointer (call to ijl_string_to_array)
Stacktrace:
 [1] unsafe_wrap
   @ .\strings\string.jl:100
 [2] StringVector
   @ .\iobuffer.jl:32
 [3] #IOBuffer#469
   @ .\iobuffer.jl:115
 [4] GenericIOBuffer
   @ .\iobuffer.jl:106
 [5] print_to_string
   @ .\strings\io.jl:146
 [6] multiple call sites
   @ unknown:0
Reason: unsupported call to an unknown function (call to julia.pop_gc_frame)
Stacktrace:
 [1] GenericIOBuffer
   @ .\iobuffer.jl:106
 [2] print_to_string
   @ .\strings\io.jl:146
 [3] multiple call sites
   @ unknown:0
Reason: unsupported call to an unknown function (call to julia.new_gc_frame)
Stacktrace:
 [1] _getindex
   @ .\multidimensional.jl:889
 [2] getindex
   @ .\abstractarray.jl:1291
 [3] macro expansion
   @ c:\Users\xinle\MIT\NN_Oceanaigans\Oceananigans.jl\test_lux_cuda.jl:19
 [4] gpu_nn_pass
   @ C:\Users\xinle\.julia\packages\KernelAbstractions\zPAn3\src\macros.jl:95
 [5] gpu_nn_pass
   @ .\none:0
Reason: unsupported call to an unknown function (call to julia.push_gc_frame)
Stacktrace:
 [1] _getindex
   @ .\multidimensional.jl:889
 [2] getindex
   @ .\abstractarray.jl:1291
 [3] macro expansion
   @ c:\Users\xinle\MIT\NN_Oceanaigans\Oceananigans.jl\test_lux_cuda.jl:19
 [4] gpu_nn_pass
   @ C:\Users\xinle\.julia\packages\KernelAbstractions\zPAn3\src\macros.jl:95
 [5] gpu_nn_pass
   @ .\none:0
Reason: unsupported call through a literal pointer (call to ijl_alloc_array_1d)
Stacktrace:
  [1] Array
    @ .\boot.jl:477
  [2] Array
    @ .\boot.jl:486
  [3] similar
    @ .\abstractarray.jl:842
  [4] similar
    @ .\abstractarray.jl:831
  [5] _unsafe_getindex
    @ .\multidimensional.jl:901
  [6] _getindex
    @ .\multidimensional.jl:889
  [7] getindex
    @ .\abstractarray.jl:1291
  [8] macro expansion
    @ c:\Users\xinle\MIT\NN_Oceanaigans\Oceananigans.jl\test_lux_cuda.jl:19
  [9] gpu_nn_pass
    @ C:\Users\xinle\.julia\packages\KernelAbstractions\zPAn3\src\macros.jl:95
 [10] gpu_nn_pass
    @ .\none:0
Reason: unsupported call to an unknown function (call to julia.pop_gc_frame)
Stacktrace:
 [1] _getindex
   @ .\multidimensional.jl:889
 [2] getindex
   @ .\abstractarray.jl:1291
 [3] macro expansion
   @ c:\Users\xinle\MIT\NN_Oceanaigans\Oceananigans.jl\test_lux_cuda.jl:19
 [4] gpu_nn_pass
   @ C:\Users\xinle\.julia\packages\KernelAbstractions\zPAn3\src\macros.jl:95
 [5] gpu_nn_pass
   @ .\none:0
Reason: unsupported call to an unknown function (call to julia.get_gc_frame_slot)
Stacktrace:
 [1] _getindex
   @ .\multidimensional.jl:889
 [2] getindex
   @ .\abstractarray.jl:1291
 [3] macro expansion
   @ c:\Users\xinle\MIT\NN_Oceanaigans\Oceananigans.jl\test_lux_cuda.jl:19
 [4] gpu_nn_pass
   @ C:\Users\xinle\.julia\packages\KernelAbstractions\zPAn3\src\macros.jl:95
 [5] gpu_nn_pass
   @ .\none:0
Reason: unsupported call to an unknown function (call to julia.new_gc_frame)
Stacktrace:
 [1] macro expansion
   @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:0
 [2] applychain
   @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:498
 [3] Chain
   @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:496
 [4] macro expansion
   @ c:\Users\xinle\MIT\NN_Oceanaigans\Oceananigans.jl\test_lux_cuda.jl:19
 [5] gpu_nn_pass
   @ C:\Users\xinle\.julia\packages\KernelAbstractions\zPAn3\src\macros.jl:95
 [6] gpu_nn_pass
   @ .\none:0
Reason: unsupported call to an unknown function (call to julia.push_gc_frame)
Stacktrace:
 [1] macro expansion
   @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:0
 [2] applychain
   @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:498
 [3] Chain
   @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:496
 [4] macro expansion
   @ c:\Users\xinle\MIT\NN_Oceanaigans\Oceananigans.jl\test_lux_cuda.jl:19
 [5] gpu_nn_pass
   @ C:\Users\xinle\.julia\packages\KernelAbstractions\zPAn3\src\macros.jl:95
 [6] gpu_nn_pass
   @ .\none:0
Reason: unsupported call to an unknown function (call to julia.get_gc_frame_slot)
Stacktrace:
 [1] macro expansion
   @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:0
 [2] applychain
   @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:498
 [3] Chain
   @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:496
 [4] macro expansion
   @ c:\Users\xinle\MIT\NN_Oceanaigans\Oceananigans.jl\test_lux_cuda.jl:19
 [5] gpu_nn_pass
   @ C:\Users\xinle\.julia\packages\KernelAbstractions\zPAn3\src\macros.jl:95
 [6] gpu_nn_pass
   @ .\none:0
Reason: unsupported call through a literal pointer (call to ijl_reshape_array)
Stacktrace:
  [1] reshape
    @ .\reshapedarray.jl:51
  [2] reshape
    @ .\reshapedarray.jl:119
  [3] reshape
    @ .\reshapedarray.jl:118
  [4] Dense
    @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\basic.jl:353
  [5] apply
    @ C:\Users\xinle\.julia\packages\LuxCore\qiHPC\src\LuxCore.jl:179
  [6] macro expansion
    @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:0
  [7] applychain
    @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:498
  [8] Chain
    @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:496
  [9] macro expansion
    @ c:\Users\xinle\MIT\NN_Oceanaigans\Oceananigans.jl\test_lux_cuda.jl:19
 [10] gpu_nn_pass
    @ C:\Users\xinle\.julia\packages\KernelAbstractions\zPAn3\src\macros.jl:95
 [11] gpu_nn_pass
    @ .\none:0
Reason: unsupported call through a literal pointer (call to ijl_reshape_array)
Stacktrace:
  [1] reshape
    @ .\reshapedarray.jl:51
  [2] reshape
    @ .\reshapedarray.jl:117
  [3] vec
    @ .\abstractarraymath.jl:41
  [4] Dense
    @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\basic.jl:353
  [5] apply
    @ C:\Users\xinle\.julia\packages\LuxCore\qiHPC\src\LuxCore.jl:179
  [6] macro expansion
    @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:0
  [7] applychain
    @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:498
  [8] Chain
    @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:496
  [9] macro expansion
    @ c:\Users\xinle\MIT\NN_Oceanaigans\Oceananigans.jl\test_lux_cuda.jl:19
 [10] gpu_nn_pass
    @ C:\Users\xinle\.julia\packages\KernelAbstractions\zPAn3\src\macros.jl:95
 [11] gpu_nn_pass
    @ .\none:0
Reason: unsupported call to an unknown function (call to julia.pop_gc_frame)
Stacktrace:
 [1] macro expansion
   @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:0
 [2] applychain
   @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:498
 [3] Chain
   @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:496
 [4] macro expansion
   @ c:\Users\xinle\MIT\NN_Oceanaigans\Oceananigans.jl\test_lux_cuda.jl:19
 [5] gpu_nn_pass
   @ C:\Users\xinle\.julia\packages\KernelAbstractions\zPAn3\src\macros.jl:95
 [6] gpu_nn_pass
   @ .\none:0
Reason: unsupported call to an unknown function (call to julia.new_gc_frame)
Stacktrace:
 [1] unsafe_convert
   @ .\subarray.jl:470
 [2] unsafe_convert
   @ .\reshapedarray.jl:296
 [3] gemm!
   @ C:\Users\xinle\AppData\Local\Programs\Julia-1.10.1\share\julia\stdlib\v1.10\LinearAlgebra\src\blas.jl:1524
 [4] multiple call sites
   @ unknown:0
Reason: unsupported call to an unknown function (call to julia.push_gc_frame)
Stacktrace:
 [1] unsafe_convert
   @ .\subarray.jl:470
 [2] unsafe_convert
   @ .\reshapedarray.jl:296
 [3] gemm!
   @ C:\Users\xinle\AppData\Local\Programs\Julia-1.10.1\share\julia\stdlib\v1.10\LinearAlgebra\src\blas.jl:1524
 [4] multiple call sites
   @ unknown:0
Reason: unsupported call to an unknown function (call to julia.get_gc_frame_slot)
Stacktrace:
 [1] unsafe_convert
   @ .\subarray.jl:470
 [2] unsafe_convert
   @ .\reshapedarray.jl:296
 [3] gemm!
   @ C:\Users\xinle\AppData\Local\Programs\Julia-1.10.1\share\julia\stdlib\v1.10\LinearAlgebra\src\blas.jl:1524
 [4] multiple call sites
   @ unknown:0
Reason: unsupported call to an unknown function (call to julia.new_gc_frame)
Stacktrace:
 [1] multiple call sites
   @ unknown:0
Reason: unsupported call to an unknown function (call to julia.push_gc_frame)
Stacktrace:
 [1] multiple call sites
   @ unknown:0
Reason: unsupported call to an unknown function (call to julia.get_gc_frame_slot)
Stacktrace:
 [1] multiple call sites
   @ unknown:0
Reason: unsupported call through a literal pointer (call to ijl_alloc_array_2d)
Stacktrace:
  [1] Array
    @ .\boot.jl:479
  [2] Array
    @ .\boot.jl:487
  [3] similar
    @ .\abstractarray.jl:842
  [4] similar
    @ .\subarray.jl:65
  [5] similar
    @ .\reshapedarray.jl:209
  [6] similar
    @ .\abstractarray.jl:833
  [7] __fused_dense_bias_activation_impl
    @ C:\Users\xinle\.julia\packages\LuxLib\Q3elb\src\impl\fused_dense.jl:28
  [8] fused_dense_bias_activation
    @ C:\Users\xinle\.julia\packages\LuxLib\Q3elb\src\api\dense.jl:46
  [9] fused_dense_bias_activation
    @ C:\Users\xinle\.julia\packages\LuxLib\Q3elb\src\api\dense.jl:38
 [10] Dense
    @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\basic.jl:357
 [11] Dense
    @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\basic.jl:353
 [12] apply
    @ C:\Users\xinle\.julia\packages\LuxCore\qiHPC\src\LuxCore.jl:179
 [13] macro expansion
    @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:0
 [14] applychain
    @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:498
 [15] Chain
    @ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:496
 [16] macro expansion
    @ c:\Users\xinle\MIT\NN_Oceanaigans\Oceananigans.jl\test_lux_cuda.jl:19
 [17] gpu_nn_pass
    @ C:\Users\xinle\.julia\packages\KernelAbstractions\zPAn3\src\macros.jl:95
 [18] gpu_nn_pass
    @ .\none:0
Reason: unsupported call to an unknown function (call to julia.new_gc_frame)
Stacktrace:
 [1] _unsafe_getindex
   @ .\multidimensional.jl:902
 [2] _getindex
   @ .\multidimensional.jl:889
 [3] getindex
   @ .\abstractarray.jl:1291
 [4] macro expansion
   @ c:\Users\xinle\MIT\NN_Oceanaigans\Oceananigans.jl\test_lux_cuda.jl:19
 [5] gpu_nn_pass
   @ C:\Users\xinle\.julia\packages\KernelAbstractions\zPAn3\src\macros.jl:95
 [6] gpu_nn_pass
   @ .\none:0
Reason: unsupported call to an unknown function (call to julia.push_gc_frame)
Stacktrace:
 [1] _unsafe_getindex
   @ .\multidimensional.jl:902
 [2] _getindex
   @ .\multidimensional.jl:889
 [3] getindex
   @ .\abstractarray.jl:1291
 [4] macro expansion
   @ c:\Users\xinle\MIT\NN_Oceanaigans\Oceananigans.jl\test_lux_cuda.jl:19
 [5] gpu_nn_pass
   @ C:\Users\xinle\.julia\packages\KernelAbstractions\zPAn3\src\macros.jl:95
 [6] gpu_nn_pass
   @ .\none:0
Reason: unsupported call to an unknown function (call to julia.get_gc_frame_slot)
Stacktrace:
 [1] _unsafe_getindex
   @ .\multidimensional.jl:902
 [2] _getindex
   @ .\multidimensional.jl:889
 [3] getindex
   @ .\abstractarray.jl:1291
 [4] macro expansion
   @ c:\Users\xinle\MIT\NN_Oceanaigans\Oceananigans.jl\test_lux_cuda.jl:19
 [5] gpu_nn_pass
   @ C:\Users\xinle\.julia\packages\KernelAbstractions\zPAn3\src\macros.jl:95
 [6] gpu_nn_pass
   @ .\none:0
Reason: unsupported dynamic function invocation (call to _str_sizehint)
Stacktrace:
 [1] print_to_string
   @ .\strings\io.jl:143
 [2] multiple call sites
   @ unknown:0
Reason: unsupported dynamic function invocation (call to print)
Stacktrace:
 [1] print_to_string
   @ .\strings\io.jl:148
 [2] multiple call sites
   @ unknown:0
Reason: unsupported call through a literal pointer (call to ijl_array_grow_end)
Stacktrace:
 [1] _growend!
   @ .\array.jl:1072
 [2] resize!
   @ .\array.jl:1315
 [3] _unsafe_take!
   @ .\iobuffer.jl:445
 [4] print_to_string
   @ .\strings\io.jl:150
 [5] multiple call sites
   @ unknown:0
Reason: unsupported call through a literal pointer (call to ijl_array_del_end)
Stacktrace:
 [1] _deleteend!
   @ .\array.jl:1081
 [2] resize!
   @ .\array.jl:1320
 [3] _unsafe_take!
   @ .\iobuffer.jl:445
 [4] print_to_string
   @ .\strings\io.jl:150
 [5] multiple call sites
   @ unknown:0
Reason: unsupported call through a literal pointer (call to ijl_array_to_string)
Stacktrace:
 [1] String
   @ .\strings\string.jl:67
 [2] print_to_string
   @ .\strings\io.jl:150
 [3] multiple call sites
   @ unknown:0
Reason: unsupported call to an unknown function (call to julia.pop_gc_frame)
Stacktrace:
 [1] multiple call sites
   @ unknown:0
Hint: catch this exception as `err` and call `code_typed(err; interactive = true)` to introspect the erronous code with Cthulhu.jl
Stacktrace:
  [1] check_ir(job::GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams}, args::LLVM.Module)
    @ GPUCompiler C:\Users\xinle\.julia\packages\GPUCompiler\kqxyC\src\validation.jl:147
  [2] macro expansion
    @ C:\Users\xinle\.julia\packages\GPUCompiler\kqxyC\src\driver.jl:445 [inlined]
  [3] macro expansion
    @ C:\Users\xinle\.julia\packages\TimerOutputs\Lw5SP\src\TimerOutput.jl:253 [inlined]
  [4] macro expansion
    @ C:\Users\xinle\.julia\packages\GPUCompiler\kqxyC\src\driver.jl:444 [inlined]
  [5]
    @ GPUCompiler C:\Users\xinle\.julia\packages\GPUCompiler\kqxyC\src\utils.jl:92
  [6] emit_llvm
    @ C:\Users\xinle\.julia\packages\GPUCompiler\kqxyC\src\utils.jl:86 [inlined]
  [7]
    @ GPUCompiler C:\Users\xinle\.julia\packages\GPUCompiler\kqxyC\src\driver.jl:134
  [8] codegen
    @ C:\Users\xinle\.julia\packages\GPUCompiler\kqxyC\src\driver.jl:115 [inlined]
  [9]
    @ GPUCompiler C:\Users\xinle\.julia\packages\GPUCompiler\kqxyC\src\driver.jl:111
 [10] compile
    @ C:\Users\xinle\.julia\packages\GPUCompiler\kqxyC\src\driver.jl:103 [inlined]
 [11] #1116
    @ C:\Users\xinle\.julia\packages\CUDA\B2Z5u\src\compiler\compilation.jl:247 [inlined]
 [12] JuliaContext(f::CUDA.var"#1116#1119"{GPUCompiler.CompilerJob{}}; kwargs::@Kwargs{})
    @ GPUCompiler C:\Users\xinle\.julia\packages\GPUCompiler\kqxyC\src\driver.jl:52
 [13] JuliaContext(f::Function)
    @ GPUCompiler C:\Users\xinle\.julia\packages\GPUCompiler\kqxyC\src\driver.jl:42
 [14] compile(job::GPUCompiler.CompilerJob)
    @ CUDA C:\Users\xinle\.julia\packages\CUDA\B2Z5u\src\compiler\compilation.jl:246
 [15] actual_compilation(cache::Dict{…}, src::Core.MethodInstance, world::UInt64, cfg::GPUCompiler.CompilerConfig{…}, compiler::typeof(CUDA.compile), linker::typeof(CUDA.link))     
    @ GPUCompiler C:\Users\xinle\.julia\packages\GPUCompiler\kqxyC\src\execution.jl:128
 [16] cached_compilation(cache::Dict{…}, src::Core.MethodInstance, cfg::GPUCompiler.CompilerConfig{…}, compiler::Function, linker::Function)
    @ GPUCompiler C:\Users\xinle\.julia\packages\GPUCompiler\kqxyC\src\execution.jl:103
 [17] macro expansion
    @ C:\Users\xinle\.julia\packages\CUDA\B2Z5u\src\compiler\execution.jl:367 [inlined]
 [18] macro expansion
    @ .\lock.jl:267 [inlined]
 [19] cufunction(f::typeof(gpu_nn_pass), tt::Type{Tuple{…}}; kwargs::@Kwargs{always_inline::Bool, maxthreads::Int64})
    @ CUDA C:\Users\xinle\.julia\packages\CUDA\B2Z5u\src\compiler\execution.jl:362
 [20] macro expansion
    @ C:\Users\xinle\.julia\packages\CUDA\B2Z5u\src\compiler\execution.jl:112 [inlined]
 [21] (::KernelAbstractions.Kernel{…})(::CuArray{…}, ::Vararg{…}; ndrange::Nothing, workgroupsize::Nothing)
    @ CUDA.CUDAKernels C:\Users\xinle\.julia\packages\CUDA\B2Z5u\src\CUDAKernels.jl:119
 [22] (::KernelAbstractions.Kernel{…})(::CuArray{…}, ::Vararg{…})
    @ CUDA.CUDAKernels C:\Users\xinle\.julia\packages\CUDA\B2Z5u\src\CUDAKernels.jl:105
 [23] top-level scope
    @ c:\Users\xinle\MIT\NN_Oceanaigans\Oceananigans.jl\test_lux_cuda.jl:25
Some type information was truncated. Use `show(err)` to see complete types.

Hope this might be helpful. Thank you so much!

@avik-pal
Copy link
Member

I don't think writing the kernel out like that is going to work, LuxLib/NNlib uses lots of operations that would be incompatible inside a kernel. I will update your code and add a way on how to do that in the linked PR.

Essentially all that needs to be done is make the individual batches StaticArrays and the parameters as StaticArrays, then they can be used inside the kernel.

@avik-pal
Copy link
Member

using Lux, LuxCUDA, ComponentArrays, Random, StaticArrays
using KernelAbstractions

dev = gpu_device()
rng  = Random.default_rng(123)

nn = Chain(Dense(4, 4, relu), Dense(4, 1))
ps, st = Lux.setup(rng, nn) |> f64

tosarray(x::AbstractArray) = SArray{Tuple{size(x)...}}(x)
ps_static = Lux.recursive_map(tosarray, ps)
st_static = Lux.recursive_map(tosarray, st)

input = [@SArray(rand(Float64, 4, 1)) for i in 1:1024] |> dev;
output = [@SArray(zeros(Float64, 1, 1)) for i in 1:1024] |> dev;

@kernel function nn_pass!(output, model, input, ps, st)
    i = @index(Global, Linear)
    output[i] = first(model(input[i], ps, st))
end

backend = KernelAbstractions.get_backend(output)
loop! = nn_pass!(backend)

loop!(output, nn, input, ps_static, st_static; ndrange=length(input))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants