From 46ffc870470ab7b33f82a1303a4d4c5ac35963ac Mon Sep 17 00:00:00 2001 From: "Steven G. Johnson" Date: Fri, 22 Jul 2016 17:35:57 -0400 Subject: [PATCH 1/3] improve #17546 for a[...] .= handling of arrays of arrays and dictionaries of arrays --- base/broadcast.jl | 26 +++++++++++++++++++++++++- src/julia-syntax.scm | 2 +- test/broadcast.jl | 10 ++++++++++ 3 files changed, 36 insertions(+), 2 deletions(-) diff --git a/base/broadcast.jl b/base/broadcast.jl index 62e96bc9ee4b6..d087344eb5003 100644 --- a/base/broadcast.jl +++ b/base/broadcast.jl @@ -5,7 +5,7 @@ module Broadcast using Base.Cartesian using Base: promote_op, promote_eltype, promote_eltype_op, @get!, _msk_end, unsafe_bitgetindex, linearindices, tail, OneTo, to_shape import Base: .+, .-, .*, ./, .\, .//, .==, .<, .!=, .<=, .รท, .%, .<<, .>>, .^ -export broadcast, broadcast!, bitbroadcast +export broadcast, broadcast!, bitbroadcast, dotview export broadcast_getindex, broadcast_setindex! ## Broadcasting utilities ## @@ -440,4 +440,28 @@ for (sigA, sigB) in ((BitArray, BitArray), end end +############################################################ + +# x[...] .= f.(y...) ---> broadcast!(f, dotview(x, ...), y...). +# The dotview function defaults to view, but we override it in +# a few cases to get the expected in-place behavior without affecting +# explicit calls to view. (All of this can go away if slices +# are changed to generate views by default.) + +dotview(args...) = view(args...) +# avoid splatting penalty in common cases: +for nargs = 0:5 + args = Symbol[Symbol("x",i) for i = 1:nargs] + eval(Expr(:(=), Expr(:call, :dotview, args...), Expr(:call, :view, args...))) +end + +# for a[i...] .= ... where a is an array-of-arrays, just pass a[i...] directly +# to broadcast! +dotview{T<:AbstractArray,N,I<:Integer}(a::AbstractArray{T,N}, i::Vararg{I,N}) = + a[i...] + +# dict[k] .= ... should work if dict[k] is an array +dotview(a::Associative, k) = a[k] +dotview(a::Associative, k1, k2, ks...) = a[tuple(k1,k2,ks...)] + end # module diff --git a/src/julia-syntax.scm b/src/julia-syntax.scm index e98029ebe1487..a8e3e704d114b 100644 --- a/src/julia-syntax.scm +++ b/src/julia-syntax.scm @@ -1549,7 +1549,7 @@ (let* ((ex (partially-expand-ref expr)) (stmts (butlast (cdr ex))) (refex (last (cdr ex))) - (nuref `(call (top view) ,(caddr refex) ,@(cdddr refex)))) + (nuref `(call (top dotview) ,(caddr refex) ,@(cdddr refex)))) `(block ,@stmts ,nuref)) expr)) diff --git a/test/broadcast.jl b/test/broadcast.jl index 52aeacc859041..70b2b39beb3fc 100644 --- a/test/broadcast.jl +++ b/test/broadcast.jl @@ -274,6 +274,16 @@ let x = [1:4;], y = x x[2:end] .= 1:3 @test y === x == [0,1,2,3] end +let a = [[4, 5], [6, 7]] + a[1] .= 3 + @test a == [[3, 3], [6, 7]] +end +let d = Dict(:foo => [1,3,7], (3,4) => [5,9]) + d[:foo] .+= 2 + @test d[:foo] == [3,5,9] + d[3,4] .-= 1 + @test d[3,4] == [4,8] +end # PR 16988 @test Base.promote_op(+, Bool) === Int From df4a58da2c29b39eab2e2fdfb762bdab42708f9c Mon Sep 17 00:00:00 2001 From: "Steven G. Johnson" Date: Tue, 26 Jul 2016 20:59:00 -0400 Subject: [PATCH 2/3] default dotview to getindex, using view only for AbstractArray --- base/broadcast.jl | 19 +++++++------------ doc/manual/functions.rst | 2 +- 2 files changed, 8 insertions(+), 13 deletions(-) diff --git a/base/broadcast.jl b/base/broadcast.jl index d087344eb5003..651103c216034 100644 --- a/base/broadcast.jl +++ b/base/broadcast.jl @@ -443,25 +443,20 @@ end ############################################################ # x[...] .= f.(y...) ---> broadcast!(f, dotview(x, ...), y...). -# The dotview function defaults to view, but we override it in +# The dotview function defaults to getindex, but we override it in # a few cases to get the expected in-place behavior without affecting # explicit calls to view. (All of this can go away if slices # are changed to generate views by default.) -dotview(args...) = view(args...) +dotview(args...) = getindex(args...) +dotview(A::AbstractArray, args...) = view(A, args...) # avoid splatting penalty in common cases: for nargs = 0:5 args = Symbol[Symbol("x",i) for i = 1:nargs] - eval(Expr(:(=), Expr(:call, :dotview, args...), Expr(:call, :view, args...))) + eval(Expr(:(=), Expr(:call, :dotview, args...), + Expr(:call, :getindex, args...))) + eval(Expr(:(=), Expr(:call, :dotview, :(A::AbstractArray), args...), + Expr(:call, :view, :A, args...))) end -# for a[i...] .= ... where a is an array-of-arrays, just pass a[i...] directly -# to broadcast! -dotview{T<:AbstractArray,N,I<:Integer}(a::AbstractArray{T,N}, i::Vararg{I,N}) = - a[i...] - -# dict[k] .= ... should work if dict[k] is an array -dotview(a::Associative, k) = a[k] -dotview(a::Associative, k1, k2, ks...) = a[tuple(k1,k2,ks...)] - end # module diff --git a/doc/manual/functions.rst b/doc/manual/functions.rst index 6f43aec2fecad..b950a3a942b67 100644 --- a/doc/manual/functions.rst +++ b/doc/manual/functions.rst @@ -660,7 +660,7 @@ calls do not allocate new arrays over and over again for the results except that, as above, the ``broadcast!`` loop is fused with any nested "dot" calls. For example, ``X .= sin.(Y)`` is equivalent to ``broadcast!(sin, X, Y)``, overwriting ``X`` with ``sin.(Y)`` in-place. -If the left-hand side is a ``getindex`` expression, e.g. +If the left-hand side is an array-indexing expression, e.g. ``X[2:end] .= sin.(Y)``, then it translates to ``broadcast!`` on a ``view``, e.g. ``broadcast!(sin, view(X, 2:endof(X)), Y)``, so that the left-hand side is updated in-place. From 0eb3e3b83d58cc630b13ac863f2f4391bee0e4e1 Mon Sep 17 00:00:00 2001 From: "Steven G. Johnson" Date: Tue, 26 Jul 2016 21:53:20 -0400 Subject: [PATCH 3/3] re-fix array-of-arrays case --- base/broadcast.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/base/broadcast.jl b/base/broadcast.jl index 651103c216034..cddf772de5e99 100644 --- a/base/broadcast.jl +++ b/base/broadcast.jl @@ -450,6 +450,7 @@ end dotview(args...) = getindex(args...) dotview(A::AbstractArray, args...) = view(A, args...) +dotview{T<:AbstractArray}(A::AbstractArray{T}, args...) = getindex(A, args...) # avoid splatting penalty in common cases: for nargs = 0:5 args = Symbol[Symbol("x",i) for i = 1:nargs]