diff --git a/src/host/indexing.jl b/src/host/indexing.jl index ec497353..af270e1f 100644 --- a/src/host/indexing.jl +++ b/src/host/indexing.jl @@ -1,6 +1,6 @@ # host-level indexing -export allowscalar, @allowscalar, @disallowscalar, assertscalar +export allowscalar, @allowscalar, assertscalar # mechanism to disallow scalar operations @@ -8,40 +8,27 @@ export allowscalar, @allowscalar, @disallowscalar, assertscalar @enum ScalarIndexing ScalarAllowed ScalarWarn ScalarWarned ScalarDisallowed """ - allowscalar(allow=true, warn=true) - allowscalar(allow=true, warn=true) do end - -Configure whether scalar indexing is allowed depending on the value of `allow`. + allowscalar() do + # code that can use scalar indexing + end -If allowed, `warn` can be set to throw a single warning instead. Calling this function will -reset the state of the warning, and throw a new warning on subsequent scalar iteration. +Denote which operations can use scalar indexing. -For temporary changes, use the do-block version, or [`@allowscalar`](@ref). +See also: [`@allowscalar`](@ref). """ -function allowscalar(allow::Bool=true, warn::Bool=true) - val = if allow && !warn - ScalarAllowed - elseif allow - ScalarWarn - else - ScalarDisallowed - end - - task_local_storage(:ScalarIndexing, val) - return +function allowscalar(f::Base.Callable) + task_local_storage(f, :ScalarIndexing, ScalarAllowed) end -@doc (@doc allowscalar) -> -function allowscalar(f::Base.Callable, allow::Bool=true, warn::Bool=false) - val = if allow && !warn - ScalarAllowed - elseif allow - ScalarWarn +# deprecated +function allowscalar(allow::Bool=true) + if allow + Base.depwarn("allowscalar([true]) is deprecated, use `allowscalar() do end` or `@allowscalar` to denote exactly which operations can use scalar operations.", :allowscalar) else - ScalarDisallowed + Base.depwarn("allowscalar(false) is deprecated; scalar indexing is now disabled by default.", :allowscalar) end - - task_local_storage(f, :ScalarIndexing, val) + task_local_storage(:ScalarIndexing, allow ? ScalarAllowed : ScalarDisallowed) + return end """ @@ -52,44 +39,40 @@ error will be thrown ([`allowscalar`](@ref)). """ function assertscalar(op = "operation") val = get!(task_local_storage(), :ScalarIndexing) do - if haskey(ENV, "JULIA_GPU_ALLOWSCALAR") - parse(Bool, ENV["JULIA_GPU_ALLOWSCALAR"]) ? ScalarAllowed : ScalarDisallowed - else + if isinteractive() ScalarWarn + else + ScalarDisallowed end end + desc = """Invocation of $op resulted in scalar indexing of a GPU array. + This is typically caused by calling an iterating implementation of a method. + Such implementations *do not* execute on the GPU, but very slowly on the CPU, + and therefore are only permitted from the REPL for prototyping purposes. + If you did intend to index this array, annotate the caller with @allowscalar.""" if val == ScalarDisallowed - error("$op is disallowed") + error("""Scalar indexing is disallowed. + $desc""") elseif val == ScalarWarn - @warn "Performing scalar operations on GPU arrays: This is very slow, consider disallowing these operations with `allowscalar(false)`" + @warn("""Performing scalar indexing. + $desc""") task_local_storage(:ScalarIndexing, ScalarWarned) end return end """ - @allowscalar ex... - @disallowscalar ex... - allowscalar(::Function, ...) + @allowscalar() begin + # code that can use scalar indexing + end -Temporarily allow or disallow scalar iteration. +Denote which operations can use scalar indexing. -Note that this functionality is intended for functionality that is known and allowed to use -scalar iteration (or not), i.e., there is no option to throw a warning. Only use this on -fine-grained expressions. +See also: [`allowscalar`](@ref). """ macro allowscalar(ex) quote - task_local_storage(:ScalarIndexing, ScalarAllowed) do - $(esc(ex)) - end - end -end - -@doc (@doc @allowscalar) -> -macro disallowscalar(ex) - quote - task_local_storage(:ScalarIndexing, ScalarDisallowed) do + task_local_storage(:ScalarIndexing, true) do $(esc(ex)) end end @@ -101,7 +84,7 @@ end Base.IndexStyle(::Type{<:AbstractGPUArray}) = Base.IndexLinear() function Base.getindex(xs::AbstractGPUArray{T}, I::Integer...) where T - assertscalar("scalar getindex") + assertscalar("getindex") i = Base._to_linear_index(xs, I...) x = Array{T}(undef, 1) copyto!(x, 1, xs, i, 1) @@ -109,7 +92,7 @@ function Base.getindex(xs::AbstractGPUArray{T}, I::Integer...) where T end function Base.setindex!(xs::AbstractGPUArray{T}, v::T, I::Integer...) where T - assertscalar("scalar setindex!") + assertscalar("setindex!") i = Base._to_linear_index(xs, I...) x = T[v] copyto!(xs, i, x, 1, 1) diff --git a/test/jlarray.jl b/test/jlarray.jl index 0193008d..be148273 100644 --- a/test/jlarray.jl +++ b/test/jlarray.jl @@ -65,7 +65,7 @@ function GPUArrays.gpu_call(::JLBackend, f, args, threads::Int, blocks::Int; ctx = JLKernelContext(threads, blocks) device_args = jlconvert.(args) tasks = Array{Task}(undef, threads) - @disallowscalar for blockidx in 1:blocks + for blockidx in 1:blocks ctx.blockidx = blockidx for threadidx in 1:threads thread_ctx = JLKernelContext(ctx, threadidx) diff --git a/test/runtests.jl b/test/runtests.jl index d097ad42..d2b1dd1e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,7 +8,6 @@ include("testsuite.jl") jl([1]) - JLArrays.allowscalar(false) TestSuite.test(JLArray) end diff --git a/test/testsuite/indexing.jl b/test/testsuite/indexing.jl index 052e1e47..b886e131 100644 --- a/test/testsuite/indexing.jl +++ b/test/testsuite/indexing.jl @@ -1,45 +1,22 @@ @testsuite "indexing scalar" AT->begin - AT <: AbstractGPUArray && @allowscalar @testset "errors and warnings" begin + AT <: AbstractGPUArray && @testset "errors and warnings" begin x = AT([0]) - allowscalar(true, false) - x[1] = 1 - @test x[1] == 1 - - @disallowscalar begin - @test_throws ErrorException x[1] - @test_throws ErrorException x[1] = 1 - end - - x[1] = 2 - @test x[1] == 2 - - allowscalar(false) - @test_throws ErrorException x[1] - @test_throws ErrorException x[1] = 1 + @test_throws ErrorException x[] @allowscalar begin - x[1] = 3 - @test x[1] == 3 + x[] = 1 + @test x[] == 1 end + @test_throws ErrorException x[] + allowscalar() do - x[1] = 4 - @test x[1] == 4 + x[] = 2 + @test x[] == 2 end - @test_throws ErrorException x[1] - @test_throws ErrorException x[1] = 1 - - allowscalar(true, false) - x[1] - - allowscalar(true, true) - @test_logs (:warn, r"Performing scalar operations on GPU arrays: .*") x[1] - @test_logs x[1] - - # NOTE: this inner testset _needs_ to be wrapped with allowscalar - # to make sure its original value is restored. + @test_throws ErrorException x[] end @allowscalar @testset "getindex with $T" for T in supported_eltypes()