Skip to content

Commit

Permalink
Merge pull request #359 from JuliaGPU/tb/disallow_scalar_indexing
Browse files Browse the repository at this point in the history
Always disallow scalar indexing.
  • Loading branch information
maleadt authored Jun 10, 2021
2 parents 0e243f5 + d5d5885 commit 25ce681
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 86 deletions.
87 changes: 35 additions & 52 deletions src/host/indexing.jl
Original file line number Diff line number Diff line change
@@ -1,47 +1,34 @@
# host-level indexing

export allowscalar, @allowscalar, @disallowscalar, assertscalar
export allowscalar, @allowscalar, assertscalar


# mechanism to disallow scalar operations

@enum ScalarIndexing ScalarAllowed ScalarWarn ScalarWarned ScalarDisallowed

"""
allowscalar(allow=true, warn=true)
allowscalar(allow=true, warn=true) do end
Configure whether scalar indexing is allowed depending on the value of `allow`.
allowscalar() do
# code that can use scalar indexing
end
If allowed, `warn` can be set to throw a single warning instead. Calling this function will
reset the state of the warning, and throw a new warning on subsequent scalar iteration.
Denote which operations can use scalar indexing.
For temporary changes, use the do-block version, or [`@allowscalar`](@ref).
See also: [`@allowscalar`](@ref).
"""
function allowscalar(allow::Bool=true, warn::Bool=true)
val = if allow && !warn
ScalarAllowed
elseif allow
ScalarWarn
else
ScalarDisallowed
end

task_local_storage(:ScalarIndexing, val)
return
function allowscalar(f::Base.Callable)
task_local_storage(f, :ScalarIndexing, ScalarAllowed)
end

@doc (@doc allowscalar) ->
function allowscalar(f::Base.Callable, allow::Bool=true, warn::Bool=false)
val = if allow && !warn
ScalarAllowed
elseif allow
ScalarWarn
# deprecated
function allowscalar(allow::Bool=true)
if allow
Base.depwarn("allowscalar([true]) is deprecated, use `allowscalar() do end` or `@allowscalar` to denote exactly which operations can use scalar operations.", :allowscalar)
else
ScalarDisallowed
Base.depwarn("allowscalar(false) is deprecated; scalar indexing is now disabled by default.", :allowscalar)
end

task_local_storage(f, :ScalarIndexing, val)
task_local_storage(:ScalarIndexing, allow ? ScalarAllowed : ScalarDisallowed)
return
end

"""
Expand All @@ -52,44 +39,40 @@ error will be thrown ([`allowscalar`](@ref)).
"""
function assertscalar(op = "operation")
val = get!(task_local_storage(), :ScalarIndexing) do
if haskey(ENV, "JULIA_GPU_ALLOWSCALAR")
parse(Bool, ENV["JULIA_GPU_ALLOWSCALAR"]) ? ScalarAllowed : ScalarDisallowed
else
if isinteractive()
ScalarWarn
else
ScalarDisallowed
end
end
desc = """Invocation of $op resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore are only permitted from the REPL for prototyping purposes.
If you did intend to index this array, annotate the caller with @allowscalar."""
if val == ScalarDisallowed
error("$op is disallowed")
error("""Scalar indexing is disallowed.
$desc""")
elseif val == ScalarWarn
@warn "Performing scalar operations on GPU arrays: This is very slow, consider disallowing these operations with `allowscalar(false)`"
@warn("""Performing scalar indexing.
$desc""")
task_local_storage(:ScalarIndexing, ScalarWarned)
end
return
end

"""
@allowscalar ex...
@disallowscalar ex...
allowscalar(::Function, ...)
@allowscalar() begin
# code that can use scalar indexing
end
Temporarily allow or disallow scalar iteration.
Denote which operations can use scalar indexing.
Note that this functionality is intended for functionality that is known and allowed to use
scalar iteration (or not), i.e., there is no option to throw a warning. Only use this on
fine-grained expressions.
See also: [`allowscalar`](@ref).
"""
macro allowscalar(ex)
quote
task_local_storage(:ScalarIndexing, ScalarAllowed) do
$(esc(ex))
end
end
end

@doc (@doc @allowscalar) ->
macro disallowscalar(ex)
quote
task_local_storage(:ScalarIndexing, ScalarDisallowed) do
task_local_storage(:ScalarIndexing, true) do
$(esc(ex))
end
end
Expand All @@ -101,15 +84,15 @@ end
Base.IndexStyle(::Type{<:AbstractGPUArray}) = Base.IndexLinear()

function Base.getindex(xs::AbstractGPUArray{T}, I::Integer...) where T
assertscalar("scalar getindex")
assertscalar("getindex")
i = Base._to_linear_index(xs, I...)
x = Array{T}(undef, 1)
copyto!(x, 1, xs, i, 1)
return x[1]
end

function Base.setindex!(xs::AbstractGPUArray{T}, v::T, I::Integer...) where T
assertscalar("scalar setindex!")
assertscalar("setindex!")
i = Base._to_linear_index(xs, I...)
x = T[v]
copyto!(xs, i, x, 1, 1)
Expand Down
2 changes: 1 addition & 1 deletion test/jlarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ function GPUArrays.gpu_call(::JLBackend, f, args, threads::Int, blocks::Int;
ctx = JLKernelContext(threads, blocks)
device_args = jlconvert.(args)
tasks = Array{Task}(undef, threads)
@disallowscalar for blockidx in 1:blocks
for blockidx in 1:blocks
ctx.blockidx = blockidx
for threadidx in 1:threads
thread_ctx = JLKernelContext(ctx, threadidx)
Expand Down
1 change: 0 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ include("testsuite.jl")

jl([1])

JLArrays.allowscalar(false)
TestSuite.test(JLArray)
end

Expand Down
41 changes: 9 additions & 32 deletions test/testsuite/indexing.jl
Original file line number Diff line number Diff line change
@@ -1,45 +1,22 @@
@testsuite "indexing scalar" AT->begin
AT <: AbstractGPUArray && @allowscalar @testset "errors and warnings" begin
AT <: AbstractGPUArray && @testset "errors and warnings" begin
x = AT([0])

allowscalar(true, false)
x[1] = 1
@test x[1] == 1

@disallowscalar begin
@test_throws ErrorException x[1]
@test_throws ErrorException x[1] = 1
end

x[1] = 2
@test x[1] == 2

allowscalar(false)
@test_throws ErrorException x[1]
@test_throws ErrorException x[1] = 1
@test_throws ErrorException x[]

@allowscalar begin
x[1] = 3
@test x[1] == 3
x[] = 1
@test x[] == 1
end

@test_throws ErrorException x[]

allowscalar() do
x[1] = 4
@test x[1] == 4
x[] = 2
@test x[] == 2
end

@test_throws ErrorException x[1]
@test_throws ErrorException x[1] = 1

allowscalar(true, false)
x[1]

allowscalar(true, true)
@test_logs (:warn, r"Performing scalar operations on GPU arrays: .*") x[1]
@test_logs x[1]

# NOTE: this inner testset _needs_ to be wrapped with allowscalar
# to make sure its original value is restored.
@test_throws ErrorException x[]
end

@allowscalar @testset "getindex with $T" for T in supported_eltypes()
Expand Down

0 comments on commit 25ce681

Please sign in to comment.