diff --git a/base/linalg/generic.jl b/base/linalg/generic.jl index b8e26ecdbb4bd..63390c372c0dc 100644 --- a/base/linalg/generic.jl +++ b/base/linalg/generic.jl @@ -14,6 +14,13 @@ function generic_scale!(X::AbstractArray, s::Number) X end +function generic_scale!(s::Number, X::AbstractArray) + @simd for I in eachindex(X) + @inbounds X[I] = s*X[I] + end + X +end + function generic_scale!(C::AbstractArray, X::AbstractArray, s::Number) if length(C) != length(X) throw(DimensionMismatch("first array has length $(length(C)) which does not match the length of the second, $(length(X)).")) @@ -29,10 +36,28 @@ function generic_scale!(C::AbstractArray, X::AbstractArray, s::Number) end C end + +function generic_scale!(C::AbstractArray, s::Number, X::AbstractArray) + if length(C) != length(X) + throw(DimensionMismatch("first array has length $(length(C)) which does not +match the length of the second, $(length(X)).")) + end + if size(C) == size(X) + for I in eachindex(C, X) + @inbounds C[I] = s*X[I] + end + else + for (IC, IX) in zip(eachindex(C), eachindex(X)) + @inbounds C[IC] = s*X[IX] + end + end + C +end + scale!(C::AbstractArray, s::Number, X::AbstractArray) = generic_scale!(C, X, s) -scale!(C::AbstractArray, X::AbstractArray, s::Number) = generic_scale!(C, X, s) +scale!(C::AbstractArray, X::AbstractArray, s::Number) = generic_scale!(C, s, X) scale!(X::AbstractArray, s::Number) = generic_scale!(X, s) -scale!(s::Number, X::AbstractArray) = generic_scale!(X, s) +scale!(s::Number, X::AbstractArray) = generic_scale!(s, X) cross(a::AbstractVector, b::AbstractVector) = [a[2]*b[3]-a[3]*b[2], a[3]*b[1]-a[1]*b[3], a[1]*b[2]-a[2]*b[1]] diff --git a/base/sparse/linalg.jl b/base/sparse/linalg.jl index 80d7d8616fad6..02231b7ab7252 100644 --- a/base/sparse/linalg.jl +++ b/base/sparse/linalg.jl @@ -815,7 +815,13 @@ function scale!(C::SparseMatrixCSC, A::SparseMatrixCSC, b::Number) C end -scale!(C::SparseMatrixCSC, b::Number, A::SparseMatrixCSC) = scale!(C, A, b) +function scale!(C::SparseMatrixCSC, b::Number, A::SparseMatrixCSC) + size(A)==size(C) || throw(DimensionMismatch()) + copyinds!(C, A) + resize!(C.nzval, length(A.nzval)) + scale!(C.nzval, b, A.nzval) + C +end scale!(A::SparseMatrixCSC, b::Number) = (scale!(A.nzval, b); A) scale!(b::Number, A::SparseMatrixCSC) = (scale!(b, A.nzval); A) diff --git a/test/linalg/generic.jl b/test/linalg/generic.jl index 3dee2b2577ac4..690e6da9c4138 100644 --- a/test/linalg/generic.jl +++ b/test/linalg/generic.jl @@ -1,7 +1,26 @@ # This file is a part of Julia. License is MIT: http://julialang.org/license +import Base: * using Base.Test +# A custom Quaternion type with minimal defined interface and methods. +# Used to test scale and scale! methods to show non-commutativity. +immutable Quaternion{T<:Real} <: Number + s::T + v1::T + v2::T + v3::T + norm::Bool +end +Quaternion(s::Real, v1::Real, v2::Real, v3::Real, n::Bool = false) = + Quaternion( promote(s, v1, v2, v3)..., n) +Quaternion(a::Vector) = Quaternion(0, a[1], a[2], a[3]) +(*)(q::Quaternion, w::Quaternion) = Quaternion(q.s*w.s - q.v1*w.v1 - q.v2*w.v2 - q.v3*w.v3, + q.s*w.v1 + q.v1*w.s + q.v2*w.v3 - q.v3*w.v2, + q.s*w.v2 - q.v1*w.v3 + q.v2*w.s + q.v3*w.v1, + q.s*w.v3 + q.v1*w.v2 - q.v2*w.v1 + q.v3*w.s, + q.norm && w.norm) + debug = false srand(123) @@ -119,6 +138,12 @@ b = randn(Base.LinAlg.SCAL_CUTOFF) # make sure we try BLAS path @test isequal(scale(BigFloat[1.0], 2.0im), Complex{BigFloat}[2.0im]) @test isequal(scale(BigFloat[1.0], 2.0f0im), Complex{BigFloat}[2.0im]) +# test scale and scale! for non-commutative multiplication +q = Quaternion([0.44567, 0.755871, 0.882548, 0.423612]) +qmat = [] +push!(qmat, Quaternion([0.015007, 0.355067, 0.418645, 0.318373])) +@test scale!(q, copy(qmat)) != scale!(copy(qmat), q) + # test ops on Numbers for elty in [Float32,Float64,Complex64,Complex128] a = rand(elty)