Skip to content

Commit

Permalink
Merge pull request #14425 from sarvjeetsinghghotra/issue-13690
Browse files Browse the repository at this point in the history
RFC: Fixed scale and scale1 methods to not assume commutativity. Issue #13690
  • Loading branch information
tkelman committed Jan 6, 2016
2 parents b17beb3 + 997bcc6 commit bc1c18e
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 3 deletions.
29 changes: 27 additions & 2 deletions base/linalg/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@ function generic_scale!(X::AbstractArray, s::Number)
X
end

function generic_scale!(s::Number, X::AbstractArray)
@simd for I in eachindex(X)
@inbounds X[I] = s*X[I]
end
X
end

function generic_scale!(C::AbstractArray, X::AbstractArray, s::Number)
if length(C) != length(X)
throw(DimensionMismatch("first array has length $(length(C)) which does not match the length of the second, $(length(X))."))
Expand All @@ -29,10 +36,28 @@ function generic_scale!(C::AbstractArray, X::AbstractArray, s::Number)
end
C
end

function generic_scale!(C::AbstractArray, s::Number, X::AbstractArray)
if length(C) != length(X)
throw(DimensionMismatch("first array has length $(length(C)) which does not
match the length of the second, $(length(X))."))
end
if size(C) == size(X)
for I in eachindex(C, X)
@inbounds C[I] = s*X[I]
end
else
for (IC, IX) in zip(eachindex(C), eachindex(X))
@inbounds C[IC] = s*X[IX]
end
end
C
end

scale!(C::AbstractArray, s::Number, X::AbstractArray) = generic_scale!(C, X, s)
scale!(C::AbstractArray, X::AbstractArray, s::Number) = generic_scale!(C, X, s)
scale!(C::AbstractArray, X::AbstractArray, s::Number) = generic_scale!(C, s, X)
scale!(X::AbstractArray, s::Number) = generic_scale!(X, s)
scale!(s::Number, X::AbstractArray) = generic_scale!(X, s)
scale!(s::Number, X::AbstractArray) = generic_scale!(s, X)

cross(a::AbstractVector, b::AbstractVector) = [a[2]*b[3]-a[3]*b[2], a[3]*b[1]-a[1]*b[3], a[1]*b[2]-a[2]*b[1]]

Expand Down
8 changes: 7 additions & 1 deletion base/sparse/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,13 @@ function scale!(C::SparseMatrixCSC, A::SparseMatrixCSC, b::Number)
C
end

scale!(C::SparseMatrixCSC, b::Number, A::SparseMatrixCSC) = scale!(C, A, b)
function scale!(C::SparseMatrixCSC, b::Number, A::SparseMatrixCSC)
size(A)==size(C) || throw(DimensionMismatch())
copyinds!(C, A)
resize!(C.nzval, length(A.nzval))
scale!(C.nzval, b, A.nzval)
C
end

scale!(A::SparseMatrixCSC, b::Number) = (scale!(A.nzval, b); A)
scale!(b::Number, A::SparseMatrixCSC) = (scale!(b, A.nzval); A)
Expand Down
25 changes: 25 additions & 0 deletions test/linalg/generic.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,26 @@
# This file is a part of Julia. License is MIT: http://julialang.org/license

import Base: *
using Base.Test

# A custom Quaternion type with minimal defined interface and methods.
# Used to test scale and scale! methods to show non-commutativity.
immutable Quaternion{T<:Real} <: Number
s::T
v1::T
v2::T
v3::T
norm::Bool
end
Quaternion(s::Real, v1::Real, v2::Real, v3::Real, n::Bool = false) =
Quaternion( promote(s, v1, v2, v3)..., n)
Quaternion(a::Vector) = Quaternion(0, a[1], a[2], a[3])
(*)(q::Quaternion, w::Quaternion) = Quaternion(q.s*w.s - q.v1*w.v1 - q.v2*w.v2 - q.v3*w.v3,
q.s*w.v1 + q.v1*w.s + q.v2*w.v3 - q.v3*w.v2,
q.s*w.v2 - q.v1*w.v3 + q.v2*w.s + q.v3*w.v1,
q.s*w.v3 + q.v1*w.v2 - q.v2*w.v1 + q.v3*w.s,
q.norm && w.norm)

debug = false

srand(123)
Expand Down Expand Up @@ -119,6 +138,12 @@ b = randn(Base.LinAlg.SCAL_CUTOFF) # make sure we try BLAS path
@test isequal(scale(BigFloat[1.0], 2.0im), Complex{BigFloat}[2.0im])
@test isequal(scale(BigFloat[1.0], 2.0f0im), Complex{BigFloat}[2.0im])

# test scale and scale! for non-commutative multiplication
q = Quaternion([0.44567, 0.755871, 0.882548, 0.423612])
qmat = []
push!(qmat, Quaternion([0.015007, 0.355067, 0.418645, 0.318373]))
@test scale!(q, copy(qmat)) != scale!(copy(qmat), q)

# test ops on Numbers
for elty in [Float32,Float64,Complex64,Complex128]
a = rand(elty)
Expand Down

0 comments on commit bc1c18e

Please sign in to comment.