Skip to content

Commit

Permalink
Merge pull request #31 from MineralsCloud:Broadcast
Browse files Browse the repository at this point in the history
Rewrite broadcasting interface for `AbstractLattice`s
  • Loading branch information
singularitti authored Dec 13, 2023
2 parents 06e3933 + aacf9f3 commit fc05336
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 7 deletions.
7 changes: 7 additions & 0 deletions ext/UnitfulExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,11 @@ uconvert(u::Units, lattice::ReciprocalLattice) =
ustrip(lattice::Lattice) = Lattice(map(ustrip, parent(lattice)))
ustrip(lattice::ReciprocalLattice) = ReciprocalLattice(map(ustrip, parent(lattice)))

# You need this to let the broadcasting work.
Base.:*(lattice::Lattice, unit::Units) = Lattice(parent(lattice) * unit)
Base.:*(unit::Units, lattice::Lattice) = lattice * unit

# You need this to let the broadcasting work.
Base.:/(lattice::Lattice, unit::Units) = Lattice(parent(lattice) / unit)

end
18 changes: 16 additions & 2 deletions src/lattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ Base.:*(x::Number, lattice::Lattice) = lattice * x

# You need this to let the broadcasting work.
Base.:/(lattice::Lattice, x::Number) = Lattice(parent(lattice) / x)
Base.:/(::Number, ::Lattice) =
throw(ArgumentError("you cannot divide a number by a lattice!"))

Base.:+(lattice::Lattice) = lattice
# You need this to let the broadcasting work.
Expand All @@ -176,5 +178,17 @@ Base.convert(::Type{Lattice{T}}, lattice::Lattice{S}) where {S,T} =
Base.ndims(::Type{<:Lattice}) = 2
Base.ndims(::Lattice) = 2

# See https://docs.julialang.org/en/v1/manual/interfaces/#man-interfaces-broadcasting
Base.broadcastable(lattice::Lattice) = Ref(lattice)
# See https://github.com/JuliaLang/julia/blob/v1.10.0-rc2/base/broadcast.jl#L741
Base.broadcastable(lattice::Lattice) = lattice

# See https://github.com/JuliaLang/julia/blob/v1.10.0-rc2/base/broadcast.jl#L49
Base.BroadcastStyle(::Type{<:Lattice}) = Broadcast.Style{Lattice}()

# See https://github.com/JuliaLang/julia/blob/v1.10.0-rc2/base/broadcast.jl#L135
Base.BroadcastStyle(::Broadcast.AbstractArrayStyle{0}, b::Broadcast.Style{Lattice}) = b

# See https://github.com/JuliaLang/julia/blob/v1.10.0-rc2/base/broadcast.jl#L1114-L1119
Base.copy(bc::Broadcast.Broadcasted{Broadcast.Style{Lattice}}) = Lattice(x for x in bc) # For uniary and binary functions

Base.broadcasted(::typeof(/), ::Number, ::Lattice) =
throw(ArgumentError("you cannot divide a number by a lattice!"))
27 changes: 22 additions & 5 deletions src/reciprocal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,13 @@ Base.firstindex(::ReciprocalLattice) = 1
Base.lastindex(::ReciprocalLattice) = 9

# You need this to let the broadcasting work.
Base.:*(lattice::ReciprocalLattice, x) = ReciprocalLattice(parent(lattice) * x)
Base.:*(x, lattice::ReciprocalLattice) = lattice * x
Base.:*(lattice::ReciprocalLattice, x::Number) = ReciprocalLattice(parent(lattice) * x)
Base.:*(x::Number, lattice::ReciprocalLattice) = lattice * x

# You need this to let the broadcasting work.
Base.:/(lattice::ReciprocalLattice, x) = ReciprocalLattice(parent(lattice) / x)
Base.:/(lattice::ReciprocalLattice, x::Number) = ReciprocalLattice(parent(lattice) / x)
Base.:/(::Number, ::ReciprocalLattice) =
throw(ArgumentError("you cannot divide a number by a reciprocal lattice!"))

Base.:+(lattice::ReciprocalLattice) = lattice

Expand All @@ -97,5 +99,20 @@ Base.convert(::Type{ReciprocalLattice{T}}, lattice::ReciprocalLattice{S}) where
Base.ndims(::Type{<:ReciprocalLattice}) = 2
Base.ndims(::ReciprocalLattice) = 2

# See https://docs.julialang.org/en/v1/manual/interfaces/#man-interfaces-broadcasting
Base.broadcastable(lattice::ReciprocalLattice) = Ref(lattice)
# See https://github.com/JuliaLang/julia/blob/v1.10.0-rc2/base/broadcast.jl#L741
Base.broadcastable(lattice::ReciprocalLattice) = lattice

# See https://github.com/JuliaLang/julia/blob/v1.10.0-rc2/base/broadcast.jl#L49
Base.BroadcastStyle(::Type{<:ReciprocalLattice}) = Broadcast.Style{ReciprocalLattice}()

# See https://github.com/JuliaLang/julia/blob/v1.10.0-rc2/base/broadcast.jl#L135
Base.BroadcastStyle(
::Broadcast.AbstractArrayStyle{0}, b::Broadcast.Style{ReciprocalLattice}
) = b

# See https://github.com/JuliaLang/julia/blob/v1.10.0-rc2/base/broadcast.jl#L1114-L1119
Base.copy(bc::Broadcast.Broadcasted{Broadcast.Style{ReciprocalLattice}}) =
ReciprocalLattice(MMatrix{3,3}(x for x in bc)) # For uniary and binary functions

Base.broadcasted(::typeof(/), ::Number, ::ReciprocalLattice) =
throw(ArgumentError("you cannot divide a number by a reciprocal lattice!"))
28 changes: 28 additions & 0 deletions test/lattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,31 @@ using Unitful, UnitfulAtomic
@test inv(inv(lattice)) == lattice
end
end

@testset "Test broadcasting for lattices" begin
lattice = Lattice([1, 0, 0], [0, 1, 0], [0, 0, 1])
@test lattice .* 4 == 4 .* lattice == Lattice([4, 0, 0], [0, 4, 0], [0, 0, 4])
@test lattice .* 4.0 == 4.0 .* lattice == Lattice([4.0, 0, 0], [0, 4.0, 0], [0, 0, 4.0])
@test lattice .* 4//1 == 4//1 .* lattice == Lattice([4, 0, 0], [0, 4, 0], [0, 0, 4])
@test lattice ./ 4 == Lattice([1//4, 0, 0], [0, 1//4, 0], [0, 0, 1//4])
@test lattice .* u"nm" ==
u"nm" .* lattice ==
Lattice(
[1u"nm", 0u"nm", 0u"nm"], [0u"nm", 1u"nm", 0u"nm"], [0u"nm", 0u"nm", 1u"nm"]
)
@test lattice .* 1u"nm" ==
1u"nm" .* lattice ==
Lattice(
[1u"nm", 0u"nm", 0u"nm"], [0u"nm", 1u"nm", 0u"nm"], [0u"nm", 0u"nm", 1u"nm"]
)
@test_throws ArgumentError 4 / lattice
@test_throws ArgumentError 4.0 ./ lattice
end

@testset "Test broadcasting for reciprocal lattices" begin
a, b, c = 4, 3, 5
lattice = Lattice([a, -b, 0] / 2, [a, b, 0] / 2, [0, 0, c])
@test reciprocal(lattice .* 4) == reciprocal(lattice) ./ 4
@test_throws ArgumentError 4 / reciprocal(lattice)
@test_throws ArgumentError 4.0 ./ reciprocal(lattice)
end

0 comments on commit fc05336

Please sign in to comment.