diff --git a/docs/src/lib/public.md b/docs/src/lib/public.md index 679ae3b..6d5f4cb 100644 --- a/docs/src/lib/public.md +++ b/docs/src/lib/public.md @@ -20,7 +20,6 @@ Pages = ["public.md"] ```@docs Lattice basisvectors(::Lattice) -eachbasisvector(::Lattice) ``` ### Reciprocal lattices @@ -28,7 +27,6 @@ eachbasisvector(::Lattice) ```@docs ReciprocalLattice basisvectors(::ReciprocalLattice) -eachbasisvector(::ReciprocalLattice) reciprocal ``` diff --git a/src/lattice.jl b/src/lattice.jl index a8f7572..0d651bd 100644 --- a/src/lattice.jl +++ b/src/lattice.jl @@ -1,13 +1,14 @@ using StaticArrays: MMatrix, SDiagonal +using StructEquality: @struct_hash_equal_isequal_isapprox -export Lattice, eachbasisvector, basisvectors +export Lattice, basisvectors """ AbstractLattice{T} <: AbstractMatrix{T} Represent the real lattices and the reciprocal lattices. """ -abstract type AbstractLattice{T} <: AbstractMatrix{T} end +abstract type AbstractLattice{T} end """ Lattice(data::AbstractMatrix) @@ -29,7 +30,7 @@ julia> Lattice([ 3.4 6.7 9.1 ``` """ -mutable struct Lattice{T} <: AbstractLattice{T} +@struct_hash_equal_isequal_isapprox struct Lattice{T} <: AbstractLattice{T} data::MMatrix{3,3,T,9} end Lattice(data::AbstractMatrix) = Lattice(MMatrix{3,3}(data)) @@ -113,14 +114,8 @@ end Get the three basis vectors from a lattice. """ -basisvectors(lattice::Lattice) = Tuple(eachcol(lattice)) - -""" - eachbasisvector(lattice::Lattice) - -Iterate over the three basis vectors of a lattice. -""" -eachbasisvector(lattice::Lattice) = eachcol(lattice) +basisvectors(lattice::Lattice) = + lattice[begin:(begin + 2)], lattice[(begin + 3):(begin + 5)], lattice[(begin + 6):end] # See https://github.com/JuliaLang/julia/blob/v1.10.0-beta1/stdlib/LinearAlgebra/src/uniformscaling.jl#L130-L131 Base.one(::Type{Lattice{T}}) where {T} = Lattice(SDiagonal(one(T), one(T), one(T))) @@ -131,42 +126,40 @@ Base.oneunit(::Type{Lattice{T}}) where {T} = Lattice(SDiagonal(oneunit(T), oneunit(T), oneunit(T))) Base.oneunit(lattice::Lattice) = oneunit(typeof(lattice)) -Base.parent(lattice::Lattice) = lattice.data +# See https://github.com/JuliaLang/julia/blob/v1.10.0-beta1/stdlib/LinearAlgebra/src/uniformscaling.jl#L134-L135 +Base.zero(::Type{Lattice{T}}) where {T} = Lattice(zeros(T, 3, 3)) +Base.zero(lattice::Lattice) = zero(typeof(lattice)) + +# Similar to https://github.com/JuliaCollections/IterTools.jl/blob/0ecaa88/src/IterTools.jl#L1028-L1032 +Base.iterate(iter::Lattice, state=1) = iterate(parent(iter), state) + +Base.IteratorSize(::Type{<:Lattice}) = Base.HasShape{2}() + +Base.eltype(::Type{Lattice{T}}) where {T} = T + +Base.length(::Lattice) = 9 Base.size(::Lattice) = (3, 3) +# See https://github.com/rafaqz/DimensionalData.jl/blob/bd28d08/src/array/array.jl#L74 +Base.size(lattice::Lattice, dim) = size(parent(lattice), dim) # Here, `parent(A)` is necessary to avoid `StackOverflowError`. -Base.getindex(lattice::Lattice, i::Int) = getindex(parent(lattice), i) -Base.getindex(lattice::Lattice, I...) = getindex(parent(lattice), I...) +Base.parent(lattice::Lattice) = lattice.data -Base.setindex!(lattice::Lattice, v, i::Int) = setindex!(parent(lattice), v, i) -Base.setindex!(lattice::Lattice, X, I...) = setindex!(parent(lattice), X, I...) +Base.getindex(lattice::Lattice, i...) = getindex(parent(lattice), i...) -Base.IndexStyle(::Type{<:Lattice}) = IndexLinear() +Base.firstindex(::Lattice) = 1 -# Customizing broadcasting -# See https://github.com/JuliaArrays/StaticArraysCore.jl/blob/v1.4.2/src/StaticArraysCore.jl#L397-L398 -# and https://github.com/JuliaLang/julia/blob/v1.10.0-beta1/stdlib/LinearAlgebra/src/structuredbroadcast.jl#L7-L14 -struct LatticeStyle <: Broadcast.AbstractArrayStyle{2} end -LatticeStyle(::Val{2}) = LatticeStyle() -LatticeStyle(::Val{N}) where {N} = Broadcast.DefaultArrayStyle{N}() +Base.lastindex(::Lattice) = 9 -Base.BroadcastStyle(::Type{<:Lattice}) = LatticeStyle() +Base.:*(lattice::Lattice, x::Number) = Lattice(parent(lattice) * x) +Base.:*(x::Number, lattice::Lattice) = lattice * x -Base.similar(::Broadcast.Broadcasted{LatticeStyle}, ::Type{T}) where {T} = - similar(Lattice{T}, 3, 3) -# Override https://github.com/JuliaLang/julia/blob/v1.10.0-beta2/base/abstractarray.jl#L839 -function Base.similar(lattice::Lattice, ::Type{T}, dims::Dims) where {T} - if dims == size(lattice) - return Lattice(similar(Matrix{T}, dims)) - else - throw(ArgumentError("invalid dims `$dims` for `Lattice`!")) - end -end -# Override https://github.com/JuliaLang/julia/blob/v1.10.0-beta1/base/abstractarray.jl#L874 -function Base.similar(::Type{Lattice{T}}, dims::Dims) where {T} - if dims == (3, 3) - return Lattice(similar(Matrix{T}, dims)) - else - throw(ArgumentError("invalid dims `$dims` for `Lattice`!")) - end -end +Base.:/(lattice::Lattice, x::Number) = Lattice(parent(lattice) / x) + +Base.:+(lattice::Lattice) = lattice +Base.:+(lattice::Lattice, x::Number) = Lattice(parent(lattice) .+ x) +Base.:+(x::Number, lattice::Lattice) = lattice + x + +Base.:-(lattice::Lattice) = -one(eltype(lattice)) * lattice +Base.:-(lattice::Lattice, x::Number) = Lattice(parent(lattice) .- x) +Base.:-(x::Number, lattice::Lattice) = -lattice + x diff --git a/src/reciprocal.jl b/src/reciprocal.jl index 8f6b2e4..ba3b875 100644 --- a/src/reciprocal.jl +++ b/src/reciprocal.jl @@ -1,7 +1,5 @@ using LinearAlgebra: I, det, cross -import Base: *, / - export ReciprocalLattice, reciprocal, isreciprocal """ @@ -15,7 +13,7 @@ Construct a `ReciprocalLattice` from a matrix. !!! warning Avoid using this constructor directly. Use `reciprocal` instead. """ -struct ReciprocalLattice{T} <: AbstractLattice{T} +@struct_hash_equal_isequal_isapprox struct ReciprocalLattice{T} <: AbstractLattice{T} data::MMatrix{3,3,T,9} end ReciprocalLattice(data::AbstractMatrix) = ReciprocalLattice(MMatrix{3,3}(data)) @@ -25,14 +23,8 @@ ReciprocalLattice(data::AbstractMatrix) = ReciprocalLattice(MMatrix{3,3}(data)) Get the three basis vectors from a reciprocal lattice. """ -basisvectors(lattice::ReciprocalLattice) = Tuple(eachbasisvector(lattice)) - -""" - eachbasisvector(lattice::ReciprocalLattice) - -Iterate over the three basis vectors of a reciprocal lattice. -""" -eachbasisvector(lattice::ReciprocalLattice) = eachcol(lattice) +basisvectors(lattice::ReciprocalLattice) = + lattice[begin:(begin + 2)], lattice[(begin + 3):(begin + 5)], lattice[(begin + 6):end] """ reciprocal(lattice::Lattice) @@ -47,7 +39,7 @@ function reciprocal(lattice::Lattice) end function reciprocal(lattice::ReciprocalLattice) Ω⁻¹ = det(lattice.data) # Cannot use `cellvolume`, it takes the absolute value! - 𝐚⁻¹, 𝐛⁻¹, 𝐜⁻¹ = eachbasisvector(lattice) + 𝐚⁻¹, 𝐛⁻¹, 𝐜⁻¹ = basisvectors(lattice) return inv(Ω⁻¹) * Lattice(hcat(cross(𝐛⁻¹, 𝐜⁻¹), cross(𝐜⁻¹, 𝐚⁻¹), cross(𝐚⁻¹, 𝐛⁻¹))) end @@ -64,50 +56,32 @@ Base.oneunit(::Type{ReciprocalLattice{T}}) where {T} = ReciprocalLattice(MMatrix{3,3}(SDiagonal(oneunit(T), oneunit(T), oneunit(T)))) Base.oneunit(lattice::ReciprocalLattice) = oneunit(typeof(lattice)) -Base.parent(lattice::ReciprocalLattice) = lattice.data +# Similar to https://github.com/JuliaCollections/IterTools.jl/blob/0ecaa88/src/IterTools.jl#L1028-L1032 +Base.iterate(iter::ReciprocalLattice, state=1) = iterate(parent(iter), state) + +Base.IteratorSize(::Type{<:ReciprocalLattice}) = Base.HasShape{2}() + +Base.eltype(::Type{ReciprocalLattice{T}}) where {T} = T + +Base.length(::ReciprocalLattice) = 9 Base.size(::ReciprocalLattice) = (3, 3) +# See https://github.com/rafaqz/DimensionalData.jl/blob/bd28d08/src/array/array.jl#L74 +Base.size(lattice::ReciprocalLattice, dim) = size(parent(lattice), dim) # Here, `parent(A)` is necessary to avoid `StackOverflowError`. -Base.getindex(lattice::ReciprocalLattice, i::Int) = getindex(parent(lattice), i) -Base.getindex(lattice::ReciprocalLattice, I...) = getindex(parent(lattice), I...) +Base.parent(lattice::ReciprocalLattice) = lattice.data -Base.setindex!(lattice::ReciprocalLattice, v, i::Int) = setindex!(parent(lattice), v, i) -Base.setindex!(lattice::ReciprocalLattice, X, I...) = setindex!(parent(lattice), X, I...) +Base.getindex(lattice::ReciprocalLattice, i...) = getindex(parent(lattice), i...) -Base.IndexStyle(::Type{<:ReciprocalLattice}) = IndexLinear() +Base.firstindex(::ReciprocalLattice) = 1 -# Customizing broadcasting -# See https://github.com/JuliaArrays/StaticArraysCore.jl/blob/v1.4.2/src/StaticArraysCore.jl#L397-L398 -# and https://github.com/JuliaLang/julia/blob/v1.10.0-beta1/stdlib/LinearAlgebra/src/structuredbroadcast.jl#L7-L14 -struct ReciprocalLatticeStyle <: Broadcast.AbstractArrayStyle{2} end -ReciprocalLatticeStyle(::Val{2}) = ReciprocalLatticeStyle() -ReciprocalLatticeStyle(::Val{N}) where {N} = Broadcast.DefaultArrayStyle{N}() +Base.lastindex(::ReciprocalLattice) = 9 -Base.BroadcastStyle(::Type{<:ReciprocalLattice}) = ReciprocalLatticeStyle() +Base.:*(lattice::ReciprocalLattice, x::Number) = ReciprocalLattice(parent(lattice) * x) +Base.:*(x::Number, lattice::ReciprocalLattice) = lattice * x -Base.similar(::Broadcast.Broadcasted{ReciprocalLatticeStyle}, ::Type{T}) where {T} = - similar(ReciprocalLattice{T}, 3, 3) -# Override https://github.com/JuliaLang/julia/blob/v1.10.0-beta2/base/abstractarray.jl#L839 -function Base.similar(lattice::ReciprocalLattice, ::Type{T}, dims::Dims) where {T} - if dims == size(lattice) - return ReciprocalLattice(similar(Matrix{T}, dims)) - else - throw(ArgumentError("invalid dims `$dims` for `Lattice`!")) - end -end -# Override https://github.com/JuliaLang/julia/blob/v1.10.0-beta1/base/abstractarray.jl#L874 -function Base.similar(::Type{ReciprocalLattice{T}}, dims::Dims) where {T} - if dims == (3, 3) - return ReciprocalLattice(similar(Matrix{T}, dims)) - else - throw(ArgumentError("invalid dims `$dims` for `ReciprocalLattice`!")) - end -end +Base.:/(lattice::ReciprocalLattice, x::Number) = ReciprocalLattice(parent(lattice) / x) -for op in (:*, :/) - for S in (:Lattice, :ReciprocalLattice) - for T in (:Lattice, :ReciprocalLattice) - @eval $op(::$S, ::$T) = error("undefined operation!") - end - end -end +Base.:+(lattice::ReciprocalLattice) = lattice + +Base.:-(lattice::ReciprocalLattice) = -one(eltype(lattice)) * lattice diff --git a/test/lattice.jl b/test/lattice.jl index ae23fbe..5681509 100644 --- a/test/lattice.jl +++ b/test/lattice.jl @@ -1,4 +1,4 @@ -using CrystallographyCore: Lattice, basisvectors, eachbasisvector +using CrystallographyCore: Lattice, basisvectors using StaticArrays: MMatrix using Unitful, UnitfulAtomic @@ -13,7 +13,7 @@ using Unitful, UnitfulAtomic @test Lattice(mat) == Lattice(MMatrix{3,3}(mat)) @test basisvectors(Lattice(mat)) == ([1.2, 2.3, 3.4], [4.5, 5.6, 6.7], [7.8, 8.9, 9.1]) - @test all(eachbasisvector(Lattice(mat)) .== basisvectors(Lattice(mat))) + @test all(basisvectors(Lattice(mat)) .== basisvectors(Lattice(mat))) # Rectangular matrix @test_throws DimensionMismatch Lattice([ 1 2