diff --git a/NEWS.md b/NEWS.md index e54201ab67..89c862fe12 100644 --- a/NEWS.md +++ b/NEWS.md @@ -5,6 +5,15 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.9.9] – 2023-12-24 + +### Fixed + +* introduced a nonzero `atol` for all point and vector checks that compre to zero. + This makes those checks a bit more relaxed by default and resolves [#630](https://github.com/JuliaManifolds/Manifolds.jl/issues/630). +* `default_estimation_method(M, f)` is deprecated, use `default_approximation_method(M, f)` for your specific method `f` on the manifold `M`. +* `AbstractEstimationMethod` is deprecated, use `AbstractApproximationMethod` instead. + ## [0.9.8] - 2023-11-17 ### Fixed diff --git a/Project.toml b/Project.toml index 490fd5c2bf..54bb294f41 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Manifolds" uuid = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e" authors = ["Seth Axen ", "Mateusz Baran ", "Ronny Bergmann ", "Antoine Levitt "] -version = "0.9.8" +version = "0.9.9" [deps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" @@ -42,7 +42,7 @@ ManifoldsRecipesBaseExt = ["Colors", "RecipesBase"] ManifoldsTestExt = "Test" [compat] -BoundaryValueDiffEq = "4, 5" +BoundaryValueDiffEq = "4, 5.6.1" Colors = "0.12" Distributions = "0.22.6, 0.23, 0.24, 0.25" Einsum = "0.4" @@ -51,7 +51,7 @@ HybridArrays = "0.4" Kronecker = "0.4, 0.5" LinearAlgebra = "1.6" ManifoldDiff = "0.3.7" -ManifoldsBase = "0.15.0" +ManifoldsBase = "0.15.6" Markdown = "1.6" MatrixEquations = "2.2" OrdinaryDiffEq = "6.31" diff --git a/README.md b/README.md index afd446fba5..8cebb371b0 100644 --- a/README.md +++ b/README.md @@ -53,17 +53,19 @@ If you have any questions regarding the Manifolds.jl ecosystem feel free to reac ## Citation -If you use `Manifolds.jl` in your work, please cite the following +If you use `Manifolds.jl` in your work, please cite the following open access article ```biblatex @article{AxenBaranBergmannRzecki:2023, - AUTHOR = {Seth D. Axen and Mateusz Baran and Ronny Bergmann and Krzysztof Rzecki}, - EPRINT = {2021.08777}, - EPRINTTYPE = {arXiv}, - JOURNAL = {AMS Transactions on Mathematical Software}, - NOTE = {accepted for publication}, - TITLE = {Manifolds.jl: An Extensible {J}ulia Framework for Data Analysis on Manifolds}, - YEAR = {2023} + AUTHOR = {Axen, Seth D. and Baran, Mateusz and Bergmann, Ronny and Rzecki, Krzysztof}, + ARTICLENO = {33}, + DOI = {10.1145/3618296}, + JOURNAL = {ACM Transactions on Mathematical Software}, + MONTH = {dec}, + NUMBER = {4}, + TITLE = {Manifolds.Jl: An Extensible Julia Framework for Data Analysis on Manifolds}, + VOLUME = {49}, + YEAR = {2023} } ``` diff --git a/docs/src/references.bib b/docs/src/references.bib index ecd9632754..0b43a860f6 100644 --- a/docs/src/references.bib +++ b/docs/src/references.bib @@ -89,13 +89,15 @@ @book{AyJostLeSchwachhoefer:2017 PUBLISHER = {Springer Cham} } @article{AxenBaranBergmannRzecki:2023, - AUTHOR = {Seth D. Axen and Mateusz Baran and Ronny Bergmann and Krzysztof Rzecki}, - EPRINT = {2021.08777}, - EPRINTTYPE = {arXiv}, - JOURNAL = {AMS Transactions on Mathematical Software}, - NOTE = {accepted for publication}, - TITLE = {Manifolds.jl: An Extensible {J}ulia Framework for Data Analysis on Manifolds}, - YEAR = {2023} + AUTHOR = {Axen, Seth D. and Baran, Mateusz and Bergmann, Ronny and Rzecki, Krzysztof}, + ARTICLENO = {33}, + DOI = {10.1145/3618296}, + JOURNAL = {ACM Transactions on Mathematical Software}, + MONTH = {dec}, + NUMBER = {4}, + TITLE = {Manifolds.Jl: An Extensible Julia Framework for Data Analysis on Manifolds}, + VOLUME = {49}, + YEAR = {2023} } # # B diff --git a/src/Manifolds.jl b/src/Manifolds.jl index 3ca4e4c428..89c466e728 100644 --- a/src/Manifolds.jl +++ b/src/Manifolds.jl @@ -48,6 +48,7 @@ import ManifoldsBase: check_vector, copy, copyto!, + default_approximation_method, default_inverse_retraction_method, default_retraction_method, default_vector_transport_method, @@ -149,19 +150,14 @@ import ManifoldsBase: submanifold_component, submanifold_components, vector_space_dimension, - vector_transport_along, # just specified in Euclidean - the next 5 as well - vector_transport_along_diff, - vector_transport_along_project, + vector_transport_along, # just specified in Euclidean - the next 5 as well vector_transport_along!, - vector_transport_along_diff!, - vector_transport_along_project!, + vector_transport_along_diff!, # For consistency these are imported, but for now not + vector_transport_along_project!, # overwritten with new definitons. vector_transport_direction, - vector_transport_direction_diff, vector_transport_direction!, vector_transport_direction_diff!, vector_transport_to, - vector_transport_to_diff, - vector_transport_to_project, vector_transport_to!, vector_transport_to_diff!, vector_transport_to_project!, # some overwrite layer 2 @@ -186,6 +182,9 @@ import ManifoldDiff: riemannian_Hessian, riemannian_Hessian! +import Statistics: mean, mean!, median, median!, cov, var +import StatsBase: mean_and_var + using Base.Iterators: repeated using Distributions using Einsum: @einsum @@ -198,6 +197,7 @@ using ManifoldsBase: ℝ, ℂ, ℍ, + AbstractApproximationMethod, AbstractBasis, AbstractDecoratorManifold, AbstractInverseRetractionMethod, @@ -225,6 +225,7 @@ using ManifoldsBase: CotangentSpaceType, CoTFVector, CoTVector, + CyclicProximalPointEstimation, DefaultBasis, DefaultOrthogonalBasis, DefaultOrthonormalBasis, @@ -232,13 +233,18 @@ using ManifoldsBase: DiagonalizingBasisData, DiagonalizingOrthonormalBasis, DifferentiatedRetractionVectorTransport, + EfficientEstimator, EmbeddedManifold, EmptyTrait, EuclideanMetric, ExponentialRetraction, + ExtrinsicEstimation, Fiber, FiberType, FVector, + GeodesicInterpolation, + GeodesicInterpolationWithinRadius, + GradientDescentEstimation, InverseProductRetraction, IsIsometricEmbeddedManifold, IsEmbeddedManifold, @@ -295,6 +301,7 @@ using ManifoldsBase: VectorSpaceFiber, VectorSpaceType, VeeOrthogonalBasis, + WeiszfeldEstimation, @invoke_maker, _euclidean_basis_vector, combine_allocation_promotion_functions, @@ -607,8 +614,9 @@ include("deprecated.jl") export test_manifold export test_group, test_action -# +# Abstract main types export CoTVector, AbstractManifold, AbstractManifoldPoint, TVector +# Manifolds export AbstractSphere, AbstractProjectiveSpace export Euclidean, ArrayProjectiveSpace, @@ -748,9 +756,10 @@ export AbstractInverseRetractionMethod, ShootingInverseRetraction, SoftmaxInverseRetraction # Estimation methods for median and mean -export AbstractEstimationMethod, +export AbstractApproximationMethod, GradientDescentEstimation, CyclicProximalPointEstimation, + EfficientEstimator, GeodesicInterpolation, GeodesicInterpolationWithinRadius, ExtrinsicEstimation @@ -788,9 +797,10 @@ export ×, convert, complex_dot, decorated_manifold, - default_vector_transport_method, + default_approximation_method, default_inverse_retraction_method, default_retraction_method, + default_vector_transport_method, det_local_metric, differential_canonical_project, differential_canonical_project!, diff --git a/src/deprecated.jl b/src/deprecated.jl index 8b13789179..03601238b0 100644 --- a/src/deprecated.jl +++ b/src/deprecated.jl @@ -1 +1,6 @@ +@deprecate default_estimation_method(M::AbstractManifold, f) default_approximation_method( + M, + f, +) +@deprecate ExtrinsicEstimation() ExtrinsicEstimation(EfficientEstimator()) diff --git a/src/groups/addition_operation.jl b/src/groups/addition_operation.jl index d4a8920e7c..885974cf53 100644 --- a/src/groups/addition_operation.jl +++ b/src/groups/addition_operation.jl @@ -83,8 +83,14 @@ function inv_diff!(::AdditionGroupTrait, G::AbstractDecoratorManifold, Y, p, X) return Y end -function is_identity(::AdditionGroupTrait, G::AbstractDecoratorManifold, q; kwargs...) - return isapprox(G, q, zero(q); kwargs...) +function is_identity( + ::AdditionGroupTrait, + G::AbstractDecoratorManifold, + q::T; + atol::Real=sqrt(prod(representation_size(G))) * eps(real(float(number_eltype(T)))), + kwargs..., +) where {T} + return isapprox(G, q, zero(q); atol=atol, kwargs...) end # resolve ambiguities function is_identity( diff --git a/src/groups/group_action.jl b/src/groups/group_action.jl index 782ce45ecb..d31dd18786 100644 --- a/src/groups/group_action.jl +++ b/src/groups/group_action.jl @@ -116,7 +116,7 @@ where element `a` is acting on `p`, with respect to the group element. Let ``\mathcal G`` be the group acting on manifold ``\mathcal M`` by the action `A`. The action is of element ``g ∈ \mathcal G`` on a point ``p ∈ \mathcal M``. -The differential transforms vector `X` from the tangent space at `a ∈ \mathcal G`, +The differential transforms vector `X` from the tangent space at `a ∈ \mathcal G`, ``X ∈ T_a \mathcal G`` into a tangent space of the manifold ``\mathcal M``. When action on element `p` is written as ``\mathrm{d}τ^p``, with the specified left or right convention, the differential transforms vectors @@ -193,7 +193,7 @@ end A::AbstractGroupAction, pts, p, - mean_method::AbstractEstimationMethod = GradientDescentEstimation(), + mean_method::AbstractApproximationMethod = GradientDescentEstimation(), ) Calculate an action element ``a`` of action `A` that is the mean element of the orbit of `p` @@ -210,7 +210,7 @@ function center_of_orbit( A::AbstractGroupAction, pts::AbstractVector, q, - mean_method::AbstractEstimationMethod=GradientDescentEstimation(), + mean_method::AbstractApproximationMethod=GradientDescentEstimation(), ) alignments = map(p -> optimal_alignment(A, q, p), pts) return mean(base_group(A), alignments, mean_method) diff --git a/src/groups/group_operation_action.jl b/src/groups/group_operation_action.jl index 2a7d06557d..1b9098ed15 100644 --- a/src/groups/group_operation_action.jl +++ b/src/groups/group_operation_action.jl @@ -199,7 +199,7 @@ function center_of_orbit( A::GroupOperationAction, pts::AbstractVector, p, - mean_method::AbstractEstimationMethod, + mean_method::AbstractApproximationMethod, ) μ = mean(A.group, pts, mean_method) return inverse_apply(switch_direction_and_side(A), p, μ) diff --git a/src/groups/special_linear.jl b/src/groups/special_linear.jl index 7ffb5fde1e..e8f8eb05a9 100644 --- a/src/groups/special_linear.jl +++ b/src/groups/special_linear.jl @@ -49,9 +49,15 @@ function check_point(G::SpecialLinear, p; kwargs...) return nothing end -function check_vector(G::SpecialLinear, p, X; kwargs...) +function check_vector( + G::SpecialLinear, + p, + X::T; + atol::Real=sqrt(prod(representation_size(G))) * eps(real(float(number_eltype(T)))), + kwargs..., +) where {T} trX = tr(inverse_translate_diff(G, p, p, X, LeftForwardAction())) - if !isapprox(trX, 0; kwargs...) + if !isapprox(trX, 0; atol=atol, kwargs...) return DomainError( trX, "The matrix $(X) does not lie in the tangent space of $(G) at $(p), since " * diff --git a/src/manifolds/CenteredMatrices.jl b/src/manifolds/CenteredMatrices.jl index 0856c49a52..e5f801797a 100644 --- a/src/manifolds/CenteredMatrices.jl +++ b/src/manifolds/CenteredMatrices.jl @@ -35,9 +35,14 @@ zero. The tolerance for the column sums of `p` can be set using `kwargs...`. """ -function check_point(M::CenteredMatrices, p; kwargs...) +function check_point( + M::CenteredMatrices, + p::T; + atol::Real=sqrt(prod(representation_size(M))) * eps(real(float(number_eltype(T)))), + kwargs..., +) where {T} m, n = get_parameter(M.size) - if !isapprox(sum(p, dims=1), zeros(1, n); kwargs...) + if !isapprox(sum(p, dims=1), zeros(1, n); atol=atol, kwargs...) return DomainError( p, string( @@ -56,9 +61,15 @@ Check whether `X` is a tangent vector to manifold point `p` on the sum to zero and its values are from the correct [`AbstractNumbers`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#number-system). The tolerance for the column sums of `p` and `X` can be set using `kwargs...`. """ -function check_vector(M::CenteredMatrices, p, X; kwargs...) +function check_vector( + M::CenteredMatrices, + p, + X::T; + atol::Real=sqrt(prod(representation_size(M))) * eps(real(float(number_eltype(T)))), + kwargs..., +) where {T} m, n = get_parameter(M.size) - if !isapprox(sum(X, dims=1), zeros(1, n); kwargs...) + if !isapprox(sum(X, dims=1), zeros(1, n); atol=atol, kwargs...) return DomainError( X, "The vector $(X) is not a tangent vector to $(p) on $(M), since its columns do not sum to zero.", diff --git a/src/manifolds/CholeskySpace.jl b/src/manifolds/CholeskySpace.jl index 5127235a07..2d2226c4b4 100644 --- a/src/manifolds/CholeskySpace.jl +++ b/src/manifolds/CholeskySpace.jl @@ -28,10 +28,15 @@ it's size fits the manifold, it is a lower triangular matrix and has positive entries on the diagonal. The tolerance for the tests can be set using the `kwargs...`. """ -function check_point(M::CholeskySpace, p; kwargs...) +function check_point( + M::CholeskySpace, + p::T; + atol::Real=sqrt(prod(representation_size(M))) * eps(real(float(number_eltype(T)))), + kwargs..., +) where {T} cks = check_size(M, p) cks === nothing || return cks - if !isapprox(norm(strictlyUpperTriangular(p)), 0.0; kwargs...) + if !isapprox(norm(strictlyUpperTriangular(p)), 0.0; atol=atol, kwargs...) return DomainError( norm(UpperTriangular(p) - Diagonal(p)), "The point $(p) does not lie on $(M), since it strictly upper triangular nonzero entries", @@ -54,8 +59,14 @@ after [`check_point`](@ref)`(M,p)`, `X` has to have the same dimension as `p` and a symmetric matrix. The tolerance for the tests can be set using the `kwargs...`. """ -function check_vector(M::CholeskySpace, p, X; kwargs...) - if !isapprox(norm(strictlyUpperTriangular(X)), 0.0; kwargs...) +function check_vector( + M::CholeskySpace, + p, + X; + atol::Real=sqrt(prod(representation_size(M)) * eps(float(eltype(p)))), + kwargs..., +) + if !isapprox(norm(strictlyUpperTriangular(X)), 0.0; atol=atol, kwargs...) return DomainError( norm(UpperTriangular(X) - Diagonal(X)), "The matrix $(X) is not a tangent vector at $(p) (represented as an element of the Lie algebra) since it is not lower triangular.", diff --git a/src/manifolds/Circle.jl b/src/manifolds/Circle.jl index bacf0320a1..c6b2d7a2f3 100644 --- a/src/manifolds/Circle.jl +++ b/src/manifolds/Circle.jl @@ -78,8 +78,14 @@ check_vector(::Circle{ℝ}, ::Any...; ::Any...) function check_vector(M::Circle{ℝ}, p, X; kwargs...) return nothing end -function check_vector(M::Circle{ℂ}, p, X; kwargs...) - if !isapprox(abs(complex_dot(p, X)), 0.0; kwargs...) +function check_vector( + M::Circle{ℂ}, + p, + X::T; + atol::Real=sqrt(eps(real(float(number_eltype(T))))), + kwargs..., +) where {T} + if !isapprox(abs(complex_dot(p, X)), 0; atol=atol, kwargs...) return DomainError( abs(complex_dot(p, X)), "The value $(X) is not a tangent vector to $(p) on $(M), since it is not orthogonal in the embedding.", @@ -160,7 +166,7 @@ function Base.exp(M::Circle{ℂ}, p::Number, X::Number, t::Number) end exp!(::Circle{ℝ}, q, p, X) = (q .= sym_rem(p + X)) -exp!(::Circle{ℝ}, q, p, X, t::Number) = (q .= sym_rem(p + t * X)) +exp!(::Circle{ℝ}, q, p, X, t::Number) = (q .= sym_rem(p[] + t * X[])) function exp!(M::Circle{ℂ}, q, p, X) θ = norm(M, p, X) q .= cos(θ) * p + usinc(θ) * X @@ -538,6 +544,11 @@ function parallel_transport_to!(M::Circle{ℂ}, Y, p, X, q) return Y end +# dispatch before allocation +function _vector_transport_direction(M::Circle, p, X, d, ::ParallelTransport) + return parallel_transport_to(M, p, X, exp(M, p, d)) +end + """ volume_density(::Circle, p, X) diff --git a/src/manifolds/Elliptope.jl b/src/manifolds/Elliptope.jl index 18b428c3af..e7d7d7432d 100644 --- a/src/manifolds/Elliptope.jl +++ b/src/manifolds/Elliptope.jl @@ -85,10 +85,16 @@ zero diagonal. The tolerance for the base point check and zero diagonal can be set using the `kwargs...`. Note that symmetric of $X$ holds by construction an is not explicitly checked. """ -function check_vector(M::Elliptope, q, Y; kwargs...) +function check_vector( + M::Elliptope, + q, + Y::T; + atol::Real=sqrt(prod(representation_size(M))) * eps(real(float(number_eltype(T)))), + kwargs..., +) where {T} X = q * Y' + Y * q' n = diag(X) - if !all(isapprox.(n, 0.0; kwargs...)) + if !all(isapprox.(n, 0.0; atol=atol, kwargs...)) return DomainError( n, "The vector $(X) is not a tangent to a point on $(M) (represented py $(q) and $(Y), since its diagonal is nonzero.", diff --git a/src/manifolds/EmbeddedTorus.jl b/src/manifolds/EmbeddedTorus.jl index ae06b8cd86..9adb27d2d2 100644 --- a/src/manifolds/EmbeddedTorus.jl +++ b/src/manifolds/EmbeddedTorus.jl @@ -52,9 +52,9 @@ Check whether `X` is a valid vector tangent to `p` on the [`EmbeddedTorus`](@ref The method checks if the vector `X` is orthogonal to the vector normal to the torus, see [`normal_vector`](@ref). Absolute tolerance can be set using `atol`. """ -function check_vector(M::EmbeddedTorus, p, X; atol=eps(eltype(p)), kwargs...) +function check_vector(M::EmbeddedTorus, p, X; atol::Real=eps(float(eltype(p))), kwargs...) dot_nX = dot(normal_vector(M, p), X) - if !isapprox(dot_nX, 0; atol, kwargs...) + if !isapprox(dot_nX, 0; atol=atol, kwargs...) return DomainError(dot_nX, "The vector $(X) is not tangent to $(p) from $(M).") end return nothing diff --git a/src/manifolds/EssentialManifold.jl b/src/manifolds/EssentialManifold.jl index 8d46f2ffca..6d585d4974 100644 --- a/src/manifolds/EssentialManifold.jl +++ b/src/manifolds/EssentialManifold.jl @@ -125,8 +125,14 @@ exp(::EssentialManifold, ::Any...) get_iterator(::EssentialManifold) = Base.OneTo(2) -function _isapprox(M::EssentialManifold, p, q; kwargs...) - return isapprox(distance(M, p, q), 0.0; kwargs...) +function _isapprox( + M::EssentialManifold, + p, + q::T; + atol::Real=eps(real(float(number_eltype(number_eltype(T))))), + kwargs..., +) where {T} + return isapprox(distance(M, p, q), 0.0; atol=atol, kwargs...) end """ diff --git a/src/manifolds/Euclidean.jl b/src/manifolds/Euclidean.jl index c5173c21c0..f2160ef492 100644 --- a/src/manifolds/Euclidean.jl +++ b/src/manifolds/Euclidean.jl @@ -114,6 +114,11 @@ function check_vector(M::Euclidean{N,𝔽}, p, X; kwargs...) where {N,𝔽} return nothing end +default_approximation_method(::Euclidean, ::typeof(mean)) = EfficientEstimator() +function default_approximation_method(::Euclidean, ::typeof(median), ::Type{<:Number}) + return EfficientEstimator() +end + function det_local_metric( ::MetricManifold{𝔽,<:AbstractManifold,EuclideanMetric}, p, @@ -504,23 +509,65 @@ Return volume of the [`Euclidean`](@ref) manifold, i.e. infinity. """ manifold_volume(::Euclidean) = Inf -Statistics.mean(::Euclidean{Tuple{}}, x::AbstractVector{<:Number}; kwargs...) = mean(x) function Statistics.mean( ::Union{Euclidean{TypeParameter{Tuple{}}},Euclidean{Tuple{}}}, - x::AbstractVector{<:Number}; + x::AbstractVector, + ::EfficientEstimator; kwargs..., ) return mean(x) end function Statistics.mean( ::Union{Euclidean{TypeParameter{Tuple{}}},Euclidean{Tuple{}}}, - x::AbstractVector{<:Number}, - w::AbstractWeights; + x::AbstractVector, + w::AbstractWeights, + ::EfficientEstimator; kwargs..., ) return mean(x, w) end -Statistics.mean(::Euclidean, x::AbstractVector; kwargs...) = mean(x) +# +# When Statistics / Statsbase.mean! is consistent with mean, we can pass this on to them as well +function Statistics.mean!( + ::Euclidean, + y, + x::AbstractVector, + ::EfficientEstimator; + kwargs..., +) + n = length(x) + copyto!(y, first(x)) + @inbounds for j in 2:n + y .+= x[j] + end + y ./= n + return y +end +function Statistics.mean!( + ::Euclidean, + y, + x::AbstractVector, + w::AbstractWeights, + ::EfficientEstimator; + kwargs..., +) + n = length(x) + if length(w) != n + throw( + DimensionMismatch( + "The number of weights ($(length(w))) does not match the number of points for the mean ($(n)).", + ), + ) + end + copyto!(y, first(x)) + y .*= first(w) + @inbounds for j in 2:n + iszero(w[j]) && continue + y .+= w[j] .* x[j] + end + y ./= sum(w) + return y +end function StatsBase.mean_and_var( ::Union{Euclidean{TypeParameter{Tuple{}}},Euclidean{Tuple{}}}, @@ -543,7 +590,8 @@ end function Statistics.median( ::Union{Euclidean{TypeParameter{Tuple{}}},Euclidean{Tuple{}}}, - x::AbstractVector{<:Number}; + x::AbstractVector{<:Number}, + ::EfficientEstimator; kwargs..., ) return median(x) @@ -551,7 +599,8 @@ end function Statistics.median( ::Union{Euclidean{TypeParameter{Tuple{}}},Euclidean{Tuple{}}}, x::AbstractVector{<:Number}, - w::AbstractWeights; + w::AbstractWeights, + ::EfficientEstimator; kwargs..., ) return median(x, w) diff --git a/src/manifolds/FixedRankMatrices.jl b/src/manifolds/FixedRankMatrices.jl index 97bc66fd18..ddc145ea82 100644 --- a/src/manifolds/FixedRankMatrices.jl +++ b/src/manifolds/FixedRankMatrices.jl @@ -324,15 +324,21 @@ Check whether the tangent [`UMVTVector`](@ref) `X` is from the tangent space of [`FixedRankMatrices`](@ref) `M`, i.e. that `v.U` and `v.Vt` are (columnwise) orthogonal to `x.U` and `x.Vt`, respectively, and its dimensions are consistent with `p` and `X.M`, i.e. correspond to `m`-by-`n` matrices of rank `k`. """ -function check_vector(M::FixedRankMatrices, p::SVDMPoint, X::UMVTVector; kwargs...) +function check_vector( + M::FixedRankMatrices, + p::SVDMPoint, + X::UMVTVector; + atol::Real=sqrt(prod(representation_size(M)) * eps(float(eltype(p.U)))), + kwargs..., +) m, n, k = get_parameter(M.size) - if !isapprox(X.U' * p.U, zeros(k, k); kwargs...) + if !isapprox(X.U' * p.U, zeros(k, k); atol=atol, kwargs...) return DomainError( norm(X.U' * p.U - zeros(k, k)), "The tangent vector $(X) is not a tangent vector to $(p) on $(M) since v.U'x.U is not zero. ", ) end - if !isapprox(X.Vt * p.Vt', zeros(k, k); kwargs...) + if !isapprox(X.Vt * p.Vt', zeros(k, k); atol=atol, kwargs...) return DomainError( norm(X.Vt * p.Vt - zeros(k, k)), "The tangent vector $(X) is not a tangent vector to $(p) on $(M) since v.V'x.V is not zero.", diff --git a/src/manifolds/GeneralUnitaryMatrices.jl b/src/manifolds/GeneralUnitaryMatrices.jl index 5f81197b8d..f6dc9e1df2 100644 --- a/src/manifolds/GeneralUnitaryMatrices.jl +++ b/src/manifolds/GeneralUnitaryMatrices.jl @@ -178,7 +178,7 @@ function cos_angles_4d_rotation_matrix(R) return ((a + b) / 4, (a - b) / 4) end -function default_estimation_method(::GeneralUnitaryMatrices{<:Any,ℝ}, ::typeof(mean)) +function default_approximation_method(::GeneralUnitaryMatrices{<:Any,ℝ}, ::typeof(mean)) return GeodesicInterpolationWithinRadius(π / 2 / √2) end diff --git a/src/manifolds/GeneralizedGrassmann.jl b/src/manifolds/GeneralizedGrassmann.jl index 2fe79dd922..74dd12a183 100644 --- a/src/manifolds/GeneralizedGrassmann.jl +++ b/src/manifolds/GeneralizedGrassmann.jl @@ -280,7 +280,7 @@ Compute the Riemannian [`mean`](@ref mean(M::AbstractManifold, args...)) of `x` """ mean(::GeneralizedGrassmann, ::Any...) -function default_estimation_method(::GeneralizedGrassmann, ::typeof(mean)) +function default_approximation_method(::GeneralizedGrassmann, ::typeof(mean)) return GeodesicInterpolationWithinRadius(π / 4) end diff --git a/src/manifolds/Grassmann.jl b/src/manifolds/Grassmann.jl index c8f23a7612..b3edc40ca8 100644 --- a/src/manifolds/Grassmann.jl +++ b/src/manifolds/Grassmann.jl @@ -171,7 +171,7 @@ Compute the Riemannian [`mean`](@ref mean(M::AbstractManifold, args...)) of `x` """ mean(::Grassmann, ::Any...) -function default_estimation_method(::Grassmann, ::typeof(mean)) +function default_approximation_method(::Grassmann, ::typeof(mean)) return GeodesicInterpolationWithinRadius(π / 4) end diff --git a/src/manifolds/Hyperbolic.jl b/src/manifolds/Hyperbolic.jl index 5fd2e15a10..e6090fea59 100644 --- a/src/manifolds/Hyperbolic.jl +++ b/src/manifolds/Hyperbolic.jl @@ -332,7 +332,7 @@ Compute the Riemannian [`mean`](@ref mean(M::AbstractManifold, args...)) of `x` """ mean(::Hyperbolic, ::Any...) -default_estimation_method(::Hyperbolic, ::typeof(mean)) = CyclicProximalPointEstimation() +default_approximation_method(::Hyperbolic, ::typeof(mean)) = CyclicProximalPointEstimation() @doc raw""" project(M::Hyperbolic, p, X) diff --git a/src/manifolds/HyperbolicHyperboloid.jl b/src/manifolds/HyperbolicHyperboloid.jl index f0b15489b8..b5df6543bb 100644 --- a/src/manifolds/HyperbolicHyperboloid.jl +++ b/src/manifolds/HyperbolicHyperboloid.jl @@ -35,8 +35,14 @@ function check_point(M::Hyperbolic, p; kwargs...) return nothing end -function check_vector(M::Hyperbolic, p, X; kwargs...) - if !isapprox(minkowski_metric(p, X), 0.0; kwargs...) +function check_vector( + M::Hyperbolic, + p, + X::T; + atol::Real=sqrt(prod(representation_size(M))) * eps(real(float(number_eltype(T)))), + kwargs..., +) where {T} + if !isapprox(minkowski_metric(p, X), 0; atol=atol, kwargs...) return DomainError( abs(minkowski_metric(p, X)), "The vector $(X) is not a tangent vector to $(p) on $(M), since it is not orthogonal (with respect to the Minkowski inner product) in the embedding.", diff --git a/src/manifolds/KendallsPreShapeSpace.jl b/src/manifolds/KendallsPreShapeSpace.jl index 3a34719ad9..3b9ef08dc6 100644 --- a/src/manifolds/KendallsPreShapeSpace.jl +++ b/src/manifolds/KendallsPreShapeSpace.jl @@ -37,9 +37,14 @@ representation_size(M::KendallsPreShapeSpace) = get_parameter(M.size) Check whether `p` is a valid point on [`KendallsPreShapeSpace`](@ref), i.e. whether each row has zero mean. Other conditions are checked via embedding in [`ArraySphere`](@ref). """ -function check_point(M::KendallsPreShapeSpace, p; atol=sqrt(eps(eltype(p))), kwargs...) +function check_point( + M::KendallsPreShapeSpace, + p; + atol::Real=sqrt(eps(float(eltype(p)))), + kwargs..., +) for p_row in eachrow(p) - if !isapprox(mean(p_row), 0; atol, kwargs...) + if !isapprox(mean(p_row), 0; atol=atol, kwargs...) return DomainError( mean(p_row), "The point $(p) does not lie on the $(M) since one of the rows does not have zero mean.", @@ -55,9 +60,15 @@ end Check whether `X` is a valid tangent vector on [`KendallsPreShapeSpace`](@ref), i.e. whether each row has zero mean. Other conditions are checked via embedding in [`ArraySphere`](@ref). """ -function check_vector(M::KendallsPreShapeSpace, p, X; atol=sqrt(eps(eltype(X))), kwargs...) +function check_vector( + M::KendallsPreShapeSpace, + p, + X; + atol::Real=sqrt(eps(float(eltype(X)))), + kwargs..., +) for X_row in eachrow(X) - if !isapprox(mean(X_row), 0; atol, kwargs...) + if !isapprox(mean(X_row), 0; atol=atol, kwargs...) return DomainError( mean(X_row), "The vector $(X) is not a tangent vector to $(p) on $(M), since one of the rows does not have zero mean.", diff --git a/src/manifolds/MultinomialDoublyStochastic.jl b/src/manifolds/MultinomialDoublyStochastic.jl index d8dfa1dcac..40a4ef3e03 100644 --- a/src/manifolds/MultinomialDoublyStochastic.jl +++ b/src/manifolds/MultinomialDoublyStochastic.jl @@ -60,10 +60,15 @@ end Checks whether `p` is a valid point on the [`MultinomialDoubleStochastic`](@ref)`(n)` `M`, i.e. is a matrix with positive entries whose rows and columns sum to one. """ -function check_point(M::MultinomialDoubleStochastic, p; kwargs...) +function check_point( + M::MultinomialDoubleStochastic, + p::T; + atol::Real=sqrt(prod(representation_size(M))) * eps(real(float(number_eltype(T)))), + kwargs..., +) where {T} n = get_parameter(M.size)[1] r = sum(p, dims=2) - if !isapprox(norm(r - ones(n, 1)), 0.0; kwargs...) + if !isapprox(norm(r - ones(n, 1)), 0.0; atol=atol, kwargs...) return DomainError( r, "The point $(p) does not lie on $M, since its rows do not sum up to one.", @@ -78,9 +83,15 @@ Checks whether `X` is a valid tangent vector to `p` on the [`MultinomialDoubleSt This means, that `p` is valid, that `X` is of correct dimension and sums to zero along any column or row. """ -function check_vector(M::MultinomialDoubleStochastic, p, X; kwargs...) +function check_vector( + M::MultinomialDoubleStochastic, + p, + X::T; + atol::Real=sqrt(prod(representation_size(M))) * eps(real(float(number_eltype(T)))), + kwargs..., +) where {T} r = sum(X, dims=2) # check for stochastic rows - if !isapprox(norm(r), 0.0; kwargs...) + if !isapprox(norm(r), 0.0; atol=atol, kwargs...) return DomainError( r, "The matrix $(X) is not a tangent vector to $(p) on $(M), since its rows do not sum up to zero.", diff --git a/src/manifolds/ProbabilitySimplex.jl b/src/manifolds/ProbabilitySimplex.jl index 8a00af96c2..96d76fa7b9 100644 --- a/src/manifolds/ProbabilitySimplex.jl +++ b/src/manifolds/ProbabilitySimplex.jl @@ -129,8 +129,14 @@ after [`check_point`](@ref check_point(::ProbabilitySimplex, ::Any))`(M,p)`, `X` has to be of same dimension as `p` and its elements have to sum to one. The tolerance for the last test can be set using the `kwargs...`. """ -function check_vector(M::ProbabilitySimplex, p, X; kwargs...) - if !isapprox(sum(X), 0.0; kwargs...) +function check_vector( + M::ProbabilitySimplex, + p, + X::T; + atol::Real=sqrt(prod(representation_size(M))) * eps(real(float(number_eltype(T)))), + kwargs..., +) where {T} + if !isapprox(sum(X), 0.0; atol=atol, kwargs...) return DomainError( sum(X), "The vector $(X) is not a tangent vector to $(p) on $(M), since its elements do not sum up to 0.", @@ -340,7 +346,7 @@ Compute the Riemannian [`mean`](@ref mean(M::AbstractManifold, args...)) of `x` """ mean(::ProbabilitySimplex, ::Any...) -default_estimation_method(::ProbabilitySimplex, ::typeof(mean)) = GeodesicInterpolation() +default_approximation_method(::ProbabilitySimplex, ::typeof(mean)) = GeodesicInterpolation() function parallel_transport_to!(M::ProbabilitySimplex, Y, p, X, q) n = get_parameter(M.size)[1] diff --git a/src/manifolds/ProjectiveSpace.jl b/src/manifolds/ProjectiveSpace.jl index 392215431b..eeec8e880d 100644 --- a/src/manifolds/ProjectiveSpace.jl +++ b/src/manifolds/ProjectiveSpace.jl @@ -126,8 +126,14 @@ Check whether `X` is a tangent vector in the tangent space of `p` on the tangent space of the embedding and that the Frobenius inner product $⟨p, X⟩_{\mathrm{F}} = 0$. """ -function check_vector(M::AbstractProjectiveSpace, p, X; kwargs...) - if !isapprox(dot(p, X), 0; kwargs...) +function check_vector( + M::AbstractProjectiveSpace, + p, + X::T; + atol::Real=sqrt(prod(representation_size(M))) * eps(real(float(number_eltype(T)))), + kwargs..., +) where {T} + if !isapprox(dot(p, X), 0; atol=atol, kwargs...) return DomainError( dot(p, X), "The vector $(X) is not a tangent vector to $(p) on $(M), since it is not" * @@ -391,7 +397,7 @@ using [`GeodesicInterpolationWithinRadius`](@ref). """ mean(::AbstractProjectiveSpace, ::Any...) -function default_estimation_method(::AbstractProjectiveSpace, ::typeof(mean)) +function default_approximation_method(::AbstractProjectiveSpace, ::typeof(mean)) return GeodesicInterpolationWithinRadius(π / 4) end diff --git a/src/manifolds/Rotations.jl b/src/manifolds/Rotations.jl index b262d0114b..1a1aef5776 100644 --- a/src/manifolds/Rotations.jl +++ b/src/manifolds/Rotations.jl @@ -366,6 +366,12 @@ where ``q=\exp_p d``. The formula simplifies to identity for 2-D rotations. """ parallel_transport_direction(M::Rotations, p, X, d) +function parallel_transport_direction(M::Rotations, p, X, d) + expdhalf = exp(d / 2) + q = exp(M, p, d) + return transpose(q) * p * expdhalf * X * expdhalf +end +parallel_transport_direction(::Rotations{TypeParameter{Tuple{2}}}, p, X, d) = X function parallel_transport_direction!(M::Rotations, Y, p, X, d) expdhalf = exp(d / 2) @@ -375,12 +381,6 @@ end function parallel_transport_direction!(::Rotations{TypeParameter{Tuple{2}}}, Y, p, X, d) return copyto!(Y, X) end -function parallel_transport_direction(M::Rotations, p, X, d) - expdhalf = exp(d / 2) - q = exp(M, p, d) - return transpose(q) * p * expdhalf * X * expdhalf -end -parallel_transport_direction(::Rotations{TypeParameter{Tuple{2}}}, p, X, d) = X function parallel_transport_to!(M::Rotations, Y, p, X, q) d = log(M, p, q) diff --git a/src/manifolds/SPDFixedDeterminant.jl b/src/manifolds/SPDFixedDeterminant.jl index 5efae753c0..d72d3f1f8d 100644 --- a/src/manifolds/SPDFixedDeterminant.jl +++ b/src/manifolds/SPDFixedDeterminant.jl @@ -82,8 +82,14 @@ and additionally fulfill ``\operatorname{tr}(X) = 0``. The tolerance for the trace check of `X` can be set using `kwargs...`, which influences the `isapprox`-check. """ -function check_vector(M::SPDFixedDeterminant, p, X; kwargs...) - if !isapprox(tr(X), 0.0; kwargs...) +function check_vector( + M::SPDFixedDeterminant, + p, + X::T; + atol::Real=sqrt(prod(representation_size(M))) * eps(real(float(number_eltype(T)))), + kwargs..., +) where {T} + if !isapprox(tr(X), 0; atol=atol, kwargs...) return DomainError( tr(X), "The vector $(X) is not a tangent vector to $(p) on $(M), since it does not have a zero trace.", diff --git a/src/manifolds/Spectrahedron.jl b/src/manifolds/Spectrahedron.jl index ed7fb88479..082b6f8fa0 100644 --- a/src/manifolds/Spectrahedron.jl +++ b/src/manifolds/Spectrahedron.jl @@ -85,10 +85,16 @@ and a $X$ has to be a symmetric matrix with trace. The tolerance for the base point check and zero diagonal can be set using the `kwargs...`. Note that symmetry of $X$ holds by construction and is not explicitly checked. """ -function check_vector(M::Spectrahedron, q, Y; kwargs...) +function check_vector( + M::Spectrahedron, + q, + Y::T; + atol::Real=sqrt(prod(representation_size(M))) * eps(real(float(number_eltype(T)))), + kwargs..., +) where {T} X = q * Y' + Y * q' n = tr(X) - if !isapprox(n, 0.0; kwargs...) + if !isapprox(n, 0; atol=atol, kwargs...) return DomainError( n, "The vector $(X) is not a tangent to a point on $(M) (represented py $(q) and $(Y), since its trace is nonzero.", diff --git a/src/manifolds/Sphere.jl b/src/manifolds/Sphere.jl index e2fb758159..73f2cd37c2 100644 --- a/src/manifolds/Sphere.jl +++ b/src/manifolds/Sphere.jl @@ -130,8 +130,14 @@ after [`check_point`](@ref)`(M,p)`, `X` has to be of same dimension as `p` and orthogonal to `p`. The tolerance for the last test can be set using the `kwargs...`. """ -function check_vector(M::AbstractSphere, p, X; kwargs...) - if !isapprox(abs(real(dot(p, X))), 0.0; kwargs...) +function check_vector( + M::AbstractSphere, + p, + X::T; + atol::Real=sqrt(prod(representation_size(M))) * eps(real(float(number_eltype(T)))), + kwargs..., +) where {T} + if !isapprox(abs(real(dot(p, X))), 0; atol=atol, kwargs...) return DomainError( abs(dot(p, X)), "The vector $(X) is not a tangent vector to $(p) on $(M), since it is not orthogonal in the embedding.", @@ -413,7 +419,7 @@ Compute the Riemannian [`mean`](@ref mean(M::AbstractManifold, args...)) of `x` """ mean(::AbstractSphere, ::Any...) -function default_estimation_method(::AbstractSphere, ::typeof(mean)) +function default_approximation_method(::AbstractSphere, ::typeof(mean)) return GeodesicInterpolationWithinRadius(π / 2) end diff --git a/src/manifolds/SphereSymmetricMatrices.jl b/src/manifolds/SphereSymmetricMatrices.jl index cdb97c05e5..1c9eecf0f0 100644 --- a/src/manifolds/SphereSymmetricMatrices.jl +++ b/src/manifolds/SphereSymmetricMatrices.jl @@ -35,8 +35,13 @@ i.e. is an `n`-by-`n` symmetric matrix of unit Frobenius norm. The tolerance for the symmetry of `p` can be set using `kwargs...`. """ -function check_point(M::SphereSymmetricMatrices, p; kwargs...) - if !isapprox(norm(p - p'), 0.0; kwargs...) +function check_point( + M::SphereSymmetricMatrices, + p::T; + atol::Real=sqrt(prod(representation_size(M))) * eps(real(float(number_eltype(T)))), + kwargs..., +) where {T} + if !isapprox(norm(p - p'), 0; atol=atol, kwargs...) return DomainError( norm(p - p'), "The point $(p) does not lie on $M, since it is not symmetric.", @@ -54,8 +59,14 @@ of unit Frobenius norm. The tolerance for the symmetry of `p` and `X` can be set using `kwargs...`. """ -function check_vector(M::SphereSymmetricMatrices, p, X; kwargs...) - if !isapprox(norm(X - X'), 0.0; kwargs...) +function check_vector( + M::SphereSymmetricMatrices, + p, + X::T; + atol::Real=sqrt(prod(representation_size(M))) * eps(real(float(number_eltype(T)))), + kwargs..., +) where {T} + if !isapprox(norm(X - X'), 0; atol=atol, kwargs...) return DomainError( norm(X - X'), "The vector $(X) is not a tangent vector to $(p) on $(M), since it is not symmetric.", diff --git a/src/manifolds/Symmetric.jl b/src/manifolds/Symmetric.jl index 3d5a7832bd..989d3dba29 100644 --- a/src/manifolds/Symmetric.jl +++ b/src/manifolds/Symmetric.jl @@ -51,8 +51,8 @@ whether `p` is a symmetric matrix of size `(n,n)` with values from the correspon The tolerance for the symmetry of `p` can be set using `kwargs...`. """ -function check_point(M::SymmetricMatrices{<:Any,𝔽}, p; kwargs...) where {𝔽} - if !isapprox(norm(p - p'), 0.0; kwargs...) +function check_point(M::SymmetricMatrices, p; kwargs...) + if !isapprox(p, p'; kwargs...) return DomainError( norm(p - p'), "The point $(p) does not lie on $M, since it is not symmetric.", @@ -70,8 +70,8 @@ and its values have to be from the correct [`AbstractNumbers`](https://juliamani The tolerance for the symmetry of `X` can be set using `kwargs...`. """ -function check_vector(M::SymmetricMatrices{<:Any,𝔽}, p, X; kwargs...) where {𝔽} - if !isapprox(norm(X - X'), 0.0; kwargs...) +function check_vector(M::SymmetricMatrices, p, X; kwargs...) + if !isapprox(X, X'; kwargs...) return DomainError( norm(X - X'), "The vector $(X) is not a tangent vector to $(p) on $(M), since it is not symmetric.", diff --git a/src/manifolds/SymmetricPositiveDefinite.jl b/src/manifolds/SymmetricPositiveDefinite.jl index 48d367cebd..592f4413eb 100644 --- a/src/manifolds/SymmetricPositiveDefinite.jl +++ b/src/manifolds/SymmetricPositiveDefinite.jl @@ -165,7 +165,7 @@ Lie group. The tolerance for the last test can be set using the `kwargs...`. """ function check_vector(M::SymmetricPositiveDefinite, p, X; kwargs...) - if !isapprox(norm(X - transpose(X)), 0.0; kwargs...) + if !isapprox(X, transpose(X); kwargs...) return DomainError( X, "The vector $(X) is not a tangent to a point on $(M) (represented as an element of the Lie algebra) since its not symmetric.", @@ -289,7 +289,7 @@ Compute the Riemannian [`mean`](@ref mean(M::AbstractManifold, args...)) of `x` """ mean(::SymmetricPositiveDefinite, ::Any) -function default_estimation_method(::SymmetricPositiveDefinite, ::typeof(mean)) +function default_approximation_method(::SymmetricPositiveDefinite, ::typeof(mean)) return GeodesicInterpolation() end diff --git a/src/manifolds/SymmetricPositiveSemidefiniteFixedRank.jl b/src/manifolds/SymmetricPositiveSemidefiniteFixedRank.jl index 0f096829b7..65c774f07c 100644 --- a/src/manifolds/SymmetricPositiveSemidefiniteFixedRank.jl +++ b/src/manifolds/SymmetricPositiveSemidefiniteFixedRank.jl @@ -148,9 +148,15 @@ their distance, if they are not the same, i.e. that $d_{\mathcal M}(p,q) \approx the comparison is performed with the classical `isapprox`. The `kwargs...` are passed on to this accordingly. """ -function _isapprox(M::SymmetricPositiveSemidefiniteFixedRank, p, q; kwargs...) - return isapprox(norm(p - q), 0.0; kwargs...) || - isapprox(distance(M, p, q), 0.0; kwargs...) +function _isapprox( + M::SymmetricPositiveSemidefiniteFixedRank, + p::T, + q; + atol::Real=sqrt(prod(representation_size(M))) * eps(real(float(number_eltype(T)))), + kwargs..., +) where {T} + return isapprox(norm(p - q), 0; atol=atol, kwargs...) || + isapprox(distance(M, p, q), 0; atol=atol, kwargs...) end """ diff --git a/src/manifolds/Symplectic.jl b/src/manifolds/Symplectic.jl index 23cfeca206..67b3492b3f 100644 --- a/src/manifolds/Symplectic.jl +++ b/src/manifolds/Symplectic.jl @@ -212,10 +212,15 @@ Q_{2n} = ```` The tolerance can be set with `kwargs...` (e.g. `atol = 1.0e-14`). """ -function check_point(M::Symplectic, p; kwargs...) +function check_point( + M::Symplectic, + p::T; + atol::Real=sqrt(prod(representation_size(M))) * eps(real(float(number_eltype(T)))), + kwargs..., +) where {T} # Perform check that the matrix lives on the real symplectic manifold: expected_zero = norm(inv(M, p) * p - LinearAlgebra.I) - if !isapprox(expected_zero, zero(eltype(p)); kwargs...) + if !isapprox(expected_zero, 0; atol=atol, kwargs...) return DomainError( expected_zero, ( @@ -245,10 +250,16 @@ The tolerance can be set with `kwargs...` (e.g. `atol = 1.0e-14`). """ check_vector(::Symplectic, ::Any...) -function check_vector(M::Symplectic, p, X; kwargs...) +function check_vector( + M::Symplectic, + p, + X::T; + atol::Real=sqrt(prod(representation_size(M))) * eps(real(float(number_eltype(T)))), + kwargs..., +) where {T} Q = SymplecticMatrix(p, X) tangent_requirement_norm = norm(X' * Q * p + p' * Q * X, 2) - if !isapprox(tangent_requirement_norm, 0.0; kwargs...) + if !isapprox(tangent_requirement_norm, 0; atol=atol, kwargs...) return DomainError( tangent_requirement_norm, ( diff --git a/src/manifolds/SymplecticStiefel.jl b/src/manifolds/SymplecticStiefel.jl index 3a4c1c6ebb..7188e0dafe 100644 --- a/src/manifolds/SymplecticStiefel.jl +++ b/src/manifolds/SymplecticStiefel.jl @@ -103,10 +103,15 @@ Q_{2n} = ```` The tolerance can be set with `kwargs...` (e.g. `atol = 1.0e-14`). """ -function check_point(M::SymplecticStiefel, p; kwargs...) +function check_point( + M::SymplecticStiefel, + p::T; + atol::Real=sqrt(prod(representation_size(M))) * eps(real(float(number_eltype(T)))), + kwargs..., +) where {T} # Perform check that the matrix lives on the real symplectic manifold: expected_zero = norm(inv(M, p) * p - I) - if !isapprox(expected_zero, 0; kwargs...) + if !isapprox(expected_zero, 0; atol=atol, kwargs...) return DomainError( expected_zero, ( @@ -142,14 +147,20 @@ The tolerance can be set with `kwargs...` (e.g. `atol = 1.0e-14`). """ check_vector(::SymplecticStiefel, ::Any...) -function check_vector(M::SymplecticStiefel{<:Any,field}, p, X; kwargs...) where {field} +function check_vector( + M::SymplecticStiefel{S,𝔽}, + p, + X::T; + atol::Real=sqrt(prod(representation_size(M))) * eps(real(float(number_eltype(T)))), + kwargs..., +) where {S,T,𝔽} n, k = get_parameter(M.size) # From Bendokat-Zimmermann: T_pSpSt(2n, 2k) = \{p*H | H^{+} = -H \} H = inv(M, p) * X # ∈ ℝ^{2k × 2k}, should be Hamiltonian. - H_star = inv(Symplectic(2k, field), H) + H_star = inv(Symplectic(2k, 𝔽), H) hamiltonian_identity_norm = norm(H + H_star) - if !isapprox(hamiltonian_identity_norm, 0; kwargs...) + if !isapprox(hamiltonian_identity_norm, 0; atol=atol, kwargs...) return DomainError( hamiltonian_identity_norm, ( diff --git a/src/statistics.jl b/src/statistics.jl index 7a1030e60f..6b6dc83e28 100644 --- a/src/statistics.jl +++ b/src/statistics.jl @@ -1,142 +1,16 @@ """ AbstractEstimationMethod -Abstract type for defining statistical estimation methods. +Deprecated alias for `AbstractApproximationMethod` """ -abstract type AbstractEstimationMethod end - -""" - GradientDescentEstimation <: AbstractEstimationMethod - -Method for estimation using gradient descent. -""" -struct GradientDescentEstimation <: AbstractEstimationMethod end - -""" - CyclicProximalPointEstimation <: AbstractEstimationMethod - -Method for estimation using the cyclic proximal point technique. -""" -struct CyclicProximalPointEstimation <: AbstractEstimationMethod end - -""" - ExtrinsicEstimation <: AbstractEstimationMethod - -Method for estimation in the ambient space and projecting to the manifold. - -For [`mean`](@ref) estimation, [`GeodesicInterpolation`](@ref) is used for mean estimation -in the ambient space. -""" -struct ExtrinsicEstimation <: AbstractEstimationMethod end - -""" - WeiszfeldEstimation <: AbstractEstimationMethod - -Method for estimation using the Weiszfeld algorithm for the [`median`](@ref) -""" -struct WeiszfeldEstimation <: AbstractEstimationMethod end +const AbstractEstimationMethod = AbstractApproximationMethod _unit_weights(n::Int) = StatsBase.UnitWeights{Float64}(n) -@doc raw""" - GeodesicInterpolation <: AbstractEstimationMethod - -Repeated weighted geodesic interpolation method for estimating the Riemannian -center of mass. - -The algorithm proceeds with the following simple online update: - -```math -\begin{aligned} -μ_1 &= x_1\\ -t_k &= \frac{w_k}{\sum_{i=1}^k w_i}\\ -μ_{k} &= γ_{μ_{k-1}}(x_k; t_k), -\end{aligned} -``` - -where $x_k$ are points, $w_k$ are weights, $μ_k$ is the $k$th estimate of the -mean, and $γ_x(y; t)$ is the point at time $t$ along the -[`shortest_geodesic`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/functions.html#ManifoldsBase.shortest_geodesic-Tuple{AbstractManifold,%20Any,%20Any}) -between points $x,y ∈ \mathcal M$. The algorithm -terminates when all $x_k$ have been considered. In the [`Euclidean`](@ref) case, -this exactly computes the weighted mean. - -The algorithm has been shown to converge asymptotically with the sample size for -the following manifolds equipped with their default metrics when all sampled -points are in an open geodesic ball about the mean with corresponding radius -(see [`GeodesicInterpolationWithinRadius`](@ref)): - -* All simply connected complete Riemannian manifolds with non-positive sectional - curvature at radius $∞$ [ChengHoSalehianVemuri:2016](@cite), in particular: - + [`Euclidean`](@ref) - + [`SymmetricPositiveDefinite`](@ref) [HoChengSalehianVemuri:2013](@cite) -* Other manifolds: - + [`Sphere`](@ref): $\frac{π}{2}$ [SalehianEtAl:2015](@cite) - + [`Grassmann`](@ref): $\frac{π}{4}$ [ChakrabortyVemuri:2015](@cite) - + [`Stiefel`](@ref)/[`Rotations`](@ref): $\frac{π}{2 \sqrt 2}$ [ChakrabortyVemuri:2019](@cite) - -For online variance computation, the algorithm additionally uses an analogous -recursion to the weighted Welford algorithm [West:1979](@cite). -""" -struct GeodesicInterpolation <: AbstractEstimationMethod end - -""" - GeodesicInterpolationWithinRadius{T} <: AbstractEstimationMethod - -Estimation of Riemannian center of mass using [`GeodesicInterpolation`](@ref) -with fallback to [`GradientDescentEstimation`](@ref) if any points are outside of a -geodesic ball of specified `radius` around the mean. - -# Constructor - - GeodesicInterpolationWithinRadius(radius) -""" -struct GeodesicInterpolationWithinRadius{T} <: AbstractEstimationMethod - radius::T - - function GeodesicInterpolationWithinRadius(radius::T) where {T} - radius > 0 && return new{T}(radius) - return throw( - DomainError("The radius must be strictly postive, received $(radius)."), - ) - end -end - function Base.show(io::IO, method::GeodesicInterpolationWithinRadius) return print(io, "GeodesicInterpolationWithinRadius($(method.radius))") end -""" - default_estimation_method(M::AbstractManifold, f) - -Specify a default [`AbstractEstimationMethod`](@ref) for an `AbstractManifold` -for a function `f`, e.g. the `median` or the `mean`. - -Note that his function is decorated, so it can inherit from the embedding, for example for the -`IsEmbeddedSubmanifold` trait. -""" -default_estimation_method(M::AbstractManifold, f) - -for mf in [mean, median, cov, var, mean_and_std, mean_and_var] - @eval @trait_function default_estimation_method( - M::AbstractDecoratorManifold, - f::typeof($mf), - ) (no_empty,) - eval( - quote - function default_estimation_method( - ::TraitList{IsEmbeddedSubmanifold}, - M::AbstractDecoratorManifold, - f::typeof($mf), - ) - return default_estimation_method(get_embedding(M), f) - end - end, - ) -end - -@trait_function Statistics.mean(M::AbstractDecoratorManifold, x::AbstractVector) - """ Statistics.cov( M::AbstractManifold, @@ -145,7 +19,7 @@ end tangent_space_covariance_estimator::CovarianceEstimator=SimpleCovariance(; corrected=true, ), - mean_estimation_method::AbstractEstimationMethod=GradientDescentEstimation(), + mean_estimation_method::AbstractApproximationMethod=GradientDescentEstimation(), inverse_retraction_method::AbstractInverseRetractionMethod=default_inverse_retraction_method( M, eltype(x), ), @@ -156,7 +30,7 @@ on a manifold is a rank 2 tensor, the function returns its coefficients in basis the given tangent space basis. See Section 5 of [Pennec:2006](@cite) for details. The mean is calculated using the specified `mean_estimation_method` using -[mean](@ref Statistics.mean(::AbstractManifold, ::AbstractVector, ::AbstractEstimationMethod), +[mean](@ref Statistics.mean(::AbstractManifold, ::AbstractVector, ::AbstractApproximationMethod), and tangent vectors at this mean are calculated using the provided `inverse_retraction_method`. Finally, the covariance matrix in the tangent plane is estimated using the Euclidean space estimator `tangent_space_covariance_estimator`. The type `CovarianceEstimator` is defined @@ -171,7 +45,10 @@ function Statistics.cov( tangent_space_covariance_estimator::CovarianceEstimator=SimpleCovariance(; corrected=true, ), - mean_estimation_method::AbstractEstimationMethod=default_estimation_method(M, cov), + mean_estimation_method::AbstractApproximationMethod=default_approximation_method( + M, + cov, + ), inverse_retraction_method::AbstractInverseRetractionMethod=default_inverse_retraction_method( M, eltype(x), @@ -188,10 +65,16 @@ function Statistics.cov( ) end -function default_estimation_method(::EmptyTrait, ::AbstractDecoratorManifold, ::typeof(cov)) +function default_approximation_method( + ::EmptyTrait, + ::AbstractDecoratorManifold, + ::typeof(cov), +) + return GradientDescentEstimation() +end +function default_approximation_method(::AbstractManifold, ::typeof(cov)) return GradientDescentEstimation() end -default_estimation_method(::AbstractManifold, ::typeof(cov)) = GradientDescentEstimation() @doc raw""" mean(M::AbstractManifold, x::AbstractVector[, w::AbstractWeights]; kwargs...) @@ -209,7 +92,7 @@ In the general case, the [`GradientDescentEstimation`](@ref) is used to compute M::AbstractManifold, x::AbstractVector, [w::AbstractWeights,] - method::AbstractEstimationMethod=default_estimation_method(M); + method::AbstractApproximationMethod=default_approximation_method(M, mean); kwargs..., ) @@ -243,10 +126,25 @@ as the exponential barycenter. The algorithm is further described in[AfsariTronVidal:2013](@cite). """ mean(::AbstractManifold, ::Any...) + +# +# dispatch on method first to allow Euclidean defaults to hit +function Statistics.mean(M::AbstractManifold, x::AbstractVector, kwargs...) + return mean(M, x, default_approximation_method(M, mean, eltype(x)); kwargs...) +end function Statistics.mean( M::AbstractManifold, x::AbstractVector, - method::AbstractEstimationMethod=default_estimation_method(M, mean); + w::AbstractVector, + kwargs..., +) + return mean(M, x, w, default_approximation_method(M, mean, eltype(x)); kwargs...) +end + +function Statistics.mean( + M::AbstractManifold, + x::AbstractVector, + method::AbstractApproximationMethod; kwargs..., ) y = allocate_result(M, mean, x[1]) @@ -256,18 +154,13 @@ function Statistics.mean( M::AbstractManifold, x::AbstractVector, w::AbstractVector, - method::AbstractEstimationMethod=default_estimation_method(M, mean); + method::AbstractApproximationMethod; kwargs..., ) y = allocate_result(M, mean, x[1]) return mean!(M, y, x, w, method; kwargs...) end -function default_estimation_method(::EmptyTrait, ::AbstractManifold, ::typeof(mean)) - return GradientDescentEstimation() -end; -default_estimation_method(::AbstractManifold, ::typeof(mean)) = GradientDescentEstimation(); - @doc raw""" mean!(M::AbstractManifold, y, x::AbstractVector[, w::AbstractWeights]; kwargs...) mean!( @@ -275,7 +168,7 @@ default_estimation_method(::AbstractManifold, ::typeof(mean)) = GradientDescentE y, x::AbstractVector, [w::AbstractWeights,] - method::AbstractEstimationMethod; + method::AbstractApproximationMethod; kwargs..., ) @@ -287,7 +180,7 @@ function Statistics.mean!( M::AbstractManifold, y, x::AbstractVector, - method::AbstractEstimationMethod=default_estimation_method(M, mean); + method::AbstractApproximationMethod=default_approximation_method(M, mean); kwargs..., ) w = _unit_weights(length(x)) @@ -497,8 +390,6 @@ end Estimate the Riemannian center of mass of `x` using [`ExtrinsicEstimation`](@ref), i.e. by computing the mean in the embedding and projecting the result back. -You can specify an `extrinsic_method` to specify which mean estimation method to use in the embedding, -which defaults to [`GeodesicInterpolation`](@ref). See [`mean`](@ref mean(::AbstractManifold, ::AbstractVector, ::AbstractVector, ::GeodesicInterpolation)) for a description of the remaining `kwargs`. @@ -515,26 +406,37 @@ function Statistics.mean!( y, x::AbstractVector, w::AbstractVector, - ::ExtrinsicEstimation; - extrinsic_method::AbstractEstimationMethod=default_estimation_method( - get_embedding(M), - mean, - ), + e::ExtrinsicEstimation; + extrinsic_method::Union{AbstractEstimationMethod,Nothing}=nothing, kwargs..., ) + if !isnothing(extrinsic_method) + Base.depwarn( + "The Keyword Argument `extrinsic_method` is deprecated use `ExtrinsicEstimators` field instead", + :mean!, + ) + e = ExtrinsicEstimation(extrinsic_method) + end embedded_x = map(p -> embed(M, p), x) - embedded_y = mean(get_embedding(M), embedded_x, w, extrinsic_method; kwargs...) + embedded_y = mean(get_embedding(M), embedded_x, w, e.extrinsic_estimation; kwargs...) project!(M, y, embedded_y) return y end +function default_approximation_method(::EmptyTrait, ::AbstractManifold, ::typeof(mean)) + return GradientDescentEstimation() +end; +function default_approximation_method(::AbstractManifold, ::typeof(mean)) + return GradientDescentEstimation() +end; + @doc raw""" median(M::AbstractManifold, x::AbstractVector[, w::AbstractWeights]; kwargs...) median( M::AbstractManifold, x::AbstractVector, [w::AbstractWeights,] - method::AbstractEstimationMethod; + method::AbstractApproximationMethod; kwargs..., ) @@ -553,14 +455,14 @@ Compute the median using the specified `method`. """ Statistics.median(::AbstractManifold, ::Any...) -function default_estimation_method( +function default_approximation_method( ::EmptyTrait, ::AbstractDecoratorManifold, ::typeof(median), ) return CyclicProximalPointEstimation() end -function default_estimation_method(::AbstractManifold, ::typeof(median)) +function default_approximation_method(::AbstractManifold, ::typeof(median)) return CyclicProximalPointEstimation() end @@ -605,14 +507,11 @@ Statistics.median( x::AbstractVector, [w::AbstractWeights,] method::ExtrinsicEstimation; - extrinsic_method = CyclicProximalPointEstimation(), kwargs..., ) Estimate the median of `x` using [`ExtrinsicEstimation`](@ref), i.e. by computing the median in the embedding and projecting the result back. -You can specify an `extrinsic_method` to specify which median estimation method to use in -the embedding, which defaults to [`CyclicProximalPointEstimation`](@ref). See [`median`](@ref median(::AbstractManifold, ::AbstractVector, ::AbstractVector, ::CyclicProximalPointEstimation)) for a description of `kwargs`. @@ -680,10 +579,24 @@ Statistics.median( ::WeiszfeldEstimation, ) +# +# dispatch on the method first before allocating to allow Euclidean defaults to hit +function Statistics.median(M::AbstractManifold, x::AbstractVector; kwargs...) + return median(M, x, default_approximation_method(M, median, eltype(x))) +end function Statistics.median( M::AbstractManifold, x::AbstractVector, - method::AbstractEstimationMethod=default_estimation_method(M, median); + w::AbstractVector; + kwargs..., +) + return median(M, x, w, default_approximation_method(M, median, eltype(x))) +end + +function Statistics.median( + M::AbstractManifold, + x::AbstractVector, + method::AbstractApproximationMethod; kwargs..., ) y = allocate_result(M, median, x[1]) @@ -693,7 +606,7 @@ function Statistics.median( M::AbstractManifold, x::AbstractVector, w::AbstractVector, - method::AbstractEstimationMethod=default_estimation_method(M, median); + method::AbstractApproximationMethod; kwargs..., ) y = allocate_result(M, median, x[1]) @@ -707,7 +620,7 @@ end y, x::AbstractVector, [w::AbstractWeights,] - method::AbstractEstimationMethod; + method::AbstractApproximationMethod; kwargs..., ) @@ -718,7 +631,7 @@ function Statistics.median!( M::AbstractManifold, q, x::AbstractVector, - method::AbstractEstimationMethod=default_estimation_method(M, median); + method::AbstractApproximationMethod=default_approximation_method(M, median, eltype(x)); kwargs..., ) w = _unit_weights(length(x)) @@ -771,15 +684,19 @@ function Statistics.median!( y, x::AbstractVector, w::AbstractVector, - ::ExtrinsicEstimation; - extrinsic_method::AbstractEstimationMethod=default_estimation_method( - get_embedding(M), - median, - ), + e::ExtrinsicEstimation; + extrinsic_method=nothing, kwargs..., ) + if !isnothing(extrinsic_method) + Base.depwarn( + "The Keyword Argument `extrinsic_method` is deprecated use `ExtrinsicEstimators` field instead", + :median!, + ) + e = ExtrinsicEstimation(extrinsic_method) + end embedded_x = map(p -> embed(M, p), x) - embedded_y = median(get_embedding(M), embedded_x, w, extrinsic_method; kwargs...) + embedded_y = median(get_embedding(M), embedded_x, w, e.extrinsic_estimation; kwargs...) project!(M, y, embedded_y) return y end @@ -905,7 +822,7 @@ simultaneously. See those functions for a description of the arguments. M::AbstractManifold, x::AbstractVector [w::AbstractWeights,] - method::AbstractEstimationMethod; + method::AbstractApproximationMethod; kwargs..., ) -> (mean, var) @@ -918,7 +835,7 @@ function StatsBase.mean_and_var( M::AbstractManifold, x::AbstractVector, w::AbstractWeights, - method::AbstractEstimationMethod=default_estimation_method(M, mean); + method::AbstractApproximationMethod=default_approximation_method(M, mean, eltype(x)); corrected=false, kwargs..., ) @@ -929,7 +846,11 @@ end function StatsBase.mean_and_var( M::AbstractManifold, x::AbstractVector, - method::AbstractEstimationMethod=default_estimation_method(M, mean_and_var); + method::AbstractApproximationMethod=default_approximation_method( + M, + mean_and_var, + eltype(x), + ); corrected=true, kwargs..., ) @@ -937,15 +858,15 @@ function StatsBase.mean_and_var( w = _unit_weights(n) return mean_and_var(M, x, w, method; corrected=corrected, kwargs...) end -function default_estimation_method( +function default_approximation_method( ::EmptyTrait, M::AbstractDecoratorManifold, ::typeof(mean_and_var), ) - return default_estimation_method(M, mean) + return default_approximation_method(M, mean) end -function default_estimation_method(M::AbstractManifold, ::typeof(mean_and_var)) - return default_estimation_method(M, mean) +function default_approximation_method(M::AbstractManifold, ::typeof(mean_and_var)) + return default_approximation_method(M, mean) end @doc raw""" @@ -1077,7 +998,7 @@ Compute the [`mean`](@ref mean(::AbstractManifold, args...)) and the standard de M::AbstractManifold, x::AbstractVector [w::AbstractWeights,] - method::AbstractEstimationMethod; + method::AbstractApproximationMethod; kwargs..., ) -> (mean, var) @@ -1089,8 +1010,8 @@ function StatsBase.mean_and_std(M::AbstractManifold, args...; kwargs...) m, v = mean_and_var(M, args...; kwargs...) return m, sqrt(v) end -function default_estimation_method(M::AbstractManifold, ::typeof(mean_and_std)) - return default_estimation_method(M, mean) +function default_approximation_method(M::AbstractManifold, ::typeof(mean_and_std)) + return default_approximation_method(M, mean) end """ @@ -1154,3 +1075,24 @@ function StatsBase.kurtosis(M::AbstractManifold, x::AbstractVector, args...) w = _unit_weights(length(x)) return kurtosis(M, x, w, args...) end + +# +# decorate default method for a few functions +for mf in [mean, median, cov, var, mean_and_std, mean_and_var] + @eval @trait_function default_approximation_method( + M::AbstractDecoratorManifold, + f::typeof($mf), + ) (no_empty,) + eval( + quote + function default_approximation_method( + ::TraitList{IsEmbeddedSubmanifold}, + M::AbstractDecoratorManifold, + f::typeof($mf), + ) + return default_approximation_method(get_embedding(M), f) + end + end, + ) +end +@trait_function Statistics.mean(M::AbstractDecoratorManifold, x::AbstractVector) diff --git a/test/manifolds/circle.jl b/test/manifolds/circle.jl index 07d3eddb6d..46bfe9d226 100644 --- a/test/manifolds/circle.jl +++ b/test/manifolds/circle.jl @@ -288,4 +288,10 @@ using Manifolds: TFVector, CoTFVector ) end end + @testset "Mixed array dimensions for exp" begin + M = Circle() + p = fill(0.0) + exp!(M, p, p, [1.0], 2.0) + @test p ≈ fill(2.0) + end end diff --git a/test/manifolds/embedded_torus.jl b/test/manifolds/embedded_torus.jl index a5ccf63e36..b8720d10d5 100644 --- a/test/manifolds/embedded_torus.jl +++ b/test/manifolds/embedded_torus.jl @@ -91,7 +91,7 @@ using BoundaryValueDiffEq a2 = [-0.5, 0.3] sol_log = Manifolds.solve_chart_log_bvp(M, p0x, a2, A, (0, 0)) @test sol_log(0.0)[1:2] ≈ p0x - @test sol_log(1.0)[1:2] ≈ a2 + @test sol_log(1.0)[1:2] ≈ a2 atol = 1e-7 # a test randomly failed here on Julia 1.6 once for no clear reason? # so I bumped tolerance considerably bvp_atol = VERSION < v"1.7" ? 2e-3 : 1e-15 diff --git a/test/manifolds/rotations.jl b/test/manifolds/rotations.jl index 3fa56e491b..4a276b8def 100644 --- a/test/manifolds/rotations.jl +++ b/test/manifolds/rotations.jl @@ -286,4 +286,33 @@ include("../utils.jl") @test X isa Matrix{Float64} @test X == fill(0.0, 1, 1) end + + @testset "Specializations" begin + M = Rotations(2) + p = Matrix{Float64}(I, 2, 2) + X = [0.0 3.0; -3.0 0.0] + @test parallel_transport_direction(M, p, X, X) === X + + M = Rotations(3) + p = @SMatrix [ + -0.5908399013383766 -0.6241917041179139 0.5111681988316876 + -0.7261666986267721 0.13535732881097293 -0.6740625485388226 + 0.35155388888753836 -0.7694563730631729 -0.5332417398896261 + ] + X = @SMatrix [ + 0.0 -0.30777760628130063 0.5499897386953444 + 0.30777760628130063 0.0 -0.32059980100053004 + -0.5499897386953444 0.32059980100053004 0.0 + ] + d = @SMatrix [ + 0.0 -0.4821890003925358 -0.3513148535122392 + 0.4821890003925358 0.0 0.37956770358148356 + 0.3513148535122392 -0.37956770358148356 0.0 + ] + @test parallel_transport_direction(M, p, X, d) ≈ [ + 0.0 -0.3258778314599828 0.3903114578816008 + 0.32587783145998306 0.0 -0.49138641089195584 + -0.3903114578816011 0.4913864108919558 0.0 + ] + end end diff --git a/test/statistics.jl b/test/statistics.jl index c4356d3788..160537bc00 100644 --- a/test/statistics.jl +++ b/test/statistics.jl @@ -12,13 +12,14 @@ import ManifoldsBase: base_manifold, get_embedding using Manifolds: - AbstractEstimationMethod, + AbstractApproximationMethod, CyclicProximalPointEstimation, GeodesicInterpolation, GeodesicInterpolationWithinRadius, GradientDescentEstimation, WeiszfeldEstimation -import Manifolds: mean, mean!, median, median!, var, mean_and_var, default_estimation_method +import Manifolds: + mean, mean!, median, median!, var, mean_and_var, default_approximation_method struct TestStatsSphere{N} <: AbstractManifold{ℝ} end TestStatsSphere(N) = TestStatsSphere{N}() @@ -114,7 +115,7 @@ function test_mean(M, x, yexp=nothing, method...; kwargs...) y, x, pweights(ones(n + 1)), - Manifolds.default_estimation_method(M, mean); + Manifolds.default_approximation_method(M, mean); kwargs..., ) end @@ -125,14 +126,14 @@ function test_median( M, x, yexp=nothing; - method::Union{Nothing,AbstractEstimationMethod}=nothing, + method::Union{Nothing,AbstractApproximationMethod}=nothing, kwargs..., ) @testset "median unweighted$(!isnothing(method) ? " ($method)" : "")" begin y = isnothing(method) ? median(M, x; kwargs...) : median(M, x, method; kwargs...) @test is_point(M, y; atol=10^-9) if yexp !== nothing - @test isapprox(M, y, yexp; atol=10^-5) + @test isapprox(M, y, yexp; atol=5 * 10^-5) end end @@ -301,7 +302,7 @@ function test_moments(M, x) end struct TestStatsOverload1 <: AbstractManifold{ℝ} end -struct TestStatsMethod1 <: AbstractEstimationMethod end +struct TestStatsMethod1 <: AbstractApproximationMethod end function mean!( ::TestStatsOverload1, @@ -401,8 +402,8 @@ end @test std(M, x, w) == 2.0 @test std(M, x, w, 2) == 2.0 - @test Manifolds.default_estimation_method(M, mean_and_std) == - Manifolds.default_estimation_method(M, mean) + @test Manifolds.default_approximation_method(M, mean_and_std) == + Manifolds.default_approximation_method(M, mean) @test mean_and_var(M, x, TestStatsMethod1()) == ([5.0], 16) @test mean_and_var(M, x, w, TestStatsMethod1()) == ([5.0], 9) @test mean_and_std(M, x, TestStatsMethod1()) == ([5.0], 4.0) @@ -547,6 +548,9 @@ end test_var(M, x) test_std(M, x) test_moments(M, x) + y = copy(x[1]) + mean!(M, y, x) + @test y == mean(x) end end end @@ -786,7 +790,7 @@ end x = [normalize(randn(rng, 3)) for _ in 1:10] w = pweights([rand(rng) for _ in 1:length(x)]) m = normalize(mean(reduce(hcat, x), w; dims=2)[:, 1]) - mg = mean(S, x, w, ExtrinsicEstimation()) + mg = mean(S, x, w, ExtrinsicEstimation(EfficientEstimator())) @test isapprox(S, m, mg) end @@ -796,12 +800,12 @@ end x = [normalize(randn(rng, 3)) for _ in 1:10] w = pweights([rand(rng) for _ in 1:length(x)]) m = normalize(median(Euclidean(3), x, w)) - mg = median(S, x, w, ExtrinsicEstimation()) + mg = median(S, x, w, ExtrinsicEstimation(CyclicProximalPointEstimation())) @test isapprox(S, m, mg) end @testset "Covariance Default" begin - @test default_estimation_method(TestStatsSphere{2}(), cov) == + @test default_approximation_method(TestStatsSphere{2}(), cov) == GradientDescentEstimation() end diff --git a/test/test_deprecated.jl b/test/test_deprecated.jl index 5791392758..1fb4c592f5 100644 --- a/test/test_deprecated.jl +++ b/test/test_deprecated.jl @@ -1,3 +1,33 @@ -using Manifolds, ManifoldsBase, Test +using Manifolds, ManifoldsBase, Random, Test -@testset "Deprecation tests" begin end +using StatsBase: AbstractWeights, pweights +using Random: GLOBAL_RNG, seed! + +@testset "Deprecation tests" begin + @testset "Depreacte extrinsic_method= keyword" begin + rng = MersenneTwister(47) + S = Sphere(2) + x = [normalize(randn(rng, 3)) for _ in 1:10] + w = pweights([rand(rng) for _ in 1:length(x)]) + mg1 = mean(S, x, w, ExtrinsicEstimation(EfficientEstimator())) + # Statistics 414-418, depcreatce former extrinsic_method keyword + mg2 = mean( + S, + x, + w, + ExtrinsicEstimation(EfficientEstimator()); + extrinsic_method=EfficientEstimator(), + ) + @test isapprox(S, mg1, mg2) + mg3 = median(S, x, w, ExtrinsicEstimation(CyclicProximalPointEstimation())) + # Statistics 692-696, depcreatce former extrinsic_method keyword + mg4 = median( + S, + x, + w, + ExtrinsicEstimation(CyclicProximalPointEstimation()); + extrinsic_method=CyclicProximalPointEstimation(), + ) + @test isapprox(S, mg3, mg4) + end +end