Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make I consistently provide identity behavior in type #24396

Merged
merged 1 commit into from
Oct 31, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,9 @@ This section lists changes that do not have deprecation warnings.
* All command line arguments passed via `-e`, `-E`, and `-L` will be executed in the order
given on the command line ([#23665]).

* `I` now yields `UniformScaling{Bool}(true)` rather than `UniformScaling{Int64}(1)`
to better preserve types in operations involving `I` ([#24396]).

* The return type of `reinterpret` has changed to `ReinterpretArray`. `reinterpret` on sparse
arrays has been discontinued.

Expand Down
20 changes: 10 additions & 10 deletions base/linalg/uniformscaling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ julia> [1 2im 3; 1im 2 3] * I
0+1im 2+0im 3+0im
```
"""
const I = UniformScaling(1)
const I = UniformScaling(true)

eltype(::Type{UniformScaling{T}}) where {T} = T
ndims(J::UniformScaling) = 2
Expand Down Expand Up @@ -99,7 +99,7 @@ for (t1, t2) in ((:UnitUpperTriangular, :UpperTriangular),
($op)(UL::$t2, J::UniformScaling) = ($t2)(($op)(UL.data, J))

function ($op)(UL::$t1, J::UniformScaling)
ULnew = copy_oftype(UL.data, promote_type(eltype(UL), eltype(J)))
ULnew = copy_oftype(UL.data, Base.Broadcast._broadcast_eltype($op, UL, J))
for i = 1:size(ULnew, 1)
ULnew[i,i] = ($op)(1, J.λ)
end
Expand All @@ -110,7 +110,7 @@ for (t1, t2) in ((:UnitUpperTriangular, :UpperTriangular),
end

function (-)(J::UniformScaling, UL::Union{UpperTriangular,UnitUpperTriangular})
ULnew = similar(parent(UL), promote_type(eltype(J), eltype(UL)))
ULnew = similar(parent(UL), Base.Broadcast._broadcast_eltype(-, J, UL))
n = size(ULnew, 1)
ULold = UL.data
for j = 1:n
Expand All @@ -126,7 +126,7 @@ function (-)(J::UniformScaling, UL::Union{UpperTriangular,UnitUpperTriangular})
return UpperTriangular(ULnew)
end
function (-)(J::UniformScaling, UL::Union{LowerTriangular,UnitLowerTriangular})
ULnew = similar(parent(UL), promote_type(eltype(J), eltype(UL)))
ULnew = similar(parent(UL), Base.Broadcast._broadcast_eltype(-, J, UL))
n = size(ULnew, 1)
ULold = UL.data
for j = 1:n
Expand All @@ -142,28 +142,28 @@ function (-)(J::UniformScaling, UL::Union{LowerTriangular,UnitLowerTriangular})
return LowerTriangular(ULnew)
end

function (+)(A::AbstractMatrix{TA}, J::UniformScaling{TJ}) where {TA,TJ}
function (+)(A::AbstractMatrix, J::UniformScaling)
n = checksquare(A)
B = similar(A, promote_type(TA,TJ))
B = similar(A, Base.Broadcast._broadcast_eltype(+, A, J))
copy!(B,A)
@inbounds for i = 1:n
B[i,i] += J.λ
end
B
end

function (-)(A::AbstractMatrix{TA}, J::UniformScaling{TJ}) where {TA,TJ<:Number}
function (-)(A::AbstractMatrix, J::UniformScaling)
n = checksquare(A)
B = similar(A, promote_type(TA,TJ))
B = similar(A, Base.Broadcast._broadcast_eltype(-, A, J))
copy!(B, A)
@inbounds for i = 1:n
B[i,i] -= J.λ
end
B
end
function (-)(J::UniformScaling{TJ}, A::AbstractMatrix{TA}) where {TA,TJ<:Number}
function (-)(J::UniformScaling, A::AbstractMatrix)
n = checksquare(A)
B = convert(AbstractMatrix{promote_type(TJ,TA)}, -A)
B = convert(AbstractMatrix{Base.Broadcast._broadcast_eltype(-, J, A)}, -A)
@inbounds for j = 1:n
B[j,j] += J.λ
end
Expand Down
2 changes: 1 addition & 1 deletion test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1919,7 +1919,7 @@ test5536(a::Union{Real, AbstractArray}) = "Non-splatting"
# issue #6142
import Base: +
mutable struct A6142 <: AbstractMatrix{Float64}; end
+(x::A6142, y::UniformScaling{TJ}) where {TJ} = "UniformScaling method called"
+(x::A6142, y::UniformScaling) = "UniformScaling method called"
+(x::A6142, y::AbstractArray) = "AbstractArray method called"
@test A6142() + I == "UniformScaling method called"
+(x::A6142, y::AbstractRange) = "AbstractRange method called" #16324 ambiguity
Expand Down
13 changes: 12 additions & 1 deletion test/linalg/uniformscaling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ end
end

@testset "det and logdet" begin
@test det(I) === 1
@test det(I) === true
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only bit that gives me pause; I wonder if we should have prod([true,true,true]) === 1 instead of true. We already have sum([true,true,true]) === 3 so it would kind of make sense for the type of prod to match, even though 0 and 1 would be the only possible value.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps warrants a dedicated issue? :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done: #24425

@test det(1.0I) === 1.0
@test det(0I) === 0
@test det(0.0I) === 0.0
Expand Down Expand Up @@ -216,3 +216,14 @@ end
@test alltwos != 2I != alltwos # test generic path / inequality off diag
@test rdenseI != I != rdenseI # test square matrix check
end

@testset "operations involving I should preserve eltype" begin
@test isa(Int8(1) + I, Int8)
@test isa(Float16(1) + I, Float16)
@test eltype(Int8(1)I) == Int8
@test eltype(Float16(1)I) == Float16
@test eltype(fill(Int8(1), 2, 2)I) == Int8
@test eltype(fill(Float16(1), 2, 2)I) == Float16
@test eltype(fill(Int8(1), 2, 2) + I) == Int8
@test eltype(fill(Float16(1), 2, 2) + I) == Float16
end