Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve zero checks in sparse and reduce #9325

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion base/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,6 @@ function count(pred::Union(Function,Func{1}), a::AbstractArray)
end

immutable NotEqZero <: Func{1} end
call(::NotEqZero, x) = x != 0
call(::NotEqZero, x) = x != zero(x)

countnz(a) = count(NotEqZero(), a)
4 changes: 2 additions & 2 deletions base/sparse/csparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ function sparse{Tv,Ti<:Integer}(I::AbstractVector{Ti}, J::AbstractVector{Ti},
Rnz[1] = 1
nz = 0
for k=1:N
if V[k] != 0
if V[k] != zero(Tv)
Rnz[I[k]+1] += 1
nz += 1
end
Expand All @@ -49,7 +49,7 @@ function sparse{Tv,Ti<:Integer}(I::AbstractVector{Ti}, J::AbstractVector{Ti},
((iind > 0) && (jind > 0)) || throw(BoundsError())
p = Wj[iind]
Vk = V[k]
if Vk != 0
if Vk != zero(Tv)
Wj[iind] += 1
Rx[p] = Vk
Ri[p] = jind
Expand Down
6 changes: 3 additions & 3 deletions base/sparse/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ function findn{Tv,Ti}(S::SparseMatrixCSC{Tv,Ti})

count = 1
@inbounds for col = 1 : S.n, k = S.colptr[col] : (S.colptr[col+1]-1)
if S.nzval[k] != 0
if S.nzval[k] != zero(Tv)
I[count] = S.rowval[k]
J[count] = col
count += 1
Expand All @@ -338,7 +338,7 @@ function findnz{Tv,Ti}(S::SparseMatrixCSC{Tv,Ti})

count = 1
@inbounds for col = 1 : S.n, k = S.colptr[col] : (S.colptr[col+1]-1)
if S.nzval[k] != 0
if S.nzval[k] != zero(Tv)
I[count] = S.rowval[k]
J[count] = col
V[count] = S.nzval[k]
Expand Down Expand Up @@ -1224,7 +1224,7 @@ function setindex!{T,Ti}(A::SparseMatrixCSC{T,Ti}, v, i0::Integer, i1::Integer)
v = convert(T, v)
r1 = int(A.colptr[i1])
r2 = int(A.colptr[i1+1]-1)
if v == 0 #either do nothing or delete entry if it exists
if v == zero(T) #either do nothing or delete entry if it exists
if r1 <= r2
r1 = searchsortedfirst(A.rowval, i0, r1, r2, Forward)
if (r1 <= r2) && (A.rowval[r1] == i0)
Expand Down
17 changes: 17 additions & 0 deletions test/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -472,3 +472,20 @@ end
# issue #8976
@test conj(sparse([1im])) == sparse(conj([1im]))
@test conj!(sparse([1im])) == sparse(conj!([1im]))

# Test proper handling of zeros for user types
immutable SpTestVal
value::Float64
end
(==)(x::SpTestVal,y::SpTestVal) = (x.value == y.value)
Base.zero(x::SpTestVal) = SpTestVal(0)
Base.zero(::Type{SpTestVal}) = SpTestVal(0)
A = sparse([1,2,3],[1,2,3],[SpTestVal(1),SpTestVal(0),SpTestVal(3)])
@test nnz(A) == 2 # zeros should be stripped by sparse
A[2,2] = SpTestVal(0)
@test nnz(A) == 2 # zeros should be stripped by setindex
A = SparseMatrixCSC(3,3,[1,2,3,4],[1,2,3],
[SpTestVal(1.0),SpTestVal(0.0),SpTestVal(3.0)])
@test countnz(A) == 2
r,c,v = findnz(A)
@test length(r) == length(c) == length(v) == 2