diff --git a/src/abstractsparse.jl b/src/abstractsparse.jl index 1db6a340..39d44696 100644 --- a/src/abstractsparse.jl +++ b/src/abstractsparse.jl @@ -125,3 +125,54 @@ julia> findnz(A) function findnz end widelength(x::AbstractSparseArray) = prod(Int64.(size(x))) + + +const _restore_scalar_indexing = Expr[] +const _destroy_scalar_indexing = Expr[] +""" + @RCI f + +records the function `f` to be overwritten (and restored) with `allowscalar(::Bool)`. This is an +experimental feature. + +Note that it will evaluate the function in the top level of the package. The original code for `f` +is stored in `_restore_scalar_indexing` and a function that has the same definition as `f` but +returns an error is stored in `_destroy_scalar_indexing`. +""" +macro RCI(exp) + # evaluate to not push any broken code in the arrays when developping this package. + # also ensures that restore has the exact same effect. + @eval $exp + if length(exp.args) == 2 && exp.head ∈ (:function, :(=)) + push!(_restore_scalar_indexing, exp) + push!(_destroy_scalar_indexing, + Expr(exp.head, + exp.args[1], + :(error("scalar indexing was turned off")))) + else + error("can't parse expression") + end + return +end + +""" + allowscalar(::Bool) + +An experimental function that allows one to disable and re-enable scalar indexing for sparse matrices and vectors. + +`allowscalar(false)` will disable scalar indexing for sparse matrices and vectors. +`allowscalar(true)` will restore the original scalar indexing functionality. + +Since this function overwrites existing definitions, it will lead to recompilation. It is useful mainly when testing +code for devices such as [GPUs](https://cuda.juliagpu.org/stable/usage/workflow/), where the presence of scalar indexing can lead to substantial slowdowns. +Disabling scalar indexing during such tests can help identify performance bottlenecks quickly. +""" +allowscalar(p::Bool) = if p + for i in _restore_scalar_indexing + @eval $i + end +else + for i in _destroy_scalar_indexing + @eval $i + end +end diff --git a/src/sparsematrix.jl b/src/sparsematrix.jl index 674a2353..1b2ffb34 100644 --- a/src/sparsematrix.jl +++ b/src/sparsematrix.jl @@ -2261,9 +2261,9 @@ function rangesearch(haystack::AbstractRange, needle) (rem==0 && 1<=i+1<=length(haystack)) ? i+1 : 0 end -getindex(A::AbstractSparseMatrixCSC, I::Tuple{Integer,Integer}) = getindex(A, I[1], I[2]) +@RCI getindex(A::AbstractSparseMatrixCSC, I::Tuple{Integer,Integer}) = getindex(A, I[1], I[2]) -function getindex(A::AbstractSparseMatrixCSC{T}, i0::Integer, i1::Integer) where T +@RCI function getindex(A::AbstractSparseMatrixCSC{T}, i0::Integer, i1::Integer) where T @boundscheck checkbounds(A, i0, i1) r1 = Int(getcolptr(A)[i1]) r2 = Int(getcolptr(A)[i1+1]-1) @@ -2709,7 +2709,7 @@ getindex(A::AbstractSparseMatrixCSC, I::AbstractVector{Bool}, J::AbstractVector{ ## setindex! # dispatch helper for #29034 -setindex!(A::AbstractSparseMatrixCSC, _v, _i::Integer, _j::Integer) = _setindex_scalar!(A, _v, _i, _j) +@RCI setindex!(A::AbstractSparseMatrixCSC, _v, _i::Integer, _j::Integer) = _setindex_scalar!(A, _v, _i, _j) function _setindex_scalar!(A::AbstractSparseMatrixCSC{Tv,Ti}, _v, _i::Integer, _j::Integer) where {Tv,Ti<:Integer} v = convert(Tv, _v) diff --git a/src/sparsevector.jl b/src/sparsevector.jl index a81f226f..9354943f 100644 --- a/src/sparsevector.jl +++ b/src/sparsevector.jl @@ -334,7 +334,7 @@ end ### Element access -function setindex!(x::SparseVector{Tv,Ti}, v::Tv, i::Ti) where {Tv,Ti<:Integer} +@RCI function setindex!(x::SparseVector{Tv,Ti}, v::Tv, i::Ti) where {Tv,Ti<:Integer} checkbounds(x, i) nzind = nonzeroinds(x) nzval = nonzeros(x) @@ -352,7 +352,7 @@ function setindex!(x::SparseVector{Tv,Ti}, v::Tv, i::Ti) where {Tv,Ti<:Integer} x end -setindex!(x::SparseVector{Tv,Ti}, v, i::Integer) where {Tv,Ti<:Integer} = +@RCI setindex!(x::SparseVector{Tv,Ti}, v, i::Integer) where {Tv,Ti<:Integer} = setindex!(x, convert(Tv, v), convert(Ti, i)) @@ -839,7 +839,7 @@ function _spgetindex(m::Int, nzind::AbstractVector{Ti}, nzval::AbstractVector{Tv (ii <= m && nzind[ii] == i) ? nzval[ii] : zero(Tv) end -function getindex(x::AbstractSparseVector, i::Integer) +@RCI function getindex(x::AbstractSparseVector, i::Integer) checkbounds(x, i) _spgetindex(nnz(x), nonzeroinds(x), nonzeros(x), i) end diff --git a/test/allowscalar.jl b/test/allowscalar.jl new file mode 100644 index 00000000..fa35f35c --- /dev/null +++ b/test/allowscalar.jl @@ -0,0 +1,24 @@ +using Test, SparseArrays + +@testset "allowscalar" begin + A = sprandn(10, 20, 0.9) + A[1, 1] = 2 + @test A[1, 1] == 2 + SparseArrays.allowscalar(false) + @test_throws Any A[1, 1] + @test_throws Any A[1, 1] = 2 + SparseArrays.allowscalar(true) + @test A[1, 1] == 2 + A[1, 1] = 3 + @test A[1, 1] == 3 + + B = sprandn(10, 0.9) + B[1] = 2 + @test B[1] == 2 + SparseArrays.allowscalar(false) + @test_throws Any B[1] + SparseArrays.allowscalar(true) + @test B[1] == 2 + B[1] = 3 + @test B[1] == 3 +end diff --git a/test/testgroups b/test/testgroups index e8e8b8a0..e61a1591 100644 --- a/test/testgroups +++ b/test/testgroups @@ -1,3 +1,4 @@ +allowscalar ambiguous higherorderfns sparsematrix_ops