Skip to content

Commit

Permalink
Refactor common parts
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub committed Aug 26, 2024
1 parent 5a53ef2 commit 712a201
Showing 1 changed file with 20 additions and 17 deletions.
37 changes: 20 additions & 17 deletions src/fillbroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,33 +118,36 @@ _isfill(f::Number) = true
_isfill(f::Ref) = true
_isfill(::Any) = false

_broadcast_maybecopy(bc::Broadcast.Broadcasted{<:AbstractFillStyle}) = copy(bc)
_broadcast_maybecopy(bc::Broadcast.Broadcasted) = Broadcast.broadcasted(bc.f, map(_broadcast_maybecopy, bc.args)...)
_broadcast_maybecopy(x) = x
function _copy_fill(bc)
v = _getindex_value(bc)
if _iszeros(bc)
return Zeros(typeof(v), axes(bc))
elseif _isones(bc)
return Ones(typeof(v), axes(bc))
end
return Fill(v, axes(bc))
end

# recursively copy the purely fill components
function _preprocess_fill(bc::Broadcast.Broadcasted{<:AbstractFillStyle})
_isfill(bc) ? _copy_fill(bc) : Broadcast.broadcasted(bc.f, map(_preprocess_fill, bc.args)...)
end
_preprocess_fill(bc::Broadcast.Broadcasted) = Broadcast.broadcasted(bc.f, map(_preprocess_fill, bc.args)...)

Check warning on line 135 in src/fillbroadcast.jl

View check run for this annotation

Codecov / codecov/patch

src/fillbroadcast.jl#L135

Added line #L135 was not covered by tests
_preprocess_fill(x) = x

function _fallback_copy(bc)
# treat the fill components
bc2 = Base.broadcasted(bc.f, map(_broadcast_maybecopy, bc.args)...)
# copy the purely fill components
bc2 = Base.broadcasted(bc.f, map(_preprocess_fill, bc.args)...)
# fallback style
S = Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{ndims(bc)}}
copy(convert(S, bc2))
end

function Base.copy(bc::Broadcast.Broadcasted{<:AbstractFillStyle})
if _iszeros(bc)
return Zeros(typeof(_getindex_value(bc)), axes(bc))
elseif _isones(bc)
return Ones(typeof(_getindex_value(bc)), axes(bc))
elseif _isfill(bc)
return Fill(_getindex_value(bc), axes(bc))
else
_fallback_copy(bc)
end
_isfill(bc) ? _copy_fill(bc) : _fallback_copy(bc)
end
# make the zero-dimensional case consistent with Base
function Base.copy(bc::Broadcast.Broadcasted{<:AbstractFillStyle{0}})
_fallback_copy(bc)
end
Base.copy(bc::Broadcast.Broadcasted{<:AbstractFillStyle{0}}) = _fallback_copy(bc)

# some cases that preserve 0d
function broadcast_preserving_0d(f, As...)
Expand Down

0 comments on commit 712a201

Please sign in to comment.