Skip to content

Commit

Permalink
Add promotion for matrix * vector and vector * matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
thofma committed Feb 18, 2023
1 parent 8441a3a commit 5b8af02
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 0 deletions.
25 changes: 25 additions & 0 deletions src/Matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -975,6 +975,31 @@ end

==(x::MatElem, y::MatElem) = ==(promote(x, y)...)

# matrix * vec and vec * matrx
function Base.promote(x::MatrixElem{S},
y::Vector{T}) where {S <: NCRingElement,
T <: NCRingElement}
U = promote_rule_sym(S, T)
if U === S
return x, map(base_ring(x), y)
elseif U === T && length(y) != 0
return change_base_ring(parent(y[1]), x), y
else
error("Cannot promote to common type")
end
end

function Base.promote(x::Vector{S},
y::MatrixElem{T}) where {S <: NCRingElement,
T <: NCRingElement}
yy, xx = promote(y, x)
return xx, yy
end

*(x::MatrixElem, y::Vector) = *(promote(x, y)...)

*(x::Vector, y::MatrixElem) = *(promote(x, y)...)

function Base.promote(x::MatElem{S}, y::T) where {S <: NCRingElement, T <: NCRingElement}
U = promote_rule_sym(S, T)
if U === S
Expand Down
17 changes: 17 additions & 0 deletions test/generic/Matrix-test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1334,6 +1334,23 @@ end

@test one(F) == R[1 0; 0 1]
@test R[1 0; 0 1] == one(R)

# vector * matrix
m = [1 2; 3 4]
F = ResidueField(ZZ, 3)
R, t = PolynomialRing(F, "t")
A = matrix(R, m)
B = matrix(F, m)
v = [one(F), 2*one(F)]
vv = [one(R), 2*one(R)]
@test (@inferred A * v) == A * vv
@test (@inferred v * A) == vv * A

@test (@inferred B * vv) == A * vv
@test (@inferred vv * B) == vv * A

@test_throws ErrorException A * Rational{BigInt}[1 ,2]
@test_throws ErrorException Rational{BigInt}[1 ,2] * A
end

@testset "Generic.Mat.permutation" begin
Expand Down
17 changes: 17 additions & 0 deletions test/generic/MatrixAlgebra-test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,23 @@ end

@test one(F) == R[1 0; 0 1]
@test R[1 0; 0 1] == one(R)

# vector * matrix
m = [1 2; 3 4]
F = ResidueField(ZZ, 3)
R, t = PolynomialRing(F, "t")
A = MatrixAlgebra(R, 2)(m)
B = MatrixAlgebra(F, 2)(m)
v = [one(F), 2*one(F)]
vv = [one(R), 2*one(R)]
@test (@inferred A * v) == A * vv
@test (@inferred v * A) == vv * A

@test (@inferred B * vv) == A * vv
@test (@inferred vv * B) == vv * A

@test_throws ErrorException A * Rational{BigInt}[1 ,2]
@test_throws ErrorException Rational{BigInt}[1 ,2] * A
end

@testset "Generic.MatAlg.permutation" begin
Expand Down

0 comments on commit 5b8af02

Please sign in to comment.