-
Notifications
You must be signed in to change notification settings - Fork 58
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Inference using NN :: Chain
inside a GPU kernel
#720
Comments
That's an interesting usecase I hadn't thought of. I can add concrete type for |
Here's a simple test script: using Lux, LuxCUDA, ComponentArrays, Random, CUDA
using KernelAbstractions
dev = gpu_device()
rng = Random.default_rng(123)
nn = Chain(Dense(4, 4, relu), Dense(4, 1))
ps, st = Lux.setup(rng, nn)
ps = ps |> ComponentArray |> dev .|> Float64
st = st |> dev
input = randn(4, 100) |> dev .|> Float64
outp = nn(input, ps, st)
@kernel function nn_pass(input, output, nn, ps, st)
i = @index(Global, Linear)
output[i] = nn(input[:, i], ps, st)[1][1]
end
output = zeros(100) |> dev .|> Float64
loop! = nn_pass(CUDA.CUDABackend(), 100, 100)
loop!(input, output, nn, ps, st) which gives an error of ERROR: InvalidIRError: compiling MethodInstance for gpu_nn_pass(::KernelAbstractions.CompilerMetadata{…}, ::CuDeviceMatrix{…}, ::CuDeviceVector{…}, ::Chain{…}, ::ComponentVector{…}, ::@NamedTuple{…}) resulted in invalid LLVM IR
Reason: unsupported call to an unknown function (call to julia.new_gc_frame)
Stacktrace:
[1] __matmuladd
@ C:\Users\xinle\.julia\packages\LuxLib\Q3elb\src\impl\fused_dense.jl:6
[2] __fused_dense_bias_activation_impl
@ C:\Users\xinle\.julia\packages\LuxLib\Q3elb\src\impl\fused_dense.jl:26
[3] fused_dense_bias_activation
@ C:\Users\xinle\.julia\packages\LuxLib\Q3elb\src\api\dense.jl:46
[4] fused_dense_bias_activation
@ C:\Users\xinle\.julia\packages\LuxLib\Q3elb\src\api\dense.jl:38
[5] Dense
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\basic.jl:357
[6] Dense
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\basic.jl:353
[7] apply
@ C:\Users\xinle\.julia\packages\LuxCore\qiHPC\src\LuxCore.jl:179
[8] macro expansion
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:0
[9] applychain
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:498
[10] Chain
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:496
[11] macro expansion
@ c:\Users\xinle\MIT\NN_Oceanaigans\Oceananigans.jl\test_lux_cuda.jl:19
[12] gpu_nn_pass
@ C:\Users\xinle\.julia\packages\KernelAbstractions\zPAn3\src\macros.jl:95
[13] gpu_nn_pass
@ .\none:0
Reason: unsupported call to an unknown function (call to julia.push_gc_frame)
Stacktrace:
[1] __matmuladd
@ C:\Users\xinle\.julia\packages\LuxLib\Q3elb\src\impl\fused_dense.jl:6
[2] __fused_dense_bias_activation_impl
@ C:\Users\xinle\.julia\packages\LuxLib\Q3elb\src\impl\fused_dense.jl:26
[3] fused_dense_bias_activation
@ C:\Users\xinle\.julia\packages\LuxLib\Q3elb\src\api\dense.jl:46
[4] fused_dense_bias_activation
@ C:\Users\xinle\.julia\packages\LuxLib\Q3elb\src\api\dense.jl:38
[5] Dense
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\basic.jl:357
[6] Dense
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\basic.jl:353
[7] apply
@ C:\Users\xinle\.julia\packages\LuxCore\qiHPC\src\LuxCore.jl:179
[8] macro expansion
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:0
[9] applychain
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:498
[10] Chain
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:496
[11] macro expansion
@ c:\Users\xinle\MIT\NN_Oceanaigans\Oceananigans.jl\test_lux_cuda.jl:19
[12] gpu_nn_pass
@ C:\Users\xinle\.julia\packages\KernelAbstractions\zPAn3\src\macros.jl:95
[13] gpu_nn_pass
@ .\none:0
Reason: unsupported call through a literal pointer (call to ijl_alloc_array_2d)
Stacktrace:
[1] Array
@ .\boot.jl:479
[2] Array
@ .\boot.jl:487
[3] similar
@ .\abstractarray.jl:842
[4] similar
@ .\subarray.jl:65
[5] similar
@ .\reshapedarray.jl:209
[6] similar
@ .\abstractarray.jl:833
[7] muladd
@ C:\Users\xinle\AppData\Local\Programs\Julia-1.10.1\share\julia\stdlib\v1.10\LinearAlgebra\src\matmul.jl:209
[8] __matmuladd
@ C:\Users\xinle\.julia\packages\LuxLib\Q3elb\src\impl\fused_dense.jl:6
[9] __fused_dense_bias_activation_impl
@ C:\Users\xinle\.julia\packages\LuxLib\Q3elb\src\impl\fused_dense.jl:26
[10] fused_dense_bias_activation
@ C:\Users\xinle\.julia\packages\LuxLib\Q3elb\src\api\dense.jl:46
[11] fused_dense_bias_activation
@ C:\Users\xinle\.julia\packages\LuxLib\Q3elb\src\api\dense.jl:38
[12] Dense
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\basic.jl:357
[13] Dense
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\basic.jl:353
[14] apply
@ C:\Users\xinle\.julia\packages\LuxCore\qiHPC\src\LuxCore.jl:179
[15] macro expansion
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:0
[16] applychain
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:498
[17] Chain
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:496
[18] macro expansion
@ c:\Users\xinle\MIT\NN_Oceanaigans\Oceananigans.jl\test_lux_cuda.jl:19
[19] gpu_nn_pass
@ C:\Users\xinle\.julia\packages\KernelAbstractions\zPAn3\src\macros.jl:95
[20] gpu_nn_pass
@ .\none:0
Reason: unsupported call to an unknown function (call to julia.get_gc_frame_slot)
Stacktrace:
[1] __matmuladd
@ C:\Users\xinle\.julia\packages\LuxLib\Q3elb\src\impl\fused_dense.jl:6
[2] __fused_dense_bias_activation_impl
@ C:\Users\xinle\.julia\packages\LuxLib\Q3elb\src\impl\fused_dense.jl:26
[3] fused_dense_bias_activation
@ C:\Users\xinle\.julia\packages\LuxLib\Q3elb\src\api\dense.jl:46
[4] fused_dense_bias_activation
@ C:\Users\xinle\.julia\packages\LuxLib\Q3elb\src\api\dense.jl:38
[5] Dense
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\basic.jl:357
[6] Dense
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\basic.jl:353
[7] apply
@ C:\Users\xinle\.julia\packages\LuxCore\qiHPC\src\LuxCore.jl:179
[8] macro expansion
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:0
[9] applychain
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:498
[10] Chain
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:496
[11] macro expansion
@ c:\Users\xinle\MIT\NN_Oceanaigans\Oceananigans.jl\test_lux_cuda.jl:19
[12] gpu_nn_pass
@ C:\Users\xinle\.julia\packages\KernelAbstractions\zPAn3\src\macros.jl:95
[13] gpu_nn_pass
@ .\none:0
Reason: unsupported call to an unknown function (call to julia.pop_gc_frame)
Stacktrace:
[1] __matmuladd
@ C:\Users\xinle\.julia\packages\LuxLib\Q3elb\src\impl\fused_dense.jl:6
[2] __fused_dense_bias_activation_impl
@ C:\Users\xinle\.julia\packages\LuxLib\Q3elb\src\impl\fused_dense.jl:26
[3] fused_dense_bias_activation
@ C:\Users\xinle\.julia\packages\LuxLib\Q3elb\src\api\dense.jl:46
[4] fused_dense_bias_activation
@ C:\Users\xinle\.julia\packages\LuxLib\Q3elb\src\api\dense.jl:38
[5] Dense
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\basic.jl:357
[6] Dense
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\basic.jl:353
[7] apply
@ C:\Users\xinle\.julia\packages\LuxCore\qiHPC\src\LuxCore.jl:179
[8] macro expansion
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:0
[9] applychain
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:498
[10] Chain
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:496
[11] macro expansion
@ c:\Users\xinle\MIT\NN_Oceanaigans\Oceananigans.jl\test_lux_cuda.jl:19
[12] gpu_nn_pass
@ C:\Users\xinle\.julia\packages\KernelAbstractions\zPAn3\src\macros.jl:95
[13] gpu_nn_pass
@ .\none:0
Reason: unsupported call to an unknown function (call to julia.new_gc_frame)
Reason: unsupported call to an unknown function (call to julia.push_gc_frame)
Reason: unsupported call to an unknown function (call to julia.get_gc_frame_slot)
Reason: unsupported call to an unknown function (call to julia.pop_gc_frame)
Reason: unsupported call to an unknown function (call to julia.new_gc_frame)
Stacktrace:
[1] GenericIOBuffer
@ .\iobuffer.jl:106
[2] print_to_string
@ .\strings\io.jl:146
[3] multiple call sites
@ unknown:0
Reason: unsupported call to an unknown function (call to julia.push_gc_frame)
Stacktrace:
[1] GenericIOBuffer
@ .\iobuffer.jl:106
[2] print_to_string
@ .\strings\io.jl:146
[3] multiple call sites
@ unknown:0
Reason: unsupported call through a literal pointer (call to ijl_alloc_string)
Stacktrace:
[1] _string_n
@ .\strings\string.jl:90
[2] StringVector
@ .\iobuffer.jl:32
[3] #IOBuffer#469
@ .\iobuffer.jl:115
[4] GenericIOBuffer
@ .\iobuffer.jl:106
[5] print_to_string
@ .\strings\io.jl:146
[6] multiple call sites
@ unknown:0
Reason: unsupported call to an unknown function (call to julia.get_gc_frame_slot)
Stacktrace:
[1] GenericIOBuffer
@ .\iobuffer.jl:106
[2] print_to_string
@ .\strings\io.jl:146
[3] multiple call sites
@ unknown:0
Reason: unsupported call through a literal pointer (call to ijl_string_to_array)
Stacktrace:
[1] unsafe_wrap
@ .\strings\string.jl:100
[2] StringVector
@ .\iobuffer.jl:32
[3] #IOBuffer#469
@ .\iobuffer.jl:115
[4] GenericIOBuffer
@ .\iobuffer.jl:106
[5] print_to_string
@ .\strings\io.jl:146
[6] multiple call sites
@ unknown:0
Reason: unsupported call to an unknown function (call to julia.pop_gc_frame)
Stacktrace:
[1] GenericIOBuffer
@ .\iobuffer.jl:106
[2] print_to_string
@ .\strings\io.jl:146
[3] multiple call sites
@ unknown:0
Reason: unsupported call to an unknown function (call to julia.new_gc_frame)
Stacktrace:
[1] _getindex
@ .\multidimensional.jl:889
[2] getindex
@ .\abstractarray.jl:1291
[3] macro expansion
@ c:\Users\xinle\MIT\NN_Oceanaigans\Oceananigans.jl\test_lux_cuda.jl:19
[4] gpu_nn_pass
@ C:\Users\xinle\.julia\packages\KernelAbstractions\zPAn3\src\macros.jl:95
[5] gpu_nn_pass
@ .\none:0
Reason: unsupported call to an unknown function (call to julia.push_gc_frame)
Stacktrace:
[1] _getindex
@ .\multidimensional.jl:889
[2] getindex
@ .\abstractarray.jl:1291
[3] macro expansion
@ c:\Users\xinle\MIT\NN_Oceanaigans\Oceananigans.jl\test_lux_cuda.jl:19
[4] gpu_nn_pass
@ C:\Users\xinle\.julia\packages\KernelAbstractions\zPAn3\src\macros.jl:95
[5] gpu_nn_pass
@ .\none:0
Reason: unsupported call through a literal pointer (call to ijl_alloc_array_1d)
Stacktrace:
[1] Array
@ .\boot.jl:477
[2] Array
@ .\boot.jl:486
[3] similar
@ .\abstractarray.jl:842
[4] similar
@ .\abstractarray.jl:831
[5] _unsafe_getindex
@ .\multidimensional.jl:901
[6] _getindex
@ .\multidimensional.jl:889
[7] getindex
@ .\abstractarray.jl:1291
[8] macro expansion
@ c:\Users\xinle\MIT\NN_Oceanaigans\Oceananigans.jl\test_lux_cuda.jl:19
[9] gpu_nn_pass
@ C:\Users\xinle\.julia\packages\KernelAbstractions\zPAn3\src\macros.jl:95
[10] gpu_nn_pass
@ .\none:0
Reason: unsupported call to an unknown function (call to julia.pop_gc_frame)
Stacktrace:
[1] _getindex
@ .\multidimensional.jl:889
[2] getindex
@ .\abstractarray.jl:1291
[3] macro expansion
@ c:\Users\xinle\MIT\NN_Oceanaigans\Oceananigans.jl\test_lux_cuda.jl:19
[4] gpu_nn_pass
@ C:\Users\xinle\.julia\packages\KernelAbstractions\zPAn3\src\macros.jl:95
[5] gpu_nn_pass
@ .\none:0
Reason: unsupported call to an unknown function (call to julia.get_gc_frame_slot)
Stacktrace:
[1] _getindex
@ .\multidimensional.jl:889
[2] getindex
@ .\abstractarray.jl:1291
[3] macro expansion
@ c:\Users\xinle\MIT\NN_Oceanaigans\Oceananigans.jl\test_lux_cuda.jl:19
[4] gpu_nn_pass
@ C:\Users\xinle\.julia\packages\KernelAbstractions\zPAn3\src\macros.jl:95
[5] gpu_nn_pass
@ .\none:0
Reason: unsupported call to an unknown function (call to julia.new_gc_frame)
Stacktrace:
[1] macro expansion
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:0
[2] applychain
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:498
[3] Chain
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:496
[4] macro expansion
@ c:\Users\xinle\MIT\NN_Oceanaigans\Oceananigans.jl\test_lux_cuda.jl:19
[5] gpu_nn_pass
@ C:\Users\xinle\.julia\packages\KernelAbstractions\zPAn3\src\macros.jl:95
[6] gpu_nn_pass
@ .\none:0
Reason: unsupported call to an unknown function (call to julia.push_gc_frame)
Stacktrace:
[1] macro expansion
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:0
[2] applychain
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:498
[3] Chain
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:496
[4] macro expansion
@ c:\Users\xinle\MIT\NN_Oceanaigans\Oceananigans.jl\test_lux_cuda.jl:19
[5] gpu_nn_pass
@ C:\Users\xinle\.julia\packages\KernelAbstractions\zPAn3\src\macros.jl:95
[6] gpu_nn_pass
@ .\none:0
Reason: unsupported call to an unknown function (call to julia.get_gc_frame_slot)
Stacktrace:
[1] macro expansion
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:0
[2] applychain
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:498
[3] Chain
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:496
[4] macro expansion
@ c:\Users\xinle\MIT\NN_Oceanaigans\Oceananigans.jl\test_lux_cuda.jl:19
[5] gpu_nn_pass
@ C:\Users\xinle\.julia\packages\KernelAbstractions\zPAn3\src\macros.jl:95
[6] gpu_nn_pass
@ .\none:0
Reason: unsupported call through a literal pointer (call to ijl_reshape_array)
Stacktrace:
[1] reshape
@ .\reshapedarray.jl:51
[2] reshape
@ .\reshapedarray.jl:119
[3] reshape
@ .\reshapedarray.jl:118
[4] Dense
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\basic.jl:353
[5] apply
@ C:\Users\xinle\.julia\packages\LuxCore\qiHPC\src\LuxCore.jl:179
[6] macro expansion
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:0
[7] applychain
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:498
[8] Chain
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:496
[9] macro expansion
@ c:\Users\xinle\MIT\NN_Oceanaigans\Oceananigans.jl\test_lux_cuda.jl:19
[10] gpu_nn_pass
@ C:\Users\xinle\.julia\packages\KernelAbstractions\zPAn3\src\macros.jl:95
[11] gpu_nn_pass
@ .\none:0
Reason: unsupported call through a literal pointer (call to ijl_reshape_array)
Stacktrace:
[1] reshape
@ .\reshapedarray.jl:51
[2] reshape
@ .\reshapedarray.jl:117
[3] vec
@ .\abstractarraymath.jl:41
[4] Dense
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\basic.jl:353
[5] apply
@ C:\Users\xinle\.julia\packages\LuxCore\qiHPC\src\LuxCore.jl:179
[6] macro expansion
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:0
[7] applychain
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:498
[8] Chain
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:496
[9] macro expansion
@ c:\Users\xinle\MIT\NN_Oceanaigans\Oceananigans.jl\test_lux_cuda.jl:19
[10] gpu_nn_pass
@ C:\Users\xinle\.julia\packages\KernelAbstractions\zPAn3\src\macros.jl:95
[11] gpu_nn_pass
@ .\none:0
Reason: unsupported call to an unknown function (call to julia.pop_gc_frame)
Stacktrace:
[1] macro expansion
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:0
[2] applychain
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:498
[3] Chain
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:496
[4] macro expansion
@ c:\Users\xinle\MIT\NN_Oceanaigans\Oceananigans.jl\test_lux_cuda.jl:19
[5] gpu_nn_pass
@ C:\Users\xinle\.julia\packages\KernelAbstractions\zPAn3\src\macros.jl:95
[6] gpu_nn_pass
@ .\none:0
Reason: unsupported call to an unknown function (call to julia.new_gc_frame)
Stacktrace:
[1] unsafe_convert
@ .\subarray.jl:470
[2] unsafe_convert
@ .\reshapedarray.jl:296
[3] gemm!
@ C:\Users\xinle\AppData\Local\Programs\Julia-1.10.1\share\julia\stdlib\v1.10\LinearAlgebra\src\blas.jl:1524
[4] multiple call sites
@ unknown:0
Reason: unsupported call to an unknown function (call to julia.push_gc_frame)
Stacktrace:
[1] unsafe_convert
@ .\subarray.jl:470
[2] unsafe_convert
@ .\reshapedarray.jl:296
[3] gemm!
@ C:\Users\xinle\AppData\Local\Programs\Julia-1.10.1\share\julia\stdlib\v1.10\LinearAlgebra\src\blas.jl:1524
[4] multiple call sites
@ unknown:0
Reason: unsupported call to an unknown function (call to julia.get_gc_frame_slot)
Stacktrace:
[1] unsafe_convert
@ .\subarray.jl:470
[2] unsafe_convert
@ .\reshapedarray.jl:296
[3] gemm!
@ C:\Users\xinle\AppData\Local\Programs\Julia-1.10.1\share\julia\stdlib\v1.10\LinearAlgebra\src\blas.jl:1524
[4] multiple call sites
@ unknown:0
Reason: unsupported call to an unknown function (call to julia.new_gc_frame)
Stacktrace:
[1] multiple call sites
@ unknown:0
Reason: unsupported call to an unknown function (call to julia.push_gc_frame)
Stacktrace:
[1] multiple call sites
@ unknown:0
Reason: unsupported call to an unknown function (call to julia.get_gc_frame_slot)
Stacktrace:
[1] multiple call sites
@ unknown:0
Reason: unsupported call through a literal pointer (call to ijl_alloc_array_2d)
Stacktrace:
[1] Array
@ .\boot.jl:479
[2] Array
@ .\boot.jl:487
[3] similar
@ .\abstractarray.jl:842
[4] similar
@ .\subarray.jl:65
[5] similar
@ .\reshapedarray.jl:209
[6] similar
@ .\abstractarray.jl:833
[7] __fused_dense_bias_activation_impl
@ C:\Users\xinle\.julia\packages\LuxLib\Q3elb\src\impl\fused_dense.jl:28
[8] fused_dense_bias_activation
@ C:\Users\xinle\.julia\packages\LuxLib\Q3elb\src\api\dense.jl:46
[9] fused_dense_bias_activation
@ C:\Users\xinle\.julia\packages\LuxLib\Q3elb\src\api\dense.jl:38
[10] Dense
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\basic.jl:357
[11] Dense
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\basic.jl:353
[12] apply
@ C:\Users\xinle\.julia\packages\LuxCore\qiHPC\src\LuxCore.jl:179
[13] macro expansion
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:0
[14] applychain
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:498
[15] Chain
@ C:\Users\xinle\.julia\packages\Lux\FMcuc\src\layers\containers.jl:496
[16] macro expansion
@ c:\Users\xinle\MIT\NN_Oceanaigans\Oceananigans.jl\test_lux_cuda.jl:19
[17] gpu_nn_pass
@ C:\Users\xinle\.julia\packages\KernelAbstractions\zPAn3\src\macros.jl:95
[18] gpu_nn_pass
@ .\none:0
Reason: unsupported call to an unknown function (call to julia.new_gc_frame)
Stacktrace:
[1] _unsafe_getindex
@ .\multidimensional.jl:902
[2] _getindex
@ .\multidimensional.jl:889
[3] getindex
@ .\abstractarray.jl:1291
[4] macro expansion
@ c:\Users\xinle\MIT\NN_Oceanaigans\Oceananigans.jl\test_lux_cuda.jl:19
[5] gpu_nn_pass
@ C:\Users\xinle\.julia\packages\KernelAbstractions\zPAn3\src\macros.jl:95
[6] gpu_nn_pass
@ .\none:0
Reason: unsupported call to an unknown function (call to julia.push_gc_frame)
Stacktrace:
[1] _unsafe_getindex
@ .\multidimensional.jl:902
[2] _getindex
@ .\multidimensional.jl:889
[3] getindex
@ .\abstractarray.jl:1291
[4] macro expansion
@ c:\Users\xinle\MIT\NN_Oceanaigans\Oceananigans.jl\test_lux_cuda.jl:19
[5] gpu_nn_pass
@ C:\Users\xinle\.julia\packages\KernelAbstractions\zPAn3\src\macros.jl:95
[6] gpu_nn_pass
@ .\none:0
Reason: unsupported call to an unknown function (call to julia.get_gc_frame_slot)
Stacktrace:
[1] _unsafe_getindex
@ .\multidimensional.jl:902
[2] _getindex
@ .\multidimensional.jl:889
[3] getindex
@ .\abstractarray.jl:1291
[4] macro expansion
@ c:\Users\xinle\MIT\NN_Oceanaigans\Oceananigans.jl\test_lux_cuda.jl:19
[5] gpu_nn_pass
@ C:\Users\xinle\.julia\packages\KernelAbstractions\zPAn3\src\macros.jl:95
[6] gpu_nn_pass
@ .\none:0
Reason: unsupported dynamic function invocation (call to _str_sizehint)
Stacktrace:
[1] print_to_string
@ .\strings\io.jl:143
[2] multiple call sites
@ unknown:0
Reason: unsupported dynamic function invocation (call to print)
Stacktrace:
[1] print_to_string
@ .\strings\io.jl:148
[2] multiple call sites
@ unknown:0
Reason: unsupported call through a literal pointer (call to ijl_array_grow_end)
Stacktrace:
[1] _growend!
@ .\array.jl:1072
[2] resize!
@ .\array.jl:1315
[3] _unsafe_take!
@ .\iobuffer.jl:445
[4] print_to_string
@ .\strings\io.jl:150
[5] multiple call sites
@ unknown:0
Reason: unsupported call through a literal pointer (call to ijl_array_del_end)
Stacktrace:
[1] _deleteend!
@ .\array.jl:1081
[2] resize!
@ .\array.jl:1320
[3] _unsafe_take!
@ .\iobuffer.jl:445
[4] print_to_string
@ .\strings\io.jl:150
[5] multiple call sites
@ unknown:0
Reason: unsupported call through a literal pointer (call to ijl_array_to_string)
Stacktrace:
[1] String
@ .\strings\string.jl:67
[2] print_to_string
@ .\strings\io.jl:150
[3] multiple call sites
@ unknown:0
Reason: unsupported call to an unknown function (call to julia.pop_gc_frame)
Stacktrace:
[1] multiple call sites
@ unknown:0
Hint: catch this exception as `err` and call `code_typed(err; interactive = true)` to introspect the erronous code with Cthulhu.jl
Stacktrace:
[1] check_ir(job::GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams}, args::LLVM.Module)
@ GPUCompiler C:\Users\xinle\.julia\packages\GPUCompiler\kqxyC\src\validation.jl:147
[2] macro expansion
@ C:\Users\xinle\.julia\packages\GPUCompiler\kqxyC\src\driver.jl:445 [inlined]
[3] macro expansion
@ C:\Users\xinle\.julia\packages\TimerOutputs\Lw5SP\src\TimerOutput.jl:253 [inlined]
[4] macro expansion
@ C:\Users\xinle\.julia\packages\GPUCompiler\kqxyC\src\driver.jl:444 [inlined]
[5]
@ GPUCompiler C:\Users\xinle\.julia\packages\GPUCompiler\kqxyC\src\utils.jl:92
[6] emit_llvm
@ C:\Users\xinle\.julia\packages\GPUCompiler\kqxyC\src\utils.jl:86 [inlined]
[7]
@ GPUCompiler C:\Users\xinle\.julia\packages\GPUCompiler\kqxyC\src\driver.jl:134
[8] codegen
@ C:\Users\xinle\.julia\packages\GPUCompiler\kqxyC\src\driver.jl:115 [inlined]
[9]
@ GPUCompiler C:\Users\xinle\.julia\packages\GPUCompiler\kqxyC\src\driver.jl:111
[10] compile
@ C:\Users\xinle\.julia\packages\GPUCompiler\kqxyC\src\driver.jl:103 [inlined]
[11] #1116
@ C:\Users\xinle\.julia\packages\CUDA\B2Z5u\src\compiler\compilation.jl:247 [inlined]
[12] JuliaContext(f::CUDA.var"#1116#1119"{GPUCompiler.CompilerJob{…}}; kwargs::@Kwargs{})
@ GPUCompiler C:\Users\xinle\.julia\packages\GPUCompiler\kqxyC\src\driver.jl:52
[13] JuliaContext(f::Function)
@ GPUCompiler C:\Users\xinle\.julia\packages\GPUCompiler\kqxyC\src\driver.jl:42
[14] compile(job::GPUCompiler.CompilerJob)
@ CUDA C:\Users\xinle\.julia\packages\CUDA\B2Z5u\src\compiler\compilation.jl:246
[15] actual_compilation(cache::Dict{…}, src::Core.MethodInstance, world::UInt64, cfg::GPUCompiler.CompilerConfig{…}, compiler::typeof(CUDA.compile), linker::typeof(CUDA.link))
@ GPUCompiler C:\Users\xinle\.julia\packages\GPUCompiler\kqxyC\src\execution.jl:128
[16] cached_compilation(cache::Dict{…}, src::Core.MethodInstance, cfg::GPUCompiler.CompilerConfig{…}, compiler::Function, linker::Function)
@ GPUCompiler C:\Users\xinle\.julia\packages\GPUCompiler\kqxyC\src\execution.jl:103
[17] macro expansion
@ C:\Users\xinle\.julia\packages\CUDA\B2Z5u\src\compiler\execution.jl:367 [inlined]
[18] macro expansion
@ .\lock.jl:267 [inlined]
[19] cufunction(f::typeof(gpu_nn_pass), tt::Type{Tuple{…}}; kwargs::@Kwargs{always_inline::Bool, maxthreads::Int64})
@ CUDA C:\Users\xinle\.julia\packages\CUDA\B2Z5u\src\compiler\execution.jl:362
[20] macro expansion
@ C:\Users\xinle\.julia\packages\CUDA\B2Z5u\src\compiler\execution.jl:112 [inlined]
[21] (::KernelAbstractions.Kernel{…})(::CuArray{…}, ::Vararg{…}; ndrange::Nothing, workgroupsize::Nothing)
@ CUDA.CUDAKernels C:\Users\xinle\.julia\packages\CUDA\B2Z5u\src\CUDAKernels.jl:119
[22] (::KernelAbstractions.Kernel{…})(::CuArray{…}, ::Vararg{…})
@ CUDA.CUDAKernels C:\Users\xinle\.julia\packages\CUDA\B2Z5u\src\CUDAKernels.jl:105
[23] top-level scope
@ c:\Users\xinle\MIT\NN_Oceanaigans\Oceananigans.jl\test_lux_cuda.jl:25
Some type information was truncated. Use `show(err)` to see complete types. Hope this might be helpful. Thank you so much! |
I don't think writing the kernel out like that is going to work, LuxLib/NNlib uses lots of operations that would be incompatible inside a kernel. I will update your code and add a way on how to do that in the linked PR. Essentially all that needs to be done is make the individual batches StaticArrays and the parameters as StaticArrays, then they can be used inside the kernel. |
using Lux, LuxCUDA, ComponentArrays, Random, StaticArrays
using KernelAbstractions
dev = gpu_device()
rng = Random.default_rng(123)
nn = Chain(Dense(4, 4, relu), Dense(4, 1))
ps, st = Lux.setup(rng, nn) |> f64
tosarray(x::AbstractArray) = SArray{Tuple{size(x)...}}(x)
ps_static = Lux.recursive_map(tosarray, ps)
st_static = Lux.recursive_map(tosarray, st)
input = [@SArray(rand(Float64, 4, 1)) for i in 1:1024] |> dev;
output = [@SArray(zeros(Float64, 1, 1)) for i in 1:1024] |> dev;
@kernel function nn_pass!(output, model, input, ps, st)
i = @index(Global, Linear)
output[i] = first(model(input[i], ps, st))
end
backend = KernelAbstractions.get_backend(output)
loop! = nn_pass!(backend)
loop!(output, nn, input, ps_static, st_static; ndrange=length(input)) |
Hello, with @simone-silvestri, we are trying to compute the output of a neural network inside a GPU kernel. Therefore, we would need the NN to be passed to the gpu kernel but we are having an issue with
adapt
ing the NN to the GPU.In particular, we are writing a kernel using
KernelAbstractions.jl
that passes the NN as an input to the kernel and evaluates the NN's output pointwise inside the kernel. When we try to launch the kernel we get the following errorWe suspect this issue stems from the fact that
NAME_TYPE
(the type of model.name) is forced to beUnion{Nothing, String, Symbol}
, where Symbol is not an isbits type, preventing the NN to be passed to the GPU.Does anyone have any suggestion to solve this problem?
The text was updated successfully, but these errors were encountered: