Skip to content

Commit

Permalink
Avoid Applied in ApplyArray (#264)
Browse files Browse the repository at this point in the history
* Avoid Applied in MulArray

* add applied_axes etc.

* Update lazyconcat.jl

* axes calls size

* Update inv.jl

* Update inv.jl

* tests pass

* increase coverage

* matrix function check applied axes

* Update lazyapplying.jl

* Update applytests.jl

* empty vcat

* Update concattests.jl
  • Loading branch information
dlfivefifty authored Jul 23, 2023
1 parent 7009bd1 commit 1ebb490
Show file tree
Hide file tree
Showing 12 changed files with 164 additions and 111 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "LazyArrays"
uuid = "5078a376-72f3-5289-bfd5-ec5146d43c02"
version = "1.4.1"
version = "1.5.0"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
Expand Down
96 changes: 59 additions & 37 deletions src/lazyapplying.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ end
@inline Applied{Style}(f::F, args::Args) where {Style,F,Args<:Tuple} = Applied{Style,F,Args}(f, args)
@inline Applied{Style}(A::Applied) where Style = Applied{Style}(A.f, A.args)

ndims(a::Applied) = applied_ndims(a.f, a.args...)
eltype(a::Applied) = applied_eltype(a.f, a.args...)
axes(a::Applied) = applied_axes(a.f, a.args...)
size(a::Applied) = applied_size(a.f, a.args...)

call(a) = a.f
call(_, a) = a.f
Expand All @@ -57,11 +61,19 @@ call(_, a) = a.f
@inline arguments(::DualLayout{ML}, a) where ML = arguments(ML(), a)
@inline arguments(a::AbstractArray) = arguments(MemoryLayout(a), a)

@inline check_applied_axes(A::Applied) = nothing
@inline check_applied_axes(_...) = nothing

# following repeated due to unexplained allocations
@inline function instantiate(A::Applied{Style}) where Style
check_applied_axes(A)
Applied{Style}(A.f, map(instantiate, A.args))
iargs = map(instantiate, A.args)
check_applied_axes(A.f, iargs...)
Applied{Style}(A.f, iargs)
end

@inline function applied_instantiate(f, args...)
iargs = map(instantiate, args)
check_applied_axes(f, iargs...)
f, iargs
end

@inline _typesof() = ()
Expand Down Expand Up @@ -152,10 +164,8 @@ for f in (:exp, :sin, :cos, :sqrt)
@eval ApplyStyle(::typeof($f), ::Type{<:AbstractMatrix}) = MatrixFunctionStyle{typeof($f)}()
end

function check_applied_axes(A::Applied{<:MatrixFunctionStyle})
length(A.args) == 1 || throw(ArgumentError("MatrixFunctions only defined with 1 arg"))
axes(A.args[1],1) == axes(A.args[1],2) || throw(DimensionMismatch("matrix is not square: dimensions are $axes(A.args[1])"))
end
@inline matrixfunction_check_applied_axes(a::AbstractMatrix) = axes(a,1) == axes(a,2) || throw(DimensionMismatch("matrix is not square: dimensions are $(axes(a))"))
@inline matrixfunction_check_applied_axes(a...) = nothing

for op in (:axes, :size)
@eval begin
Expand Down Expand Up @@ -192,19 +202,27 @@ const ApplyMatrix{T, F, Args<:Tuple} = ApplyArray{T, 2, F, Args}

LazyArray(A::Applied) = ApplyArray(A)

ApplyArray{T,N,F,Args}(M::Applied) where {T,N,F,Args} = ApplyArray{T,N,F,Args}(M.f, M.args)
ApplyArray{T,N}(M::Applied{Style,F,Args}) where {T,N,Style,F,Args} = ApplyArray{T,N,F,Args}(instantiate(M))
ApplyArray{T}(M::Applied) where {T} = ApplyArray{T,ndims(M)}(M)
ApplyArray(M::Applied) = ApplyArray{eltype(M)}(M)
ApplyVector(M::Applied) = ApplyVector{eltype(M)}(M)
ApplyMatrix(M::Applied) = ApplyMatrix{eltype(M)}(M)
@inline ApplyArray{T,N,F,Args}(M::Applied) where {T,N,F,Args} = ApplyArray{T,N,F,Args}(M.f, M.args)
@inline ApplyArray{T,N}(M::Applied{Style,F,Args}) where {T,N,Style,F,Args} = ApplyArray{T,N,F,Args}(instantiate(M))
@inline ApplyArray{T}(M::Applied) where {T} = ApplyArray{T,ndims(M)}(M)
@inline ApplyArray(M::Applied) = ApplyArray{eltype(M)}(M)

ApplyArray(f, factors...) = ApplyArray(applied(f, factors...))
ApplyArray{T}(f, factors...) where T = ApplyArray{T}(applied(f, factors...))
ApplyArray{T,N}(f, factors...) where {T,N} = ApplyArray{T,N}(applied(f, factors...))

ApplyVector(f, factors...) = ApplyVector(applied(f, factors...))
ApplyMatrix(f, factors...) = ApplyMatrix(applied(f, factors...))
@inline ApplyArray(f, factors...) = ApplyArray{applied_eltype(f, factors...)}(f, factors...)
@inline ApplyArray{T}(f, factors...) where T = ApplyArray{T, applied_ndims(f, factors...)}(f, factors...)
@inline function ApplyArray{T,N}(f, factors...) where {T,N}
f̃, args = applied_instantiate(f, factors...)
ApplyArray{T,N,typeof(f̃),typeof(args)}(f̃, args)
end

@inline ApplyVector(f, factors...) = ApplyVector{applied_eltype(f, factors...)}(f, factors...)
@inline ApplyMatrix(f, factors...) = ApplyMatrix{applied_eltype(f, factors...)}(f, factors...)

ApplyArray(A::AbstractArray{T,N}) where {T,N} = ApplyArray{T,N}(call(A), arguments(A)...)
ApplyArray{T}(A::AbstractArray{V,N}) where {T,V,N} = ApplyArray{T,N}(call(A), arguments(A)...)
ApplyArray{T,N}(A::AbstractArray{V,N}) where {T,V,N} = ApplyArray{T,N}(call(A), arguments(A)...)
ApplyMatrix(A::AbstractMatrix{T}) where T = ApplyMatrix{T}(call(A), arguments(A)...)
ApplyVector(A::AbstractVector{T}) where T = ApplyVector{T}(call(A), arguments(A)...)

convert(::Type{AbstractArray{T}}, A::ApplyArray{T}) where T = A
convert(::Type{AbstractArray{T}}, A::ApplyArray{<:Any,N}) where {T,N} = ApplyArray{T,N}(A.f, A.args...)
Expand All @@ -216,8 +234,12 @@ AbstractArray{T}(A::ApplyArray{<:Any,N}) where {T,N} = ApplyArray{T,N}(A.f, map(
AbstractArray{T,N}(A::ApplyArray{T,N}) where {T,N} = copy(A)
AbstractArray{T,N}(A::ApplyArray{<:Any,N}) where {T,N} = ApplyArray{T,N}(A.f, map(copy,A.args)...)

@inline axes(A::ApplyArray) = axes(Applied(A))
@inline size(A::ApplyArray) = map(length, axes(A))
@inline axes(A::ApplyArray) = applied_axes(A.f, A.args...)
@inline size(A::ApplyArray) = applied_size(A.f, A.args...)

@inline applied_axes(f, args...) = map(oneto, applied_size(f, args...))



# immutable arrays don't need to copy.
# Some special cases like vcat overload setindex! and therefore
Expand All @@ -236,13 +258,16 @@ for F in (:exp, :log, :sqrt, :cos, :sin, :tan, :csc, :sec, :cot,
:acosh, :asinh, :atanh, :acsch, :asech, :acoth,
:acos, :asin, :atan, :acsc, :asec, :acot)
@eval begin
ndims(M::Applied{LazyArrayApplyStyle,typeof($F)}) = ndims(first(M.args))
axes(M::Applied{LazyArrayApplyStyle,typeof($F)}) = axes(first(M.args))
size(M::Applied{LazyArrayApplyStyle,typeof($F)}) = size(first(M.args))
eltype(M::Applied{LazyArrayApplyStyle,typeof($F)}) = eltype(first(M.args))
@inline applied_ndims(M::typeof($F), a) = ndims(a)
@inline applied_axes(::typeof($F), a) = axes(a)
@inline applied_size(::typeof($F), a) = size(a)
@inline applied_eltype(::typeof($F), a) = float(eltype(a))
check_applied_axes(::typeof($F), a...) = matrixfunction_check_applied_axes(a...)
end
end



###
# show
###
Expand All @@ -252,10 +277,12 @@ _applyarray_summary(io::IO, C) = _applyarray_summary(io::IO, C.f, arguments(C))
function _applyarray_summary(io::IO, f, args)
print(io, f)
print(io, "(")
summary(io, first(args))
for a in tail(args)
print(io, ", ")
summary(io, a)
if !isempty(args)
summary(io, first(args))
for a in tail(args)
print(io, ", ")
summary(io, a)
end
end
print(io, ")")
end
Expand Down Expand Up @@ -302,7 +329,6 @@ MemoryLayout(::Type{ApplyArray{T,N,F,Args}}) where {T,N,F,Args} =
applylayout(F, tuple_type_memorylayouts(Args)...)

@inline Applied(A::AbstractArray) = Applied(call(A), arguments(A)...)
@inline ApplyArray(A::AbstractArray) = ApplyArray(call(A), arguments(A)...)

function show(io::IO, A::Applied)
print(io, "Applied(", A.f)
Expand Down Expand Up @@ -354,15 +380,11 @@ _base_copyto!(dest::AbstractArray, src::AbstractArray) = Base.invoke(copyto!, NT
# triu/tril
##
for tri in (:tril, :triu)
for op in (:axes, :size)
@eval begin
$op(A::Applied{<:Any,typeof($tri)}) = $op(first(A.args))
$op(A::Applied{<:Any,typeof($tri)}, j) = $op(first(A.args), j)
end
end
@eval begin
ndims(::Applied{<:Any,typeof($tri)}) = 2
eltype(A::Applied{<:Any,typeof($tri)}) = eltype(first(A.args))
applied_axes(::typeof($tri), a, k...) = axes(a)
applied_size(::typeof($tri), a, k...) = size(a)
applied_ndims(::typeof($tri), a, k...) = 2
applied_eltype(::typeof($tri), a, k...) = eltype(a)
$tri(A::LazyMatrix) = ApplyMatrix($tri, A)
$tri(A::LazyMatrix, k::Integer) = ApplyMatrix($tri, A, k)
end
Expand Down
1 change: 1 addition & 0 deletions src/lazybroadcasting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ is returned by `MemoryLayout(A)` if a matrix `A` is a `BroadcastArray`.
"""
struct BroadcastLayout{F} <: AbstractLazyLayout end

@inline tuple_type_memorylayouts(::Type{Tuple{}}) = ()
@inline tuple_type_memorylayouts(::Type{I}) where I<:Tuple = tuple(MemoryLayout(Base.tuple_type_head(I)), tuple_type_memorylayouts(Base.tuple_type_tail(I))...)
@inline tuple_type_memorylayouts(::Type{Tuple{A}}) where {A} = (MemoryLayout(A),)
@inline tuple_type_memorylayouts(::Type{Tuple{A,B}}) where {A,B} = (MemoryLayout(A),MemoryLayout(B))
Expand Down
43 changes: 23 additions & 20 deletions src/lazyconcat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,16 @@ function instantiate(A::Applied{DefaultApplyStyle,typeof(vcat)})
Applied{DefaultApplyStyle}(A.f,map(instantiate,A.args))
end

@inline eltype(A::Applied{<:Any,typeof(vcat)}) = promote_type(map(eltype,A.args)...)
@inline eltype(A::Applied{<:Any,typeof(vcat),Tuple{}}) = Any
@inline ndims(A::Applied{<:Any,typeof(vcat),I}) where I = max(1,maximum(map(ndims,A.args)))
@inline ndims(A::Applied{<:Any,typeof(vcat),Tuple{}}) = 1
@inline applied_eltype(::typeof(vcat)) = Any
@inline applied_eltype(::typeof(vcat), args...) = promote_type(map(eltype, args)...)
@inline applied_ndims(::typeof(vcat), args...) = max(1,maximum(map(ndims,args)))
@inline applied_ndims(::typeof(vcat)) = 1
@inline axes(f::Vcat{<:Any,1,Tuple{}}) = (OneTo(0),)
@inline axes(f::Vcat{<:Any,1}) = tuple(oneto(+(map(length,f.args)...)))
@inline axes(f::Vcat{<:Any,2}) = (oneto(+(map(a -> size(a,1), f.args)...)), axes(f.args[1],2))
@inline size(f::Vcat) = map(length, axes(f))


Base.IndexStyle(::Type{<:Vcat{T,1}}) where T = Base.IndexLinear()

function ==(a::Vcat{T,N}, b::Vcat{T,N}) where {N,T}
Expand Down Expand Up @@ -134,9 +137,9 @@ function instantiate(A::Applied{DefaultApplyStyle,typeof(hcat)})
Applied{DefaultApplyStyle}(A.f,map(instantiate,A.args))
end

@inline eltype(A::Applied{<:Any,typeof(hcat)}) = promote_type(map(eltype,A.args)...)
ndims(::Applied{<:Any,typeof(hcat)}) = 2
size(f::Applied{<:Any,typeof(hcat)}) = (size(f.args[1],1), +(map(a -> size(a,2), f.args)...))
@inline applied_eltype(::typeof(hcat), args...) = promote_type(map(eltype,args)...)
@inline applied_ndims(::typeof(hcat), args...) = 2
@inline applied_size(::typeof(hcat), args...) = (size(args[1],1), +(map(a -> size(a,2), args)...))

@inline hcat_getindex(f, k, j::Integer) = hcat_getindex_recursive(f, (k, j), f.args...)

Expand Down Expand Up @@ -184,26 +187,21 @@ end
# Hvcat
####

@inline eltype(A::Applied{<:Any,typeof(hvcat)}) = promote_type(map(eltype,tail(A.args))...)
ndims(::Applied{<:Any,typeof(hvcat)}) = 2
function size(f::Applied{<:Any,typeof(hvcat),<:Tuple{Int,Vararg{Any}}})
n = f.args[1]
sum(size.(f.args[2:n:end],1)),sum(size.(f.args[2:1+n],2))
end

function size(f::Applied{<:Any,typeof(hvcat),<:Tuple{NTuple{N,Int},Vararg{Any}}}) where N
n = f.args[1]
@inline applied_eltype(::typeof(hvcat), a, b...) = promote_type(map(eltype, b)...)
@inline applied_ndims(::typeof(hvcat), args...) = 2
@inline applied_size(::typeof(hvcat), n::Int, b...) = sum(size.(b[1:n:end],1)),sum(size.(b[1:n],2))

@inline function applied_size(::typeof(hvcat), n::NTuple{N,Int}, b...) where N
as = tuple(2, (2 .+ cumsum(Base.front(n)))...)
sum(size.(getindex.(Ref(f.args), as),1)),sum(size.(f.args[2:1+n[1]],2))
sum(size.(getindex.(Ref((n, b...)), as),1)),sum(size.(b[1:n[1]],2))
end


@inline hvcat_getindex(f, k, j::Integer) = hvcat_getindex_recursive(f, (k, j), f.args...)

_hvcat_size(A) = size(A)
_hvcat_size(A::Number) = (1,1)
_hvcat_size(A::AbstractVector) = (size(A,1),1)
@inline _hvcat_size(A) = size(A)
@inline _hvcat_size(A::Number) = (1,1)
@inline _hvcat_size(A::AbstractVector) = (size(A,1),1)

@inline function hvcat_getindex_recursive(f, (k,j)::Tuple{Integer,Integer}, N::Int, A, args...)
T = eltype(f)
Expand Down Expand Up @@ -966,4 +964,9 @@ end

searchsorted(f::Vcat{<:Any,1}, x) = searchsortedfirst(f, x):searchsortedlast(f,x)

###
# vec
###

@inline applied_eltype(::typeof(vec), a) = eltype(a)
@inline applied_axes(::typeof(vec), a) = (oneto(length(a)),)
20 changes: 12 additions & 8 deletions src/lazyoperations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ Kron{T}(A...) where T = ApplyArray{T}(kron, A...)
_kron_dims() = 0
_kron_dims(A, B...) = max(ndims(A), _kron_dims(B...))

eltype(A::Applied{<:Any,typeof(kron)}) = promote_type(map(eltype,A.args)...)
ndims(A::Applied{<:Any,typeof(kron)}) = _kron_dims(A.args...)
applied_eltype(::typeof(kron), args...) = promote_type(map(eltype,args)...)
applied_ndims(::typeof(kron), args...) = _kron_dims(args...)

size(K::Kron, j::Int) = prod(size.(K.args, j))
size(a::Kron{<:Any,1}) = (size(a,1),)
Expand Down Expand Up @@ -470,14 +470,18 @@ AccumulateAbstractVector(op, A::AbstractVector{T}) where T = AccumulateAbstractV

for op in (:rot180, :rotl90, :rotr90)
@eval begin
ndims(::Applied{<:Any,typeof($op)}) = 2
eltype(A::Applied{<:Any,typeof($op)}) = eltype(A.args...)
applied_ndims(::typeof($op), a) = 2
applied_eltype(::typeof($op), a) = eltype(a)
end
end
applied_size(::typeof(rot180), a) = size(a)
applied_axes(::typeof(rot180), a) = axes(a)
for op in (:rotl90, :rotr90)
@eval begin
applied_size(::typeof($op), a) = reverse(size(a))
applied_axes(::typeof($op), a) = reverse(axes(a))
end
end
size(A::Applied{<:Any,typeof(rot180)}) = size(A.args...)
axes(A::Applied{<:Any,typeof(rot180)}) = axes(A.args...)
size(A::Applied{<:Any,typeof(rotl90)}) = reverse(size(A.args...))
size(A::Applied{<:Any,typeof(rotr90)}) = reverse(size(A.args...))

getindex(A::Applied{<:Any,typeof(rot180)}, k::Int, j::Int) = A.args[1][end-k+1,end-j+1]
getindex(A::Applied{<:Any,typeof(rotl90)}, k::Int, j::Int) = A.args[1][j,end-k+1]
Expand Down
7 changes: 2 additions & 5 deletions src/linalg/add.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,10 @@ for op in (:+, :-)
@eval begin
size(M::Applied{<:Any, typeof($op)}, p::Int) = size(M)[p]
axes(M::Applied{<:Any, typeof($op)}, p::Int) = axes(M)[p]
ndims(M::Applied{<:Any, typeof($op)}) = ndims(first(M.args))

length(M::Applied{<:Any, typeof($op)}) = prod(size(M))
size(M::Applied{<:Any, typeof($op)}) = length.(axes(M))
axes(M::Applied{<:Any, typeof($op)}) = axes(first(M.args))

eltype(M::Applied{<:Any, typeof($op)}) = promote_type(map(eltype,M.args)...)
applied_size(::typeof($op), args...) = size(first(args))
applied_axes(::typeof($op), args...) = axes(first(args))
end
end

Expand Down
37 changes: 19 additions & 18 deletions src/linalg/inv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,35 +22,36 @@ ndims(A::InvOrPInv) = ndims(parent(A))



size(A::InvOrPInv) = reverse(size(parent(A)))
axes(A::InvOrPInv) = reverse(axes(parent(A)))
size(A::InvOrPInv, k) = size(A)[k]
axes(A::InvOrPInv, k) = axes(A)[k]
eltype(A::InvOrPInv) = Base.promote_op(inv, eltype(parent(A)))
for op in (:inv, :pinv)
@eval begin
@inline applied_size(::typeof($op), a) = reverse(size(a))
@inline applied_axes(::typeof($op), a) = reverse(axes(a))
@inline applied_eltype(::typeof($op), a) = Base.promote_op(inv, eltype(a))
@inline applied_ndims(::typeof($op), a) = 2
end
end

# Use ArrayLayouts.ldiv instead of \
struct LdivStyle <: ApplyStyle end
struct RdivStyle <: ApplyStyle end

ApplyStyle(::typeof(\), ::Type{A}, ::Type{B}) where {A<:AbstractArray,B<:AbstractArray} = LdivStyle()
ApplyStyle(::typeof(/), ::Type{A}, ::Type{B}) where {A<:AbstractArray,B<:AbstractArray} = RdivStyle()
@inline ApplyStyle(::typeof(\), ::Type{A}, ::Type{B}) where {A<:AbstractArray,B<:AbstractArray} = LdivStyle()
@inline ApplyStyle(::typeof(/), ::Type{A}, ::Type{B}) where {A<:AbstractArray,B<:AbstractArray} = RdivStyle()


axes(M::Applied{Style,typeof(\)}) where Style = ldivaxes(M.args...)
axes(M::Applied{Style,typeof(\)}, p::Int) where Style = axes(M)[p]
size(M::Applied{Style,typeof(\)}) where Style = length.(axes(M))
@inline eltype(M::Applied{Style,typeof(\)}) where Style = eltype(Ldiv(M.args...))
@inline ndims(M::Applied{Style,typeof(\)}) where Style = ndims(last(M.args))
@inline applied_axes(::typeof(\), args...) = ldivaxes(args...)
@inline applied_size(::typeof(\), args...) = length.(applied_axes(\, args...))
@inline applied_eltype(::typeof(\), args...) = eltype(Ldiv(args...))
@inline applied_ndims(::typeof(\), args...) = ndims(last(args))


axes(M::Applied{Style,typeof(/)}) where Style = axes(Rdiv(M.args...))
axes(M::Applied{Style,typeof(/)}, p::Int) where Style = axes(M)[p]
size(M::Applied{Style,typeof(/)}) where Style = length.(axes(M))
@inline eltype(M::Applied{Style,typeof(/)}) where Style = eltype(Rdiv(M.args...))
@inline ndims(M::Applied{Style,typeof(/)}) where Style = ndims(first(M.args))
@inline applied_axes(::typeof(/), args...) = axes(Rdiv(args...))
@inline applied_size(::typeof(/), args...) = length.(applied_axes(/, args...))
@inline applied_eltype(::typeof(/), args...) = eltype(Rdiv(args...))
@inline applied_ndims(::typeof(/), args...) = ndims(first(args))


check_applied_axes(A::Applied{<:Any,typeof(\)}) = check_ldiv_axes(A.args...)
check_applied_axes(::typeof(\), args...) = check_ldiv_axes(args...)

######
# PInv/Inv
Expand Down
Loading

2 comments on commit 1ebb490

@dlfivefifty
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/88135

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.5.0 -m "<description of version>" 1ebb4902dc5ae11ebc0614300b3f2cad48350f88
git push origin v1.5.0

Please sign in to comment.