diff --git a/src/host/indexing.jl b/src/host/indexing.jl index 0e6e0bdb..3d136b66 100644 --- a/src/host/indexing.jl +++ b/src/host/indexing.jl @@ -7,6 +7,10 @@ export allowscalar, @allowscalar, assertscalar @enum ScalarIndexing ScalarAllowed ScalarWarn ScalarWarned ScalarDisallowed +# if the user explicitly calls allowscalar, use that setting for all new tasks +# XXX: use context variables to inherit the parent task's setting, once available. +const default_scalar_indexing = Ref{Union{Nothing,ScalarIndexing}}(nothing) + """ allowscalar() do # code that can use scalar indexing @@ -24,7 +28,9 @@ function allowscalar(allow::Bool=true) if allow Base.depwarn("allowscalar([true]) is deprecated, use `allowscalar() do end` or `@allowscalar` to denote exactly which operations can use scalar operations.", :allowscalar) end - task_local_storage(:ScalarIndexing, allow ? ScalarAllowed : ScalarDisallowed) + setting = allow ? ScalarAllowed : ScalarDisallowed + task_local_storage(:ScalarIndexing, setting) + default_scalar_indexing[] = setting return end @@ -36,11 +42,7 @@ error will be thrown ([`allowscalar`](@ref)). """ function assertscalar(op = "operation") val = get!(task_local_storage(), :ScalarIndexing) do - if isinteractive() - ScalarWarn - else - ScalarDisallowed - end + something(default_scalar_indexing[], isinteractive() ? ScalarWarn : ScalarDisallowed) end desc = """Invocation of $op resulted in scalar indexing of a GPU array. This is typically caused by calling an iterating implementation of a method. @@ -51,7 +53,7 @@ function assertscalar(op = "operation") error("""Scalar indexing is disallowed. $desc""") elseif val == ScalarWarn - @warn("""Performing scalar indexing. + @warn("""Performing scalar indexing on task $(current_task()). $desc""") task_local_storage(:ScalarIndexing, ScalarWarned) end