diff --git a/ext/UnitfulExt.jl b/ext/UnitfulExt.jl index 8200e1e..c03929b 100644 --- a/ext/UnitfulExt.jl +++ b/ext/UnitfulExt.jl @@ -37,4 +37,11 @@ uconvert(u::Units, lattice::ReciprocalLattice) = ustrip(lattice::Lattice) = Lattice(map(ustrip, parent(lattice))) ustrip(lattice::ReciprocalLattice) = ReciprocalLattice(map(ustrip, parent(lattice))) +# You need this to let the broadcasting work. +Base.:*(lattice::Lattice, unit::Units) = Lattice(parent(lattice) * unit) +Base.:*(unit::Units, lattice::Lattice) = lattice * unit + +# You need this to let the broadcasting work. +Base.:/(lattice::Lattice, unit::Units) = Lattice(parent(lattice) / unit) + end diff --git a/src/lattice.jl b/src/lattice.jl index 33e6c35..8e24362 100644 --- a/src/lattice.jl +++ b/src/lattice.jl @@ -157,6 +157,8 @@ Base.:*(x::Number, lattice::Lattice) = lattice * x # You need this to let the broadcasting work. Base.:/(lattice::Lattice, x::Number) = Lattice(parent(lattice) / x) +Base.:/(::Number, ::Lattice) = + throw(ArgumentError("you cannot divide a number by a lattice!")) Base.:+(lattice::Lattice) = lattice # You need this to let the broadcasting work. @@ -176,5 +178,17 @@ Base.convert(::Type{Lattice{T}}, lattice::Lattice{S}) where {S,T} = Base.ndims(::Type{<:Lattice}) = 2 Base.ndims(::Lattice) = 2 -# See https://docs.julialang.org/en/v1/manual/interfaces/#man-interfaces-broadcasting -Base.broadcastable(lattice::Lattice) = Ref(lattice) +# See https://github.com/JuliaLang/julia/blob/v1.10.0-rc2/base/broadcast.jl#L741 +Base.broadcastable(lattice::Lattice) = lattice + +# See https://github.com/JuliaLang/julia/blob/v1.10.0-rc2/base/broadcast.jl#L49 +Base.BroadcastStyle(::Type{<:Lattice}) = Broadcast.Style{Lattice}() + +# See https://github.com/JuliaLang/julia/blob/v1.10.0-rc2/base/broadcast.jl#L135 +Base.BroadcastStyle(::Broadcast.AbstractArrayStyle{0}, b::Broadcast.Style{Lattice}) = b + +# See https://github.com/JuliaLang/julia/blob/v1.10.0-rc2/base/broadcast.jl#L1114-L1119 +Base.copy(bc::Broadcast.Broadcasted{Broadcast.Style{Lattice}}) = Lattice(x for x in bc) # For uniary and binary functions + +Base.broadcasted(::typeof(/), ::Number, ::Lattice) = + throw(ArgumentError("you cannot divide a number by a lattice!")) diff --git a/src/reciprocal.jl b/src/reciprocal.jl index 4cf8297..72882b2 100644 --- a/src/reciprocal.jl +++ b/src/reciprocal.jl @@ -78,11 +78,13 @@ Base.firstindex(::ReciprocalLattice) = 1 Base.lastindex(::ReciprocalLattice) = 9 # You need this to let the broadcasting work. -Base.:*(lattice::ReciprocalLattice, x) = ReciprocalLattice(parent(lattice) * x) -Base.:*(x, lattice::ReciprocalLattice) = lattice * x +Base.:*(lattice::ReciprocalLattice, x::Number) = ReciprocalLattice(parent(lattice) * x) +Base.:*(x::Number, lattice::ReciprocalLattice) = lattice * x # You need this to let the broadcasting work. -Base.:/(lattice::ReciprocalLattice, x) = ReciprocalLattice(parent(lattice) / x) +Base.:/(lattice::ReciprocalLattice, x::Number) = ReciprocalLattice(parent(lattice) / x) +Base.:/(::Number, ::ReciprocalLattice) = + throw(ArgumentError("you cannot divide a number by a reciprocal lattice!")) Base.:+(lattice::ReciprocalLattice) = lattice @@ -97,5 +99,20 @@ Base.convert(::Type{ReciprocalLattice{T}}, lattice::ReciprocalLattice{S}) where Base.ndims(::Type{<:ReciprocalLattice}) = 2 Base.ndims(::ReciprocalLattice) = 2 -# See https://docs.julialang.org/en/v1/manual/interfaces/#man-interfaces-broadcasting -Base.broadcastable(lattice::ReciprocalLattice) = Ref(lattice) +# See https://github.com/JuliaLang/julia/blob/v1.10.0-rc2/base/broadcast.jl#L741 +Base.broadcastable(lattice::ReciprocalLattice) = lattice + +# See https://github.com/JuliaLang/julia/blob/v1.10.0-rc2/base/broadcast.jl#L49 +Base.BroadcastStyle(::Type{<:ReciprocalLattice}) = Broadcast.Style{ReciprocalLattice}() + +# See https://github.com/JuliaLang/julia/blob/v1.10.0-rc2/base/broadcast.jl#L135 +Base.BroadcastStyle( + ::Broadcast.AbstractArrayStyle{0}, b::Broadcast.Style{ReciprocalLattice} +) = b + +# See https://github.com/JuliaLang/julia/blob/v1.10.0-rc2/base/broadcast.jl#L1114-L1119 +Base.copy(bc::Broadcast.Broadcasted{Broadcast.Style{ReciprocalLattice}}) = + ReciprocalLattice(MMatrix{3,3}(x for x in bc)) # For uniary and binary functions + +Base.broadcasted(::typeof(/), ::Number, ::ReciprocalLattice) = + throw(ArgumentError("you cannot divide a number by a reciprocal lattice!")) diff --git a/test/lattice.jl b/test/lattice.jl index a74cbc5..b442851 100644 --- a/test/lattice.jl +++ b/test/lattice.jl @@ -94,3 +94,31 @@ using Unitful, UnitfulAtomic @test inv(inv(lattice)) == lattice end end + +@testset "Test broadcasting for lattices" begin + lattice = Lattice([1, 0, 0], [0, 1, 0], [0, 0, 1]) + @test lattice .* 4 == 4 .* lattice == Lattice([4, 0, 0], [0, 4, 0], [0, 0, 4]) + @test lattice .* 4.0 == 4.0 .* lattice == Lattice([4.0, 0, 0], [0, 4.0, 0], [0, 0, 4.0]) + @test lattice .* 4//1 == 4//1 .* lattice == Lattice([4, 0, 0], [0, 4, 0], [0, 0, 4]) + @test lattice ./ 4 == Lattice([1//4, 0, 0], [0, 1//4, 0], [0, 0, 1//4]) + @test lattice .* u"nm" == + u"nm" .* lattice == + Lattice( + [1u"nm", 0u"nm", 0u"nm"], [0u"nm", 1u"nm", 0u"nm"], [0u"nm", 0u"nm", 1u"nm"] + ) + @test lattice .* 1u"nm" == + 1u"nm" .* lattice == + Lattice( + [1u"nm", 0u"nm", 0u"nm"], [0u"nm", 1u"nm", 0u"nm"], [0u"nm", 0u"nm", 1u"nm"] + ) + @test_throws ArgumentError 4 / lattice + @test_throws ArgumentError 4.0 ./ lattice +end + +@testset "Test broadcasting for reciprocal lattices" begin + a, b, c = 4, 3, 5 + lattice = Lattice([a, -b, 0] / 2, [a, b, 0] / 2, [0, 0, c]) + @test reciprocal(lattice .* 4) == reciprocal(lattice) ./ 4 + @test_throws ArgumentError 4 / reciprocal(lattice) + @test_throws ArgumentError 4.0 ./ reciprocal(lattice) +end