From d9cb84a36e00280b71320159b2565f5cd0785cf3 Mon Sep 17 00:00:00 2001 From: Takafumi Arakaki Date: Sun, 28 Apr 2019 21:11:51 -0700 Subject: [PATCH 1/2] Move macro definition to lazymacro.jl --- src/LazyArrays.jl | 3 ++- src/lazybroadcasting.jl | 46 ----------------------------------------- src/lazymacro.jl | 43 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 45 insertions(+), 47 deletions(-) create mode 100644 src/lazymacro.jl diff --git a/src/LazyArrays.jl b/src/LazyArrays.jl index 78f8f55d..1b77d67f 100644 --- a/src/LazyArrays.jl +++ b/src/LazyArrays.jl @@ -48,7 +48,7 @@ end export Mul, MulArray, MulVector, MulMatrix, InvMatrix, PInvMatrix, Hcat, Vcat, Kron, BroadcastArray, cache, Ldiv, Inv, PInv, Diff, Cumsum, - applied, materialize, ApplyArray, apply, ⋆ + applied, materialize, ApplyArray, apply, ⋆, @~ include("memorylayout.jl") include("cache.jl") @@ -58,5 +58,6 @@ include("lazyconcat.jl") include("linalg/linalg.jl") include("lazysetoperations.jl") include("lazyoperations.jl") +include("lazymacro.jl") end # module diff --git a/src/lazybroadcasting.jl b/src/lazybroadcasting.jl index d3186ce6..befd9e3c 100644 --- a/src/lazybroadcasting.jl +++ b/src/lazybroadcasting.jl @@ -61,52 +61,6 @@ function Base._prod(f, A::BroadcastArray, ::Colon) out end -# Macros for lazy broadcasting, #21 WIP -# based on @dawbarton https://discourse.julialang.org/t/19641/20 -# and @tkf https://github.com/JuliaLang/julia/issues/19198#issuecomment-457967851 -# and @chethega https://github.com/JuliaLang/julia/pull/30939 - -export @~ - -lazy(::Any) = throw(ArgumentError("function `lazy` exists only for its effect on broadcasting, see the macro @~")) -struct LazyCast{T} - value::T -end -Broadcast.broadcasted(::typeof(lazy), x) = LazyCast(x) -Broadcast.materialize(x::LazyCast) = BroadcastArray(x.value) - -""" - @~ expr - -Macro for creating lazy `BroadcastArray`s. -Expects a broadcasting expression, possibly created by the `@.` macro: -``` -julia> @~ A .+ B ./ 2 - -julia> @~ @. A + B / 2 -``` -""" -macro ~(ex) - checkex(ex) - esc( :( $lazy.($ex) ) ) -end - -using MacroTools - -function checkex(ex) - if @capture(ex, (arg__,) = val_ ) - if arg[2]==:dims - throw(ArgumentError("@~ is capturing keyword arguments, try with `; dims = $val` instead of a comma")) - else - throw(ArgumentError("@~ is probably capturing capturing keyword arguments, try with ; or brackets")) - end - end - if @capture(ex, (arg_,rest__) ) - throw(ArgumentError("@~ is capturing more than one expression, try $name($arg) with brackets")) - end - ex -end - BroadcastStyle(::Type{<:BroadcastArray{<:Any,N}}) where N = LazyArrayStyle{N}() BroadcastStyle(L::LazyArrayStyle{N}, ::StaticArrayStyle{N}) where N = L diff --git a/src/lazymacro.jl b/src/lazymacro.jl new file mode 100644 index 00000000..4211ea79 --- /dev/null +++ b/src/lazymacro.jl @@ -0,0 +1,43 @@ +# Macros for lazy broadcasting, +# based on @dawbarton https://discourse.julialang.org/t/19641/20 +# and @tkf https://github.com/JuliaLang/julia/issues/19198#issuecomment-457967851 +# and @chethega https://github.com/JuliaLang/julia/pull/30939 + +using MacroTools + +lazy(::Any) = throw(ArgumentError("function `lazy` exists only for its effect on broadcasting, see the macro @~")) +struct LazyCast{T} + value::T +end +Broadcast.broadcasted(::typeof(lazy), x) = LazyCast(x) +Broadcast.materialize(x::LazyCast) = BroadcastArray(x.value) + +""" + @~ expr + +Macro for creating lazy `BroadcastArray`s. +Expects a broadcasting expression, possibly created by the `@.` macro: +``` +julia> @~ A .+ B ./ 2 + +julia> @~ @. A + B / 2 +``` +""" +macro ~(ex) + checkex(ex) + esc( :( $lazy.($ex) ) ) +end + +function checkex(ex) + if @capture(ex, (arg__,) = val_ ) + if arg[2]==:dims + throw(ArgumentError("@~ is capturing keyword arguments, try with `; dims = $val` instead of a comma")) + else + throw(ArgumentError("@~ is probably capturing capturing keyword arguments, try with ; or brackets")) + end + end + if @capture(ex, (arg_,rest__) ) + throw(ArgumentError("@~ is capturing more than one expression, try $name($arg) with brackets")) + end + ex +end From 37e8474fa2aee777e576f2d222ee677ca6775204 Mon Sep 17 00:00:00 2001 From: Takafumi Arakaki Date: Sun, 28 Apr 2019 21:14:54 -0700 Subject: [PATCH 2/2] Support creating Applied in @~ --- README.md | 14 +++++-- src/LazyArrays.jl | 2 +- src/lazyapplying.jl | 11 +++++- src/lazybroadcasting.jl | 4 +- src/lazymacro.jl | 86 ++++++++++++++++++++++++++++++++++------- test/macrotests.jl | 60 ++++++++++++++++++++++++++++ test/runtests.jl | 15 +------ 7 files changed, 158 insertions(+), 34 deletions(-) create mode 100644 test/macrotests.jl diff --git a/README.md b/README.md index c3cf1756..50bb6520 100644 --- a/README.md +++ b/README.md @@ -82,15 +82,15 @@ julia> B == A .+ 2 true ``` Such arrays can also be created using the macro `@~` which acts on ordinary -broadcasting expressions: +broadcasting expressions combined with `LazyArray`: ```julia julia> C = rand(1000)'; -julia> D = @~ exp.(C) +julia> D = LazyArray(@~ exp.(C)) -julia> E = @~ @. 2 + log(C) +julia> E = LazyArray(@~ @. 2 + log(C)) -julia> @btime sum(@~ C .* C'; dims=1) # without `@~`, 1.438 ms (5 allocations: 7.64 MiB) +julia> @btime sum(LazyArray(@~ C .* C'); dims=1) # without `@~`, 1.438 ms (5 allocations: 7.64 MiB) 74.425 μs (7 allocations: 8.08 KiB) ``` @@ -139,6 +139,12 @@ julia> @btime 2*(A*b) + 3c; # does not call gemv! 241.659 ns (4 allocations: 512 bytes) ``` +Using `@~` macro, above expression using `Mul` can also be written as + +```julia +d .= @~ 2.0 .* (A * b) .+ 3.0 .* c +``` + ## Inverses We also have lazy inverses `PInv(A)`, designed to work alongside `Mul` to diff --git a/src/LazyArrays.jl b/src/LazyArrays.jl index 1b77d67f..edb2f50d 100644 --- a/src/LazyArrays.jl +++ b/src/LazyArrays.jl @@ -48,7 +48,7 @@ end export Mul, MulArray, MulVector, MulMatrix, InvMatrix, PInvMatrix, Hcat, Vcat, Kron, BroadcastArray, cache, Ldiv, Inv, PInv, Diff, Cumsum, - applied, materialize, ApplyArray, apply, ⋆, @~ + applied, materialize, ApplyArray, apply, ⋆, @~, LazyArray include("memorylayout.jl") include("cache.jl") diff --git a/src/lazyapplying.jl b/src/lazyapplying.jl index 8182687c..ba77be7e 100644 --- a/src/lazyapplying.jl +++ b/src/lazyapplying.jl @@ -77,14 +77,23 @@ eltype(A::Applied{<:MatrixFunctionStyle}) = eltype(first(A.args)) getindex(A::Applied{<:MatrixFunctionStyle}, k::Int, j::Int) = materialize(A)[k,j] +""" + LazyArray(x::Applied) :: ApplyArray + LazyArray(x::Broadcasted) :: BroadcastArray -struct ApplyArray{T, N, App<:Applied} <: AbstractArray{T,N} +Wrap a lazy object that wraps a computation producing an array to an +array. +""" +abstract type LazyArray{T,N} <: AbstractArray{T,N} end + +struct ApplyArray{T, N, App<:Applied} <: LazyArray{T,N} applied::App end const ApplyVector{T, App<:Applied} = ApplyArray{T, 1, App} const ApplyMatrix{T, App<:Applied} = ApplyArray{T, 2, App} +LazyArray(A::Applied) = ApplyArray(A) ApplyArray{T,N}(M::App) where {T,N,App<:Applied} = ApplyArray{T,N,App}(M) ApplyArray{T}(M::Applied) where {T} = ApplyArray{T,ndims(M)}(M) diff --git a/src/lazybroadcasting.jl b/src/lazybroadcasting.jl index befd9e3c..bcea2a1c 100644 --- a/src/lazybroadcasting.jl +++ b/src/lazybroadcasting.jl @@ -3,10 +3,12 @@ LazyArrayStyle(::Val{N}) where N = LazyArrayStyle{N}() LazyArrayStyle{M}(::Val{N}) where {N,M} = LazyArrayStyle{N}() -struct BroadcastArray{T, N, BRD<:Broadcasted} <: AbstractArray{T, N} +struct BroadcastArray{T, N, BRD<:Broadcasted} <: LazyArray{T, N} broadcasted::BRD end +LazyArray(bc::Broadcasted) = BroadcastArray(bc) + BroadcastArray{T,N}(bc::BRD) where {T,N,BRD<:Broadcasted} = BroadcastArray{T,N,BRD}(bc) BroadcastArray{T}(bc::Broadcasted{<:Union{Nothing,BroadcastStyle},<:Tuple{Vararg{Any,N}},<:Any,<:Tuple}) where {T,N} = BroadcastArray{T,N}(bc) diff --git a/src/lazymacro.jl b/src/lazymacro.jl index 4211ea79..e18e3d6e 100644 --- a/src/lazymacro.jl +++ b/src/lazymacro.jl @@ -10,22 +10,59 @@ struct LazyCast{T} value::T end Broadcast.broadcasted(::typeof(lazy), x) = LazyCast(x) -Broadcast.materialize(x::LazyCast) = BroadcastArray(x.value) +Broadcast.materialize(x::LazyCast) = x.value -""" - @~ expr -Macro for creating lazy `BroadcastArray`s. -Expects a broadcasting expression, possibly created by the `@.` macro: -``` -julia> @~ A .+ B ./ 2 +is_call(ex::Expr) = + ex.head == :call && !startswith(String(ex.args[1]), ".") -julia> @~ @. A + B / 2 -``` -""" -macro ~(ex) - checkex(ex) - esc( :( $lazy.($ex) ) ) +is_dotcall(ex::Expr) = + ex.head == :. || (ex.head == :call && startswith(String(ex.args[1]), ".")) +# e.g., `f.(x, y, z)` or `x .+ y .+ z` + +lazy_expr(x) = x +function lazy_expr(ex::Expr) + if is_dotcall(ex) + return bc_expr(ex) + elseif is_call(ex) + return app_expr(ex) + else + # TODO: Maybe better to support `a ? b : c` etc.? But how? + return ex + end +end + +function bc_expr(ex::Expr) + @assert is_dotcall(ex) + return :($(Broadcast.instantiate)($lazy.($(bc_expr_impl(ex))))) +end + +bc_expr_impl(x) = x +function bc_expr_impl(ex::Expr) + # walk down chain of dot calls + if is_dotcall(ex) + return Expr(ex.head, + lazy_expr(ex.args[1]), # function name (`f`, `.+`, etc.) + bc_expr_impl.(ex.args[2:end])...) # arguments + else + return lazy_expr(ex) + end +end + +function app_expr(ex::Expr) + @assert is_call(ex) + # instantiate? + return app_expr_impl(ex) +end + +app_expr_impl(x) = x +function app_expr_impl(ex::Expr) + # walk down chain of calls and lazy-ify them + if is_call(ex) + return :($applied($(app_expr_impl.(ex.args)...))) + else + return lazy_expr(ex) + end end function checkex(ex) @@ -41,3 +78,26 @@ function checkex(ex) end ex end + +""" + @~ expr + +Macro for creating a `Broadcasted` or `Applied` object. Regular calls +like `f(args...)` inside `expr` are replaced with `applied(f, args...)`. +Dotted-calls like `f(args...)` inside `expr` are replaced with +`broadcasted.(f, args...)`. Use `LazyArray(@~ expr)` if you need an +array-based interface. + +``` +julia> @~ A .+ B ./ 2 + +julia> @~ @. A + B / 2 + +julia> @~ A * B + C +``` +""" +macro ~(ex) + checkex(ex) + # Expanding macro here to support, e.g., `@.` + esc(lazy_expr(macroexpand(__module__, ex))) +end diff --git a/test/macrotests.jl b/test/macrotests.jl new file mode 100644 index 00000000..56151d16 --- /dev/null +++ b/test/macrotests.jl @@ -0,0 +1,60 @@ +module MacroTests + +using Test, LazyArrays, MacroTools + +A = randn(6, 6) +B = BroadcastArray(+, A, 2) +C = randn(6, 6) + +expressions_block = quote + exp.(A) + @. exp(A) + # exp(A) + A .+ 2 + @. A + 2 + A + B + @. A + B + A * B + C + # A * B .+ C + A * (B + C) + # A * (B .+ C) + # 2 .* (A * B) .+ 3 .* C +end +testparams = [ + ("$(rmlines(ex))", ex) for ex in expressions_block.args if ex isa Expr +] + +@testset "@~" begin + @testset "$label" for (label, ex) in testparams + desired = @eval $ex + lazy = @eval @~ $ex + @test lazy isa Union{Broadcast.Broadcasted, LazyArrays.Applied} + + @testset ".= @~ $label" begin + actual = zero(desired) + actual .= lazy + @test actual == desired + end + + @testset "materialize(@~ $label)" begin + @test materialize(lazy) == desired + end + + @testset "LazyArray(@~ $label)" begin + actual = LazyArray(lazy) :: LazyArray + @test actual == desired + end + + @testset "materialize(LazyArray(@~ $label))" begin + @test materialize(LazyArray(lazy)) == desired + end + + @testset ".= LazyArray(@~ $label)" begin + actual = zero(desired) + actual .= LazyArray(lazy) + @test actual == desired + end + end +end + +end # module diff --git a/test/runtests.jl b/test/runtests.jl index 9a80ab3e..b3ff0537 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,6 +6,7 @@ include("multests.jl") include("ldivtests.jl") include("addtests.jl") include("setoptests.jl") +include("macrotests.jl") @testset "concat" begin @@ -225,27 +226,13 @@ end @testset "BroadcastArray" begin A = randn(6,6) - B = BroadcastArray(exp, A) - B′ = @~ exp.(A) - B′′ = @~ @. exp(A) @test Matrix(B) == exp.(A) - @test Matrix(B′) == exp.(A) - @test Matrix(B′′) == exp.(A) C = BroadcastArray(+, A, 2) - C′ = @~ A .+ 2 - C′′ = @~ @. A + 2 @test C == A .+ 2 - @test C′ == A .+ 2 - @test C′′ == A .+ 2 - D = BroadcastArray(+, A, C) - D′ = @~ A + C - D′′ = @~ @. A + C @test D == A + C - @test D′ == A + C - @test D′′ == A + C @test sum(B) ≈ sum(exp, A) @test sum(C) ≈ sum(A .+ 2)