Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Storage type #23

Merged
merged 3 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "AbstractOperators"
uuid = "d9c5613a-d543-52d8-9afd-8f241a8c3f1c"
version = "0.3"
version = "0.4"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand Down
4 changes: 2 additions & 2 deletions src/calculus/AdjointOperator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ export AdjointOperator
"""
`AdjointOperator(A::AbstractOperator)`

Shorthand constructor:
Shorthand constructor:

`'(A::AbstractOperator)`

Expand All @@ -19,7 +19,7 @@ julia> [DFT(10); DCT(10)]'
"""
struct AdjointOperator{T <: AbstractOperator} <: AbstractOperator
A::T
function AdjointOperator(A::T) where {T<:AbstractOperator}
function AdjointOperator(A::T) where {T<:AbstractOperator}
is_linear(A) == false && error("Cannot transpose a nonlinear operator. You might use `jacobian`")
new{T}(A)
end
Expand Down
18 changes: 9 additions & 9 deletions src/calculus/AffineAdd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ export AffineAdd
"""
`AffineAdd(A::AbstractOperator, d, [sign = true])`

Affine addition to `AbstractOperator` with an array or scalar `d`.
Affine addition to `AbstractOperator` with an array or scalar `d`.

Use `sign = false` to perform subtraction.

Expand All @@ -26,17 +26,17 @@ true
struct AffineAdd{L <: AbstractOperator, D <: Union{AbstractArray, Number}, S} <: AbstractOperator
A::L
d::D
function AffineAdd(A::L, d::D, sign::Bool = true) where {L, D <: AbstractArray}
if size(d) != size(A,1)
function AffineAdd(A::L, d::D, sign::Bool = true) where {L, D <: AbstractArray}
if size(d) != size(A,1)
throw(DimensionMismatch("codomain size of $A not compatible with array `d` of size $(size(d))"))
end
if eltype(d) != codomainType(A)
if eltype(d) != codomainType(A)
error("cannot tilt opertor having codomain type $(codomainType(A)) with array of type $(eltype(d))")
end
new{L,D,sign}(A,d)
end
# scalar
function AffineAdd(A::L, d::D, sign::Bool = true) where {L, D <: Number}
function AffineAdd(A::L, d::D, sign::Bool = true) where {L, D <: Number}
if typeof(d) <: Complex && codomainType(A) <: Real
error("cannot tilt opertor having codomain type $(codomainType(A)) with array of type $(eltype(d))")
end
Expand All @@ -46,12 +46,12 @@ end

# Mappings
# array
function mul!(y::DD, T::AffineAdd{L, D, true}, x) where {L <: AbstractOperator, DD, D}
function mul!(y::DD, T::AffineAdd{L, D, true}, x) where {L <: AbstractOperator, DD, D}
mul!(y,T.A,x)
y .+= T.d
end

function mul!(y::DD, T::AffineAdd{L, D, false}, x) where {L <: AbstractOperator, DD, D}
function mul!(y::DD, T::AffineAdd{L, D, false}, x) where {L <: AbstractOperator, DD, D}
mul!(y,T.A,x)
y .-= T.d
end
Expand All @@ -70,7 +70,7 @@ is_null(L::AffineAdd) = is_null(L.A)
is_eye(L::AffineAdd) = is_diagonal(L.A)
is_diagonal(L::AffineAdd) = is_diagonal(L.A)
is_invertible(L::AffineAdd) = is_invertible(L.A)
is_AcA_diagonal(L::AffineAdd) = is_AcA_diagonal(L.A)
is_AcA_diagonal(L::AffineAdd) = is_AcA_diagonal(L.A)
is_AAc_diagonal(L::AffineAdd) = is_AAc_diagonal(L.A)
is_full_row_rank(L::AffineAdd) = is_full_row_rank(L.A)
is_full_column_rank(L::AffineAdd) = is_full_column_rank(L.A)
Expand All @@ -90,7 +90,7 @@ sign(T::AffineAdd{L,D, true}) where {L,D} = 1

function permute(T::AffineAdd{L,D,S}, p::AbstractVector{Int}) where {L,D,S}
A = permute(T.A,p)
return AffineAdd(A,T.d,S)
return AffineAdd(A,T.d,S)
end

displacement(A::AffineAdd{L,D,true}) where {L,D} = A.d .+ displacement(A.A)
Expand Down
19 changes: 8 additions & 11 deletions src/calculus/Ax_mul_Bx.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,10 @@ end

# Constructors
function Ax_mul_Bx(A::AbstractOperator,B::AbstractOperator)
s,t = size(A,1), codomainType(A)
bufA = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
s,t = size(B,1), codomainType(B)
bufB = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
bufC = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
s,t = size(A,2), domainType(A)
bufD = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
bufA = allocateInCodomain(A)
bufB = allocateInCodomain(B)
bufC = allocateInCodomain(B)
bufD = allocateInDomain(A)
Ax_mul_Bx(A,B,bufA,bufB,bufC,bufD)
end

Expand Down Expand Up @@ -95,16 +92,16 @@ end

size(P::Union{Ax_mul_Bx,Ax_mul_BxJac}) = ((size(P.A,1)[1],size(P.B,1)[2]),size(P.A,2))

fun_name(L::Union{Ax_mul_Bx,Ax_mul_BxJac}) = fun_name(L.A)*"*"*fun_name(L.B)
fun_name(L::Union{Ax_mul_Bx,Ax_mul_BxJac}) = fun_name(L.A)*"*"*fun_name(L.B)

domainType(L::Union{Ax_mul_Bx,Ax_mul_BxJac}) = domainType(L.A)
codomainType(L::Union{Ax_mul_Bx,Ax_mul_BxJac}) = codomainType(L.A)

# utils
function permute(P::Ax_mul_Bx{L1,L2,C,D},
function permute(P::Ax_mul_Bx{L1,L2,C,D},
p::AbstractVector{Int}) where {L1,L2,C,D <:ArrayPartition}
Ax_mul_Bx(permute(P.A,p),permute(P.B,p),P.bufA,P.bufB,P.bufC,ArrayPartition(P.bufD.x[p]) )
Ax_mul_Bx(permute(P.A,p),permute(P.B,p),P.bufA,P.bufB,P.bufC,ArrayPartition(P.bufD.x[p]))
end

remove_displacement(P::Ax_mul_Bx) =
remove_displacement(P::Ax_mul_Bx) =
Ax_mul_Bx(remove_displacement(P.A), remove_displacement(P.B), P.bufA, P.bufB, P.bufC, P.bufD)
21 changes: 9 additions & 12 deletions src/calculus/Ax_mul_Bxt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ struct Ax_mul_Bxt{
bufD::D
function Ax_mul_Bxt(A::L1, B::L2, bufA::C, bufB::C, bufC::C, bufD::D) where {L1,L2,C,D}
if ndims(A,1) == 1
if size(A) != size(B)
if size(A) != size(B)
throw(DimensionMismatch("Cannot compose operators"))
end
elseif ndims(A,1) == 2 && ndims(B,1) == 2 && size(A,2) == size(B,2)
if size(A,1)[2] != size(B,1)[2]
if size(A,1)[2] != size(B,1)[2]
throw(DimensionMismatch("Cannot compose operators"))
end
else
Expand All @@ -68,13 +68,10 @@ end

# Constructors
function Ax_mul_Bxt(A::AbstractOperator,B::AbstractOperator)
s,t = size(A,1), codomainType(A)
bufA = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
bufC = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
s,t = size(B,1), codomainType(B)
bufB = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
s,t = size(A,2), domainType(A)
bufD = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
bufA = allocateInCodomain(A)
bufB = allocateInCodomain(B)
bufC = allocateInCodomain(A)
bufD = allocateInDomain(A)
Ax_mul_Bxt(A,B,bufA,bufB,bufC,bufD)
end

Expand Down Expand Up @@ -103,16 +100,16 @@ end

size(P::Union{Ax_mul_Bxt,Ax_mul_BxtJac}) = ((size(P.A,1)[1],size(P.B,1)[1]),size(P.A,2))

fun_name(L::Union{Ax_mul_Bxt,Ax_mul_BxtJac}) = fun_name(L.A)*"*"*fun_name(L.B)
fun_name(L::Union{Ax_mul_Bxt,Ax_mul_BxtJac}) = fun_name(L.A)*"*"*fun_name(L.B)

domainType(L::Union{Ax_mul_Bxt,Ax_mul_BxtJac}) = domainType(L.A)
codomainType(L::Union{Ax_mul_Bxt,Ax_mul_BxtJac}) = codomainType(L.A)

# utils
function permute(P::Ax_mul_Bxt{L1,L2,C,D},
function permute(P::Ax_mul_Bxt{L1,L2,C,D},
p::AbstractVector{Int}) where {L1,L2,C,D <:ArrayPartition}
Ax_mul_Bxt(permute(P.A,p),permute(P.B,p),P.bufA,P.bufB,P.bufC,ArrayPartition(P.bufD.x[p]) )
end

remove_displacement(P::Ax_mul_Bxt) =
remove_displacement(P::Ax_mul_Bxt) =
Ax_mul_Bxt(remove_displacement(P.A), remove_displacement(P.B), P.bufA, P.bufB, P.bufC, P.bufD)
21 changes: 9 additions & 12 deletions src/calculus/Axt_mul_Bx.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ struct Axt_mul_Bx{N,
bufD::D
function Axt_mul_Bx(A::L1, B::L2, bufA::C, bufB::C, bufC::C, bufD::D) where {L1,L2,C,D}
if ndims(A,1) == 1
if size(A) != size(B)
if size(A) != size(B)
throw(DimensionMismatch("Cannot compose operators"))
end
elseif ndims(A,1) == 2 && ndims(B,1) == 2 && size(A,2) == size(B,2)
if size(A,1)[1] != size(B,1)[1]
if size(A,1)[1] != size(B,1)[1]
throw(DimensionMismatch("Cannot compose operators"))
end
else
Expand All @@ -69,13 +69,10 @@ end

# Constructors
function Axt_mul_Bx(A::AbstractOperator,B::AbstractOperator)
s,t = size(A,1), codomainType(A)
bufA = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
bufC = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
s,t = size(B,1), codomainType(B)
bufB = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
s,t = size(A,2), domainType(A)
bufD = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
bufA = allocateInCodomain(A)
bufB = allocateInCodomain(B)
bufC = allocateInCodomain(A)
bufD = allocateInDomain(A)
Axt_mul_Bx(A,B,bufA,bufB,bufC,bufD)
end

Expand Down Expand Up @@ -122,16 +119,16 @@ end
size(P::Union{Axt_mul_Bx{1},Axt_mul_BxJac{1}}) = ((1,),size(P.A,2))
size(P::Union{Axt_mul_Bx{2},Axt_mul_BxJac{2}}) = ((size(P.A,1)[2],size(P.B,1)[2]),size(P.A,2))

fun_name(L::Union{Axt_mul_Bx,Axt_mul_BxJac}) = fun_name(L.A)*"*"*fun_name(L.B)
fun_name(L::Union{Axt_mul_Bx,Axt_mul_BxJac}) = fun_name(L.A)*"*"*fun_name(L.B)

domainType(L::Union{Axt_mul_Bx,Axt_mul_BxJac}) = domainType(L.A)
codomainType(L::Union{Axt_mul_Bx,Axt_mul_BxJac}) = codomainType(L.A)

# utils
function permute(P::Axt_mul_Bx{N,L1,L2,C,D},
function permute(P::Axt_mul_Bx{N,L1,L2,C,D},
p::AbstractVector{Int}) where {N,L1,L2,C,D <:ArrayPartition}
Axt_mul_Bx(permute(P.A,p),permute(P.B,p),P.bufA,P.bufB,P.bufC,ArrayPartition(P.bufD.x[p]) )
end

remove_displacement(P::Axt_mul_Bx) =
remove_displacement(P::Axt_mul_Bx) =
Axt_mul_Bx(remove_displacement(P.A), remove_displacement(P.B), P.bufA, P.bufB, P.bufC, P.bufD)
19 changes: 9 additions & 10 deletions src/calculus/BroadCast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ julia> B*[1.;2.]
```

"""
struct BroadCast{N,
L <: AbstractOperator,
T <: AbstractArray,
struct BroadCast{N,
L <: AbstractOperator,
T <: AbstractArray,
D <: AbstractArray,
M,
C <: NTuple{M,Colon},
Expand All @@ -36,14 +36,14 @@ struct BroadCast{N,
cols::C
idxs::I

function BroadCast(A::L,dim_out::NTuple{N,Int},bufC::T, bufD::D) where {N,
L<:AbstractOperator,
function BroadCast(A::L,dim_out::NTuple{N,Int},bufC::T, bufD::D) where {N,
L<:AbstractOperator,
T<:AbstractArray,
D<:AbstractArray
}
Base.Broadcast.check_broadcast_shape(dim_out,size(A,1))
if size(A,1) != (1,)
M = length(size(A,1))
M = length(size(A,1))
cols = ([Colon() for i = 1:M]...,)
idxs = CartesianIndices((dim_out[M+1:end]...,))
new{N,L,T,D,M,typeof(cols),typeof(idxs)}(A,dim_out,bufC,bufD,cols,idxs)
Expand All @@ -52,14 +52,13 @@ struct BroadCast{N,
idxs = CartesianIndices((1,))
new{N,L,T,D,M,NTuple{0,Colon},typeof(idxs)}(A,dim_out,bufC,bufD,(),idxs)
end

end
end

# Constructors

BroadCast(A::L, dim_out::NTuple{N,Int}) where {N,L<:AbstractOperator} =
BroadCast(A, dim_out, zeros(codomainType(A),size(A,1)), zeros(domainType(A),size(A,2)) )
BroadCast(A, dim_out, allocateInCodomain(A), allocateInDomain(A))

# Mappings

Expand All @@ -82,7 +81,7 @@ end
function mul!(y::CC, A::AdjointOperator{BroadCast{N,L,T,D,0,C,I}}, b::DD) where {N,L,T,D,C,I,CC,DD}
R = A.A
fill!(y, 0.)
bii = zeros(eltype(b),1)
bii = allocateInCodomain(R.A)
for bi in b
bii[1] = bi
mul!(R.bufD, R.A', bii)
Expand All @@ -92,7 +91,7 @@ function mul!(y::CC, A::AdjointOperator{BroadCast{N,L,T,D,0,C,I}}, b::DD) where
end

#TODO make this more general
#length(dim_out) == size(A,1) e.g. a .= b; size(a) = (m,n) size(b) = (1,n) matrix out, column in
#length(dim_out) == size(A,1) e.g. a .= b; size(a) = (m,n) size(b) = (1,n) matrix out, column in
function mul!(y::CC, A::AdjointOperator{BroadCast{2,L,T,D,2,C,I}}, b::DD) where {L,T,D,C,I,CC,DD}
R = A.A
fill!(y, 0.)
Expand Down
17 changes: 10 additions & 7 deletions src/calculus/Compose.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ export Compose
"""
`Compose(A::AbstractOperator,B::AbstractOperator)`

Shorthand constructor:
Shorthand constructor:

`A*B`
`A*B`

Compose different `AbstractOperator`s. Notice that the domain and codomain of the operators `A` and `B` must match, i.e. `size(A,2) == size(B,1)` and `domainType(A) == codomainType(B)`.

Expand All @@ -28,19 +28,22 @@ end

function Compose(L1::AbstractOperator, L2::AbstractOperator)
if size(L1,2) != size(L2,1)
throw(DimensionMismatch("cannot compose operators"))
throw(DimensionMismatch("cannot compose operators with different domain and codomain sizes"))
end
if domainType(L1) != codomainType(L2)
throw(DomainError())
throw(DomainError((domainType(L1),codomainType(L2)), "cannot compose operators with different domain and codomain types"))
end
Compose( L1, L2, Array{domainType(L1)}(undef,size(L2,1)) )
if domainStorageType(L1) != codomainStorageType(L2)
throw(DomainError((domainStorageType(L1),codomainStorageType(L2)), "cannot compose operators with different input and output storage types"))
end
Compose(L1, L2, allocateInCodomain(L2))
end

Compose(L1::AbstractOperator,L2::AbstractOperator,buf::AbstractArray) =
Compose( (L2,L1), (buf,))
Compose((L2,L1), (buf,))

Compose(L1::Compose, L2::AbstractOperator,buf::AbstractArray) =
Compose( (L2,L1.A...), (buf,L1.buf...))
Compose((L2,L1.A...), (buf,L1.buf...))

Compose(L1::AbstractOperator,L2::Compose, buf::AbstractArray) =
Compose((L2.A...,L1), (L2.buf...,buf))
Expand Down
Loading
Loading