Skip to content

Commit

Permalink
Support Dotu (#274)
Browse files Browse the repository at this point in the history
* Support Dotu

* Update Project.toml

* Special Real broadcast routines

* Update Project.toml
  • Loading branch information
dlfivefifty authored Sep 12, 2023
1 parent e7eba56 commit 3ecdcd0
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 38 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "LazyArrays"
uuid = "5078a376-72f3-5289-bfd5-ec5146d43c02"
version = "1.7"
version = "1.8"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
Expand All @@ -19,7 +19,7 @@ LazyArraysStaticArraysExt = "StaticArrays"

[compat]
Aqua = "0.6"
ArrayLayouts = "1.0"
ArrayLayouts = "1.4.1"
FillArrays = "1.0"
MacroTools = "0.5"
MatrixFactorizations = "1.0, 2.0"
Expand Down
2 changes: 1 addition & 1 deletion src/LazyArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ import ArrayLayouts: MatMulVecAdd, MatMulMatAdd, MulAdd, Lmul, Rmul, Ldiv, Dot,
adjointlayout, sub_materialize, mulreduce,
check_mul_axes, _mul_eltype, check_ldiv_axes, ldivaxes, colsupport, rowsupport,
_fill_lmul!, scalarone, scalarzero, fillzeros, zero!, layout_getindex, _copyto!,
AbstractQLayout, StridedLayout, layout_replace_in_print_matrix
AbstractQLayout, StridedLayout, layout_replace_in_print_matrix, dotu, Dotu

import Base: require_one_based_indexing, oneto

Expand Down
17 changes: 17 additions & 0 deletions src/lazybroadcasting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ BroadcastMatrix(A::BroadcastMatrix) = A
@inline _broadcastarray2broadcasted(_, a) = a
@inline _broadcastarray2broadcasted(lay, a::BroadcastArray) = error("Overload LazyArrays._broadcastarray2broadcasted(::$(lay), _)")
@inline _broadcastarray2broadcasted(::DualLayout{ML}, a) where ML = _broadcastarray2broadcasted(ML(), a)
@inline _broadcastarray2broadcasted(::DualLayout{ML}, a::BroadcastArray) where ML = _broadcastarray2broadcasted(ML(), a)
@inline _broadcastarray2broadcasted(a) = _broadcastarray2broadcasted(MemoryLayout(a), a)
@inline _broadcasted(A) = instantiate(_broadcastarray2broadcasted(A))
broadcasted(A::BroadcastArray) = _broadcasted(A)
Expand Down Expand Up @@ -394,3 +395,19 @@ for op in (:*, :\, :/)
end

permutedims(A::BroadcastArray{T}) where T = BroadcastArray{T}(A.f, map(_permutedims,A.args)...)



####
# Dual broadcast: functions of transpose can also behave like transpose
####

@inline broadcastlayout(::Type{F}, ::DualLayout) where F = DualLayout{BroadcastLayout{F}}()

# real broadcast can use adjoint or transpose. This keeps types simple
sub_materialize(::DualLayout{BroadcastLayout{typeof(real)}}, A::AbstractMatrix{<:Real}) = sub_materialize(view(_adjortrans(A), parentindices(A)[2]))'

_adjortrans(A::SubArray{<:Any,2, <:Any, <:Tuple{Slice,Any}}) = view(_adjortrans(parent(A)), parentindices(A)[2])
_adjortrans(A::Adjoint) = A'
_adjortrans(A::Transpose) = transpose(A)
_adjortrans(A::BroadcastArray{T,N,typeof(real)}) where {T,N} = BroadcastArray{T}(real, _adjortrans(A.args...))
69 changes: 36 additions & 33 deletions src/padded.jl
Original file line number Diff line number Diff line change
Expand Up @@ -327,44 +327,47 @@ _normInf(::PaddedLayout, a) = norm(paddeddata(a),Inf)
_normp(::PaddedLayout, a, p) = norm(paddeddata(a),p)


function copy(D::Dot{layA, layB}) where {layA<:PaddedLayout,layB<:PaddedLayout}
a = paddeddata(D.A)
b = paddeddata(D.B)
T = eltype(D)
if MemoryLayout(a) isa layA && MemoryLayout(b) isa layB
return convert(T, dot(Array(a),Array(b)))
end
length(a) == length(b) && return convert(T, dot(a,b))
# following handles scalars
((length(a) == 1) || (length(b) == 1)) && return convert(T, a[1] * b[1])
m = min(length(a), length(b))
convert(T, dot(view(a, 1:m), view(b, 1:m)))
end

function copy(D::Dot{<:PaddedLayout})
a = paddeddata(D.A)
m = length(a)
v = view(D.B, 1:m)
if MemoryLayout(a) isa PaddedLayout
convert(eltype(D), dot(Array(a), v))
else
convert(eltype(D), dot(a, v))
end
for (Dt, dt) in ((:Dot, :dot), (:Dotu, :dotu))
@eval begin
function copy(D::$Dt{layA, layB}) where {layA<:PaddedLayout,layB<:PaddedLayout}
a = paddeddata(D.A)
b = paddeddata(D.B)
T = eltype(D)
if MemoryLayout(a) isa layA && MemoryLayout(b) isa layB
return convert(T, $dt(Array(a),Array(b)))
end
length(a) == length(b) && return convert(T, $dt(a,b))
# following handles scalars
((length(a) == 1) || (length(b) == 1)) && return convert(T, a[1] * b[1])
m = min(length(a), length(b))
convert(T, $dt(view(a, 1:m), view(b, 1:m)))
end

end
function copy(D::$Dt{<:PaddedLayout})
a = paddeddata(D.A)
m = length(a)
v = view(D.B, 1:m)
if MemoryLayout(a) isa PaddedLayout
convert(eltype(D), $dt(Array(a), v))
else
convert(eltype(D), $dt(a, v))
end

function copy(D::Dot{<:Any, <:PaddedLayout})
b = paddeddata(D.B)
m = length(b)
v = view(D.A, 1:m)
if MemoryLayout(b) isa PaddedLayout
convert(eltype(D), dot(v, Array(b)))
else
convert(eltype(D), dot(v, b))
end

function copy(D::$Dt{<:Any, <:PaddedLayout})
b = paddeddata(D.B)
m = length(b)
v = view(D.A, 1:m)
if MemoryLayout(b) isa PaddedLayout
convert(eltype(D), $dt(v, Array(b)))
else
convert(eltype(D), $dt(v, b))
end
end
end
end


_vcat_sub_arguments(::PaddedLayout, A, V) = _vcat_sub_arguments(ApplyLayout{typeof(vcat)}(), A, V)

_lazy_getindex(dat, kr...) = view(dat, kr...)
Expand Down
13 changes: 11 additions & 2 deletions test/broadcasttests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,7 @@ import Base: broadcasted
a = BroadcastArray(exp, 1:5)
b = randn(5)
@test MemoryLayout(a') isa DualLayout{BroadcastLayout{typeof(exp)}}
@test a'b Vector(a)'b
@test BroadcastArray(a')b [a'b]
@test a'b BroadcastArray(a')b Vector(a)'b
end

@testset "show" begin
Expand Down Expand Up @@ -359,4 +358,14 @@ import Base: broadcasted
@testset "zero-sized views" begin
@test size(copy(view(BroadcastArray(+, 1, randn(1,3)), 1:0, 2:3))) == (0,2)
end

@testset "real adjortrans" begin
a = BroadcastArray(real, ((1:5) .+ im)')
@test a[:,1:3] == (1:3)'
@test a[:,1:3] isa Adjoint{Int,Vector{Int}}

a = BroadcastArray(real, transpose((1:5) .+ im))
@test a[:,1:3] == (1:3)'
@test a[:,1:3] isa Adjoint{Int,Vector{Int}}
end
end

2 comments on commit 3ecdcd0

@dlfivefifty
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/91276

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.8.0 -m "<description of version>" 3ecdcd03d2ee23622c711e5362ae503a166f3a5e
git push origin v1.8.0

Please sign in to comment.