Skip to content

Commit

Permalink
More progress on SSymmetricCompact.
Browse files Browse the repository at this point in the history
* Extract out @pure triangularnumber and triangularroot functions, use to simplify code
* Add similar_type overload
* Speed up == overload
* Generalize +, - overloads with two SSymetricCompact matrices by overloading _fill
* More complete list of scalar-array ops
* Make transpose recursive
* Overloads for rand, randn, randexp
  • Loading branch information
tkoolen committed Jan 15, 2018
1 parent 69334a4 commit 7dfe1cc
Showing 1 changed file with 83 additions and 32 deletions.
115 changes: 83 additions & 32 deletions src/SSymmetricCompact.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,17 @@ struct SSymmetricCompact{N, T, L} <: StaticMatrix{N, N, T}
end
end

@inline function (::Type{SSymmetricCompact{N, T}})(lowertriangle::SVector{L}) where {N, T, L}
SSymmetricCompact{N, T, L}(lowertriangle)
end
lowertriangletype(::Type{SSymmetricCompact{N, T, L}}) where {N, T, L} = SVector{L, T}
lowertriangletype(::Type{<:SSymmetricCompact}) = SVector
Base.@pure triangularnumber(N::Int) = div(N * (N + 1), 2)
Base.@pure triangularroot(L::Int) = div(isqrt(8 * L + 1) - 1, 2) # from quadratic formula

@inline function (::Type{SSymmetricCompact{N}})(lowertriangle::SVector{L, T}) where {N, T, L}
SSymmetricCompact{N, T, L}(lowertriangle)
end
@inline (::Type{SSymmetricCompact{N, T}})(lowertriangle::SVector{L}) where {N, T, L} = SSymmetricCompact{N, T, L}(lowertriangle)
@inline (::Type{SSymmetricCompact{N}})(lowertriangle::SVector{L, T}) where {N, T, L} = SSymmetricCompact{N, T, L}(lowertriangle)

@generated function SSymmetricCompact(lowertriangle::SVector{L, T}) where {T, L}
N = div(isqrt(8 * L + 1) - 1, 2) # from quadratic formula
quote
@_inline_meta
SSymmetricCompact{$N, T, L}(lowertriangle)
end
@inline function SSymmetricCompact(lowertriangle::SVector{L, T}) where {T, L}
N = triangularroot(L)
SSymmetricCompact{N, T, L}(lowertriangle)
end

@generated function (::Type{SSymmetricCompact{N, T, L}})(a::Tuple) where {N, T, L}
Expand All @@ -45,12 +42,9 @@ end
end
end

@generated function (::Type{SSymmetricCompact{N, T}})(a::Tuple) where {N, T}
L = div(N * (N + 1), 2)
quote
@_inline_meta
SSymmetricCompact{N, T, $L}(a)
end
@inline function (::Type{SSymmetricCompact{N, T}})(a::Tuple) where {N, T}
L = triangularnumber(N)
SSymmetricCompact{N, T, L}(a)
end

@inline (::Type{SSymmetricCompact{N}})(a::NTuple{M, T}) where {N, T, M} = SSymmetricCompact{N, T}(a)
Expand All @@ -59,15 +53,23 @@ end
@inline (::Type{SSC})(a::SSymmetricCompact) where {SSC<:SSymmetricCompact} = SSC(a.lowertriangle)
@inline (::Type{SSC})(a::SSC) where {SSC<:SSymmetricCompact} = SSC(a.lowertriangle)

lowertriangletype(::Type{SSymmetricCompact{N, T, L}}) where {N, T, L} = SVector{L, T}
lowertriangletype(::Type{<:SSymmetricCompact}) = SVector

@inline (::Type{SSC})(a::AbstractVector) where {SSC <: SSymmetricCompact} = SSC(convert(lowertriangletype(SSC), a))
@inline (::Type{SSC})(a::Tuple) where {SSC <: SSymmetricCompact} = SSymmetricCompact(convert(lowertriangletype(SSC), a))

convert(::Type{SSC}, a::SSC) where {SSC<:SSymmetricCompact} = a
convert(::Type{SSC}, a::SSC) where {SSC<:SSymmetricCompact} = a # TODO: needed?
# TODO: more convert methods?

# TODO: is the following a good idea?
@inline function similar_type(::Type{SSC}, ::Type{T}, ::Size{S}) where {SSC<:SSymmetricCompact, T, S<:Tuple{Int, Int}}
if S[1] === S[2]
N = S[1]
L = triangularnumber(N)
SSymmetricCompact{N, T, L}
else
default_similar_type(T, S, length_val(S))
end
end

