Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve a[...] .= handling of arrays of arrays and dicts of arrays #17568

Merged
merged 3 commits into from
Jul 27, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ module Broadcast
using Base.Cartesian
using Base: promote_op, promote_eltype, promote_eltype_op, @get!, _msk_end, unsafe_bitgetindex, linearindices, tail, OneTo, to_shape
import Base: .+, .-, .*, ./, .\, .//, .==, .<, .!=, .<=, .÷, .%, .<<, .>>, .^
export broadcast, broadcast!, bitbroadcast
export broadcast, broadcast!, bitbroadcast, dotview
export broadcast_getindex, broadcast_setindex!

## Broadcasting utilities ##
Expand Down Expand Up @@ -440,4 +440,24 @@ for (sigA, sigB) in ((BitArray, BitArray),
end
end

############################################################

# x[...] .= f.(y...) ---> broadcast!(f, dotview(x, ...), y...).
# The dotview function defaults to getindex, but we override it in
# a few cases to get the expected in-place behavior without affecting
# explicit calls to view. (All of this can go away if slices
# are changed to generate views by default.)

dotview(args...) = getindex(args...)
dotview(A::AbstractArray, args...) = view(A, args...)
dotview{T<:AbstractArray}(A::AbstractArray{T}, args...) = getindex(A, args...)
Copy link
Member

@andreasnoack andreasnoack Feb 1, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stevengj Why did you decide to default to getindex and not view in this case? @alanedelman showed me this example today

s = zeros(2,2) 
A = fill(s, 4, 4)
A[1:3,1:3] .= [ones(2,2)]

which doesn't set A because of this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, it should be view also. I was thinking that view wasn't necessary, since A[i,j] returns a mutable object.

# avoid splatting penalty in common cases:
for nargs = 0:5
args = Symbol[Symbol("x",i) for i = 1:nargs]
eval(Expr(:(=), Expr(:call, :dotview, args...),
Expr(:call, :getindex, args...)))
eval(Expr(:(=), Expr(:call, :dotview, :(A::AbstractArray), args...),
Expr(:call, :view, :A, args...)))
end

end # module
2 changes: 1 addition & 1 deletion doc/manual/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,7 @@ calls do not allocate new arrays over and over again for the results
except that, as above, the ``broadcast!`` loop is fused with any nested
"dot" calls. For example, ``X .= sin.(Y)`` is equivalent to
``broadcast!(sin, X, Y)``, overwriting ``X`` with ``sin.(Y)`` in-place.
If the left-hand side is a ``getindex`` expression, e.g.
If the left-hand side is an array-indexing expression, e.g.
``X[2:end] .= sin.(Y)``, then it translates to ``broadcast!`` on a ``view``,
e.g. ``broadcast!(sin, view(X, 2:endof(X)), Y)``, so that the left-hand
side is updated in-place.
Expand Down
2 changes: 1 addition & 1 deletion src/julia-syntax.scm
Original file line number Diff line number Diff line change
Expand Up @@ -1549,7 +1549,7 @@
(let* ((ex (partially-expand-ref expr))
(stmts (butlast (cdr ex)))
(refex (last (cdr ex)))
(nuref `(call (top view) ,(caddr refex) ,@(cdddr refex))))
(nuref `(call (top dotview) ,(caddr refex) ,@(cdddr refex))))
`(block ,@stmts ,nuref))
expr))

Expand Down
10 changes: 10 additions & 0 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,16 @@ let x = [1:4;], y = x
x[2:end] .= 1:3
@test y === x == [0,1,2,3]
end
let a = [[4, 5], [6, 7]]
a[1] .= 3
@test a == [[3, 3], [6, 7]]
end
let d = Dict(:foo => [1,3,7], (3,4) => [5,9])
d[:foo] .+= 2
@test d[:foo] == [3,5,9]
d[3,4] .-= 1
@test d[3,4] == [4,8]
end

# PR 16988
@test Base.promote_op(+, Bool) === Int
Expand Down