Skip to content

Commit

Permalink
Merge pull request #280 from JuliaLang/yyc/vectorize
Browse files Browse the repository at this point in the history
Compat for at-vectorize_(1|2)arg deprecation
  • Loading branch information
yuyichao authored Sep 8, 2016
2 parents 24793ed + fbf0173 commit e8e69d1
Show file tree
Hide file tree
Showing 3 changed files with 246 additions and 9 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,17 @@ Currently, the `@compat` macro supports the following syntaxes:

* `@__DIR__` has been added [#18380](https://github.com/JuliaLang/julia/pull/18380)

* `@vectorize_1arg` and `@vectorize_2arg` are deprecated on Julia 0.6 in favor
of the broadcast syntax [#17302](https://github.com/JuliaLang/julia/pull/17302).
`Compat.@dep_vectorize_1arg` and `Compat.@dep_vectorize_2arg` are provided
so that packages can still provide the deprecated definitions
without causing a depwarn in the package itself before all the users
are upgraded.

Packages are expected to use this until all users of the deprecated
vectorized function have migrated. These macros will be dropped when
the support for `0.6` is dropped from `Compat`.

## Other changes

* `remotecall`, `remotecall_fetch`, `remotecall_wait`, and `remote_do` have the function to be executed remotely as the first argument in Julia 0.5. Loading `Compat` defines the same methods in older versions of Julia. [#13338](https://github.com/JuliaLang/julia/pull/13338)
Expand Down
195 changes: 187 additions & 8 deletions src/Compat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ using Base.Meta
"""Get just the function part of a function declaration."""
withincurly(ex) = isexpr(ex, :curly) ? ex.args[1] : ex

if VERSION < v"0.4.0-dev+2254"
immutable Val{T} end
export Val
end

if VERSION < v"0.4.0-dev+1419"
export UInt, UInt8, UInt16, UInt32, UInt64, UInt128
const UInt = Uint
Expand Down Expand Up @@ -431,6 +436,143 @@ end

istopsymbol(ex, mod, sym) = ex in (sym, Expr(:(.), mod, Expr(:quote, sym)))

if VERSION < v"0.5.0-dev+4002"
typealias Array0D{T} Array{T,0}
@inline broadcast_getindex(arg, idx) = arg[(idx - 1) % length(arg) + 1]
# Optimize for single element
@inline broadcast_getindex(arg::Number, idx) = arg
@inline broadcast_getindex(arg::Array0D, idx) = arg[1]

# If we know from syntax level that we don't need wrapping
@inline broadcast_getindex_naive(arg, idx) = arg[idx]
@inline broadcast_getindex_naive(arg::Number, idx) = arg
@inline broadcast_getindex_naive(arg::Array0D, idx) = arg[1]

# For vararg support
@inline getindex_vararg(idx) = ()
@inline getindex_vararg(idx, arg1) = (broadcast_getindex(arg1, idx),)
@inline getindex_vararg(idx, arg1, arg2) =
(broadcast_getindex(arg1, idx), broadcast_getindex(arg2, idx))
@inline getindex_vararg(idx, arg1, arg2, arg3, args...) =
(broadcast_getindex(arg1, idx), broadcast_getindex(arg2, idx),
broadcast_getindex(arg3, idx), getindex_vararg(idx, args...)...)

@inline getindex_naive_vararg(idx) = ()
@inline getindex_naive_vararg(idx, arg1) =
(broadcast_getindex_naive(arg1, idx),)
@inline getindex_naive_vararg(idx, arg1, arg2) =
(broadcast_getindex_naive(arg1, idx),
broadcast_getindex_naive(arg2, idx))
@inline getindex_naive_vararg(idx, arg1, arg2, arg3, args...) =
(broadcast_getindex_naive(arg1, idx),
broadcast_getindex_naive(arg2, idx),
broadcast_getindex_naive(arg3, idx),
getindex_naive_vararg(idx, args...)...)

# Decide if the result should be scalar or array
# `size() === ()` is not good enough since broadcasting on
# a scalar should return a scalar where as broadcasting on a 0-dim
# array should return a 0-dim array.
@inline should_return_array(::Val{true}, args...) = Val{true}()
@inline should_return_array(::Val{false}) = Val{false}()
@inline should_return_array(::Val{false}, arg1) = Val{false}()
@inline should_return_array(::Val{false}, arg1::AbstractArray) = Val{true}()
@inline should_return_array(::Val{false}, arg1::AbstractArray,
arg2::AbstractArray) = Val{true}()
@inline should_return_array(::Val{false}, arg1,
arg2::AbstractArray) = Val{true}()
@inline should_return_array(::Val{false}, arg1::AbstractArray,
arg2) = Val{true}()
@inline should_return_array(::Val{false}, arg1, arg2) = Val{false}()
@inline should_return_array(::Val{false}, arg1, arg2, args...) =
should_return_array(should_return_array(Val{false}(), arg1, arg2),
args...)

@inline broadcast_return(res1d, shp, ret_ary::Val{false}) = res1d[1]
@inline broadcast_return(res1d, shp, ret_ary::Val{true}) = reshape(res1d, shp)

@inline need_full_getindex(shp) = false
@inline need_full_getindex(shp, arg1::Number) = false
@inline need_full_getindex(shp, arg1::Array0D) = false
@inline need_full_getindex(shp, arg1) = shp != size(arg1)
@inline need_full_getindex(shp, arg1, arg2) =
need_full_getindex(shp, arg1) || need_full_getindex(shp, arg2)
@inline need_full_getindex(shp, arg1, arg2, arg3, args...) =
need_full_getindex(shp, arg1, arg2) || need_full_getindex(shp, arg3) ||
need_full_getindex(shp, args...)

function rewrite_broadcast(f, args)
nargs = length(args)
# This actually allows multiple splatting...,
# which is now allowed on master.
# The previous version that simply calls broadcast so removing that
# will be breaking. Oh, well....
is_vararg = Bool[isexpr(args[i], :...) for i in 1:nargs]
names = [gensym("broadcast") for i in 1:nargs]
new_args = [is_vararg[i] ? Expr(:..., names[i]) : names[i]
for i in 1:nargs]
# Optimize for common case where we know the index doesn't need
# any wrapping
naive_getidx_for = function (i, idxvar)
if is_vararg[i]
Expr(:..., :($Compat.getindex_naive_vararg($idxvar,
$(names[i])...)))
else
:($Compat.broadcast_getindex_naive($(names[i]), $idxvar))
end
end
always_naive = nargs == 1 && !is_vararg[1]
getidx_for = if always_naive
naive_getidx_for
else
function (i, idxvar)
if is_vararg[i]
Expr(:..., :($Compat.getindex_vararg($idxvar,
$(names[i])...)))
else
:($Compat.broadcast_getindex($(names[i]), $idxvar))
end
end
end
@gensym allidx
@gensym newshape
@gensym res1d
@gensym idx
@gensym ret_ary

res1d_expr = quote
$res1d = [$f($([naive_getidx_for(i, idx) for i in 1:nargs]...))
for $idx in $allidx]
end
if !always_naive
res1d_expr = quote
if $Compat.need_full_getindex($newshape, $(new_args...))
$res1d = [$f($([getidx_for(i, idx) for i in 1:nargs]...))
for $idx in $allidx]
else
$res1d_expr
end
end
end

return quote
# The `local` makes sure type inference can infer the type even
# in global scope as long as the input is type stable
local $(names...)
$([:($(names[i]) = $(is_vararg[i] ? args[i].args[1] : args[i]))
for i in 1:nargs]...)
local $newshape = $(Base.Broadcast).broadcast_shape($(new_args...))
# `eachindex` is not generic enough
local $allidx = 1:prod($newshape)
local $ret_ary = $Compat.should_return_array(Val{false}(),
$(new_args...))
local $res1d
$res1d_expr
$Compat.broadcast_return($res1d, $newshape, $ret_ary)
end
end
end

function _compat(ex::Expr)
if ex.head === :call
f = ex.args[1]
Expand Down Expand Up @@ -549,11 +691,12 @@ function _compat(ex::Expr)
return Expr(ex.head, _compat(ex.args[1]), QuoteNode(ex.args[2].args[1].args[1]))
elseif isexpr(ex.args[2], :tuple)
# f.(arg1, arg2...) -> broadcast(f, arg1, arg2...)
return Expr(:call, :broadcast, _compat(ex.args[1]), map(_compat, ex.args[2].args)...)
return rewrite_broadcast(_compat(ex.args[1]),
map(_compat, ex.args[2].args))
elseif !isa(ex.args[2], QuoteNode) &&
!(isexpr(ex.args[2], :quote) && isa(ex.args[2].args[1], Symbol))
# f.(arg) -> broadcast(f, arg)
return Expr(:call, :broadcast, _compat(ex.args[1]), _compat(ex.args[2]))
return rewrite_broadcast(_compat(ex.args[1]), [_compat(ex.args[2])])
end
elseif ex.head === :import
if VERSION < v"0.5.0-dev+4340" && length(ex.args) == 2 && ex.args[1] === :Base && ex.args[2] === :show
Expand Down Expand Up @@ -668,11 +811,6 @@ if VERSION < v"0.4.0-dev+4502"
export keytype, valtype
end

if VERSION < v"0.4.0-dev+2254"
immutable Val{T} end
export Val
end

if VERSION < v"0.4.0-dev+2840"
Base.qr(A, ::Type{Val{true}}; thin::Bool=true) =
Base.qr(A, pivot=true, thin=thin)
Expand Down Expand Up @@ -1464,8 +1602,49 @@ if VERSION < v"0.6.0-dev.374"
end

if VERSION < v"0.6.0-dev.528"
macro __DIR__() Base.source_dir() end
macro __DIR__()
Base.source_dir()
end
export @__DIR__
end

# PR #17302
# Provide a non-deprecated version of `@vectorize_(1|2)arg` macro which defines
# deprecated version of the function so that the depwarns can be fixed without
# breaking users.
# Packages are expected to use this to maintain the old API until all users
# of the deprecated vectorized function have migrated.
# These macros should raise a depwarn when the `0.5` support is dropped from
# `Compat` and be dropped when the support for `0.6` is dropped from `Compat`.
# Modified based on the version copied from 0.6 Base.
macro dep_vectorize_1arg(S, f)
S = esc(S)
f = esc(f)
T = esc(:T)
x = esc(:x)
AbsArr = esc(:AbstractArray)
## Depwarn to be enabled when 0.5 support is dropped.
# depwarn("Implicit vectorized function is deprecated in favor of compact broadcast syntax.",
# Symbol("@dep_vectorize_1arg"))
:(@deprecate $f{$T<:$S}($x::$AbsArr{$T}) @compat($f.($x)))
end

macro dep_vectorize_2arg(S, f)
S = esc(S)
f = esc(f)
T1 = esc(:T1)
T2 = esc(:T2)
x = esc(:x)
y = esc(:y)
AbsArr = esc(:AbstractArray)
## Depwarn to be enabled when 0.5 support is dropped.
# depwarn("Implicit vectorized function is deprecated in favor of compact broadcast syntax.",
# Symbol("@dep_vectorize_2arg"))
quote
@deprecate $f{$T1<:$S}($x::$S, $y::$AbsArr{$T1}) @compat($f.($x,$y))
@deprecate $f{$T1<:$S}($x::$AbsArr{$T1}, $y::$S) @compat($f.($x,$y))
@deprecate $f{$T1<:$S,$T2<:$S}($x::$AbsArr{$T1}, $y::$AbsArr{$T2}) @compat($f.($x,$y))
end
end

end # module
49 changes: 48 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1115,7 +1115,7 @@ for (Fun, func) in [(:AndFun, :&),
(:DotMulFun, :.*),
(:RDivFun, :/),
(:DotRDivFun, :./),
(:LDivFun, :\),
(:LDivFun, :\ ),
(:IDivFun, :div),
(:DotIDivFun, @compat(Symbol(""))),
(:ModFun, :mod),
Expand Down Expand Up @@ -1182,6 +1182,36 @@ let x = rand(3), y = rand(3)
@test @compat(sin.(cos.(x))) == map(x -> sin(cos(x)), x)
@test @compat(atan2.(sin.(y),x)) == broadcast(atan2,map(sin,y),x)
end
let x0 = Array(Float64), v, v0
x0[1] = rand()
v0 = @compat sin.(x0)
@test isa(v0, Array{Float64,0})
v = @compat sin.(x0[1])
@test isa(v, Float64)
@test v == v0[1] == sin(x0[1])
end
let x = rand(2, 2), v
v = @compat sin.(x)
@test isa(v, Array{Float64,2})
@test v == [sin(x[1, 1]) sin(x[1, 2]);
sin(x[2, 1]) sin(x[2, 2])]
end
let x1 = [1, 2, 3], x2 = ([3, 4, 5],), v
v = @compat atan2.(x1, x2...)
@test isa(v, Vector{Float64})
@test v == [atan2(1, 3), atan2(2, 4), atan2(3, 5)]
end
# Do the following in global scope to make sure inference is able to handle it
@test @compat(sin.([1, 2])) == [sin(1), sin(2)]
@test isa(@compat(sin.([1, 2])), Vector{Float64})
@test @compat(atan2.(1, [2, 3])) == [atan2(1, 2), atan2(1, 3)]
@test isa(@compat(atan2.(1, [2, 3])), Vector{Float64})
@test @compat(atan2.([1, 2], [2, 3])) == [atan2(1, 2), atan2(2, 3)]
@test isa(@compat(atan2.([1, 2], [2, 3])), Vector{Float64})
# And make sure it is actually inferrable
f15032(a) = @compat sin.(a)
@inferred f15032([1, 2, 3])
@inferred f15032([1.0, 2.0, 3.0])

if VERSION v"0.4.0-dev+3732"
@test Symbol("foo") === :foo
Expand Down Expand Up @@ -1383,3 +1413,20 @@ let filename = tempname()
end

@test @__DIR__() == dirname(@__FILE__)

# PR #17302
# To be removed when 0.5/0.6 support is dropped.
f17302(a::Number) = a
f17302(a::Number, b::Number) = a + b
Compat.@dep_vectorize_1arg Real f17302
Compat.@dep_vectorize_2arg Real f17302
@test_throws MethodError f17302([1im])
@test_throws MethodError f17302([1im], [1im])
mktemp() do fname, f
redirect_stderr(f) do
@test f17302([1.0]) == [1.0]
@test f17302(1.0, [1]) == [2.0]
@test f17302([1.0], 1) == [2.0]
@test f17302([1.0], [1]) == [2.0]
end
end

0 comments on commit e8e69d1

Please sign in to comment.