Skip to content

Commit

Permalink
Allow scalar iteration in the REPL on 1.9+.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Feb 6, 2023
1 parent 6c1ea64 commit 50bcf01
Showing 1 changed file with 52 additions and 39 deletions.
91 changes: 52 additions & 39 deletions lib/GPUArraysCore/src/GPUArraysCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ const AbstractGPUVecOrMat{T} = Union{AbstractGPUArray{T, 1}, AbstractGPUArray{T,
const WrappedGPUArray{T,N} = WrappedArray{T,N,AbstractGPUArray,AbstractGPUArray{T,N}}
const AnyGPUArray{T,N} = Union{AbstractGPUArray{T,N}, WrappedGPUArray{T,N}}


## broadcasting

"""
Expand All @@ -35,6 +36,7 @@ this supertype.
"""
abstract type AbstractGPUArrayStyle{N} <: Base.Broadcast.AbstractArrayStyle{N} end


## scalar iteration

export allowscalar, @allowscalar, assertscalar
Expand All @@ -45,57 +47,37 @@ export allowscalar, @allowscalar, assertscalar
# XXX: use context variables to inherit the parent task's setting, once available.
const default_scalar_indexing = Ref{Union{Nothing,ScalarIndexing}}(nothing)

"""
allowscalar() do
# code that can use scalar indexing
end
Denote which operations can use scalar indexing.
See also: [`@allowscalar`](@ref).
"""
function allowscalar(f::Base.Callable)
task_local_storage(f, :ScalarIndexing, ScalarAllowed)
end

"""
allowscalar(::Bool)
Calling this with `false` replaces the default warning about scalar indexing
(show once per session) with an error.
Instead of calling this with `true`, the preferred style is to allow this locally.
This can be done with the `allowscalar(::Function)` method (with a `do` block)
or with the [`@allowscalar`](@ref) macro.
Writes to `task_local_storage` for `:ScalarIndexing`. The default is `:ScalarWarn`,
and this function sets `:ScalarAllowed` or `:ScalarDisallowed`.
"""
function allowscalar(allow::Bool=true)
if allow
Base.depwarn("allowscalar([true]) is deprecated, use `allowscalar() do end` or `@allowscalar` to denote exactly which operations can use scalar operations.", :allowscalar)
end
setting = allow ? ScalarAllowed : ScalarDisallowed
task_local_storage(:ScalarIndexing, setting)
default_scalar_indexing[] = setting
return
end

"""
assertscalar(op::String)
Assert that a certain operation `op` performs scalar indexing. If this is not allowed, an
error will be thrown ([`allowscalar`](@ref)).
"""
function assertscalar(op = "operation")
# try to detect the REPL
@static if VERSION >= v"1.10.0-DEV.444" || v"1.9-beta4" <= VERSION < v"1.10-"
if isdefined(Base, :active_repl) && current_task() == Base.active_repl.frontend_task
# we always allow scalar iteration on the REPL's frontend task,
# where we often trigger scalar indexing by displaying GPU objects.
return false
end
default_behavior = ScalarDisallowed
else
# we can't detect the REPL, but it will only be used in interactive sessions,
# so default to allowing scalar indexing there (but warn).
default_behavior = isinteractive() ? ScalarWarn : ScalarDisallowed
end

val = get!(task_local_storage(), :ScalarIndexing) do
something(default_scalar_indexing[], isinteractive() ? ScalarWarn : ScalarDisallowed)
something(default_scalar_indexing[], default_behavior)
end
desc = """Invocation of $op resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore are only permitted from the REPL for prototyping purposes.
If you did intend to index this array, annotate the caller with @allowscalar."""
and therefore should be avoided.
If you want to allow scalar iteration, use `allowscalar` or `@allowscalar`
to enable scalar iteration globally or for the operations in question."""
if val == ScalarDisallowed
error("""Scalar indexing is disallowed.
$desc""")
Expand All @@ -107,6 +89,34 @@ function assertscalar(op = "operation")
return
end

"""
allowscalar([true])
allowscalar([true]) do
...
end
Use this function to allow or disallow scalar indexing, either globall or for the
duration of the do block.
See also: [`@allowscalar`](@ref).
"""
allowscalar

function allowscalar(f::Base.Callable)
task_local_storage(f, :ScalarIndexing, ScalarAllowed)
end

function allowscalar(allow::Bool=true)
if allow
@warn """It's not recommended to use allowscalar([true]) to allow scalar indexing.
Instead, use `allowscalar() do end` or `@allowscalar` to denote exactly which operations can use scalar operations.""" maxlog=1
end
setting = allow ? ScalarAllowed : ScalarDisallowed
task_local_storage(:ScalarIndexing, setting)
default_scalar_indexing[] = setting
return
end

"""
@allowscalar() begin
# code that can use scalar indexing
Expand All @@ -124,6 +134,9 @@ macro allowscalar(ex)
end
end


## other

"""
backend(T::Type)
backend(x)
Expand Down

0 comments on commit 50bcf01

Please sign in to comment.