Skip to content

Commit

Permalink
add allowscalar (#200)
Browse files Browse the repository at this point in the history
* add enable scalar

* set evaluation scope of RCI

* add better docstring

* add space

* Update abstractsparse.jl

* Update abstractsparse.jl

* Update abstractsparse.jl

Co-authored-by: Viral B. Shah <ViralBShah@users.noreply.github.com>
  • Loading branch information
SobhanMP and ViralBShah authored Aug 7, 2022
1 parent 0944c41 commit 090474b
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 6 deletions.
51 changes: 51 additions & 0 deletions src/abstractsparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,54 @@ julia> findnz(A)
function findnz end

widelength(x::AbstractSparseArray) = prod(Int64.(size(x)))


const _restore_scalar_indexing = Expr[]
const _destroy_scalar_indexing = Expr[]
"""
@RCI f
records the function `f` to be overwritten (and restored) with `allowscalar(::Bool)`. This is an
experimental feature.
Note that it will evaluate the function in the top level of the package. The original code for `f`
is stored in `_restore_scalar_indexing` and a function that has the same definition as `f` but
returns an error is stored in `_destroy_scalar_indexing`.
"""
macro RCI(exp)
# evaluate to not push any broken code in the arrays when developping this package.
# also ensures that restore has the exact same effect.
@eval $exp
if length(exp.args) == 2 && exp.head (:function, :(=))
push!(_restore_scalar_indexing, exp)
push!(_destroy_scalar_indexing,
Expr(exp.head,
exp.args[1],
:(error("scalar indexing was turned off"))))
else
error("can't parse expression")
end
return
end

"""
allowscalar(::Bool)
An experimental function that allows one to disable and re-enable scalar indexing for sparse matrices and vectors.
`allowscalar(false)` will disable scalar indexing for sparse matrices and vectors.
`allowscalar(true)` will restore the original scalar indexing functionality.
Since this function overwrites existing definitions, it will lead to recompilation. It is useful mainly when testing
code for devices such as [GPUs](https://cuda.juliagpu.org/stable/usage/workflow/), where the presence of scalar indexing can lead to substantial slowdowns.
Disabling scalar indexing during such tests can help identify performance bottlenecks quickly.
"""
allowscalar(p::Bool) = if p
for i in _restore_scalar_indexing
@eval $i
end
else
for i in _destroy_scalar_indexing
@eval $i
end
end
6 changes: 3 additions & 3 deletions src/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2296,9 +2296,9 @@ function rangesearch(haystack::AbstractRange, needle)
(rem==0 && 1<=i+1<=length(haystack)) ? i+1 : 0
end

getindex(A::AbstractSparseMatrixCSC, I::Tuple{Integer,Integer}) = getindex(A, I[1], I[2])
@RCI getindex(A::AbstractSparseMatrixCSC, I::Tuple{Integer,Integer}) = getindex(A, I[1], I[2])

function getindex(A::AbstractSparseMatrixCSC{T}, i0::Integer, i1::Integer) where T
@RCI function getindex(A::AbstractSparseMatrixCSC{T}, i0::Integer, i1::Integer) where T
@boundscheck checkbounds(A, i0, i1)
r1 = Int(getcolptr(A)[i1])
r2 = Int(getcolptr(A)[i1+1]-1)
Expand Down Expand Up @@ -2744,7 +2744,7 @@ getindex(A::AbstractSparseMatrixCSC, I::AbstractVector{Bool}, J::AbstractVector{
## setindex!

# dispatch helper for #29034
setindex!(A::AbstractSparseMatrixCSC, _v, _i::Integer, _j::Integer) = _setindex_scalar!(A, _v, _i, _j)
@RCI setindex!(A::AbstractSparseMatrixCSC, _v, _i::Integer, _j::Integer) = _setindex_scalar!(A, _v, _i, _j)

function _setindex_scalar!(A::AbstractSparseMatrixCSC{Tv,Ti}, _v, _i::Integer, _j::Integer) where {Tv,Ti<:Integer}
v = convert(Tv, _v)
Expand Down
6 changes: 3 additions & 3 deletions src/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ end

### Element access

function setindex!(x::SparseVector{Tv,Ti}, v::Tv, i::Ti) where {Tv,Ti<:Integer}
@RCI function setindex!(x::SparseVector{Tv,Ti}, v::Tv, i::Ti) where {Tv,Ti<:Integer}
checkbounds(x, i)
nzind = nonzeroinds(x)
nzval = nonzeros(x)
Expand All @@ -357,7 +357,7 @@ function setindex!(x::SparseVector{Tv,Ti}, v::Tv, i::Ti) where {Tv,Ti<:Integer}
x
end

setindex!(x::SparseVector{Tv,Ti}, v, i::Integer) where {Tv,Ti<:Integer} =
@RCI setindex!(x::SparseVector{Tv,Ti}, v, i::Integer) where {Tv,Ti<:Integer} =
setindex!(x, convert(Tv, v), convert(Ti, i))


Expand Down Expand Up @@ -844,7 +844,7 @@ function _spgetindex(m::Int, nzind::AbstractVector{Ti}, nzval::AbstractVector{Tv
(ii <= m && nzind[ii] == i) ? nzval[ii] : zero(Tv)
end

function getindex(x::AbstractSparseVector, i::Integer)
@RCI function getindex(x::AbstractSparseVector, i::Integer)
checkbounds(x, i)
_spgetindex(nnz(x), nonzeroinds(x), nonzeros(x), i)
end
Expand Down
24 changes: 24 additions & 0 deletions test/allowscalar.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
using Test, SparseArrays

@testset "allowscalar" begin
A = sprandn(10, 20, 0.9)
A[1, 1] = 2
@test A[1, 1] == 2
SparseArrays.allowscalar(false)
@test_throws Any A[1, 1]
@test_throws Any A[1, 1] = 2
SparseArrays.allowscalar(true)
@test A[1, 1] == 2
A[1, 1] = 3
@test A[1, 1] == 3

B = sprandn(10, 0.9)
B[1] = 2
@test B[1] == 2
SparseArrays.allowscalar(false)
@test_throws Any B[1]
SparseArrays.allowscalar(true)
@test B[1] == 2
B[1] = 3
@test B[1] == 3
end
1 change: 1 addition & 0 deletions test/testgroups
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
allowscalar
ambiguous
higherorderfns
sparsematrix_ops
Expand Down

0 comments on commit 090474b

Please sign in to comment.