@inline indextuple(::T) where {T <: SSymmetricCompact} = indextuple(T)
@generated function indextuple(::Type{<:SSymmetricCompact{N}}) where N
indexmat = zeros(Int, N, N)
Expand Down Expand Up @@ -117,17 +119,66 @@ LinAlg.issymmetric(a::SSymmetricCompact) = true

# TODO: factorize

@inline ==(a::SSymmetricCompact, b::SSymmetricCompact) = a.lowertriangle == b.lowertriangle
@inline -(a::SSymmetricCompact) = SSymmetricCompact(-a.lowertriangle)
@inline +(a::SSymmetricCompact, b::SSymmetricCompact) = SSymmetricCompact(a.lowertriangle + b.lowertriangle)
@inline -(a::SSymmetricCompact, b::SSymmetricCompact) = SSymmetricCompact(a.lowertriangle - b.lowertriangle)
# TODO: a.lowertriangle == b.lowertriangle is slow (used by SDiagonal). SMatrix etc. actually use AbstractArray fallback (also slow)
@inline ==(a::SSymmetricCompact, b::SSymmetricCompact) = mapreduce(==, (x, y) -> x && y, a.lowertriangle, b.lowertriangle)
@generated function _map(f, ::Size{S}, a::SSymmetricCompact...) where {S}
N = S[1]
L = triangularnumber(N)
exprs = Vector{Expr}(L)
for i 1:L
tmp = [:(a[$j][$i]) for j 1:length(a)]
exprs[i] = :(f($(tmp...)))
end
return quote
@_inline_meta
@inbounds return SSymmetricCompact(SVector(tuple($(exprs...))))
end
end

# Scalar-array. TODO: overload broadcast instead, once API has stabilized a bit
@inline +(a::Number, b::SSymmetricCompact) = SSymmetricCompact(a + b.lowertriangle)
@inline +(a::SSymmetricCompact, b::Number) = SSymmetricCompact(a.lowertriangle + b)

@inline *(x::Number, a::SSymmetricCompact) = SSymmetricCompact(x * a.lowertriangle)
@inline *(a::SSymmetricCompact, x::Number) = SSymmetricCompact(a.lowertriangle * x)
@inline /(a::SSymmetricCompact, x::Number)= SSymmetricCompact(a.lowertriangle / x)
@inline -(a::Number, b::SSymmetricCompact) = SSymmetricCompact(a - b.lowertriangle)
@inline -(a::SSymmetricCompact, b::Number) = SSymmetricCompact(a.lowertriangle - b)

@inline conj(a::SSymmetricCompact) = SSymmetricCompact(conj(a.lowertriangle))
@inline transpose(a::SSymmetricCompact) = a
@inline *(a::Number, b::SSymmetricCompact) = SSymmetricCompact(a * b.lowertriangle)
@inline *(a::SSymmetricCompact, b::Number) = SSymmetricCompact(a.lowertriangle * b)

@inline /(a::SSymmetricCompact, b::Number) = SSymmetricCompact(a.lowertriangle / b)
@inline \(a::Number, b::SSymmetricCompact) = SSymmetricCompact(a \ b.lowertriangle)

# TODO: operations With UniformScaling

@inline transpose(a::SSymmetricCompact) = SSymmetricCompact(transpose.(a.lowertriangle))
@inline adjoint(a::SSymmetricCompact) = conj(a)

#TODO: eye, one, zero
#TODO: one, eye

# _fill covers fill, zeros, and ones:
@generated function _fill(val, ::Size{s}, ::Type{SSC}) where {s, SSC<:SSymmetricCompact}
N = s[1]
L = triangularnumber(N)
v = [:val for i = 1:L]
return quote
@_inline_meta
$SSC(SVector(tuple($(v...))))
end
end

@generated function _rand(randfun, rng::AbstractRNG, ::Type{SSC}) where {N, SSC<:SSymmetricCompact{N}}
T = eltype(SSC)
if T == Any
T = Float64
end
L = triangularnumber(N)
v = [:(randfun(rng, $T)) for i = 1:L]
return quote
@_inline_meta
$SSC(SVector(tuple($(v...))))
end
end

@inline rand(rng::AbstractRNG, ::Type{SSC}) where {SSC<:SSymmetricCompact} = _rand(rand, rng, SSC)
@inline randn(rng::AbstractRNG, ::Type{SSC}) where {SSC<:SSymmetricCompact} = _rand(randn, rng, SSC)
@inline randexp(rng::AbstractRNG, ::Type{SSC}) where {SSC<:SSymmetricCompact} = _rand(randexp, rng, SSC)

0 comments on commit 7dfe1cc

Please sign in to comment.