Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix macros not working with broadcasting #6

Merged
merged 1 commit into from
Aug 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/caching.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ macro _scache(type, common_cache_dir, ex)
esc_times = escdepth(ex)
ex = unesc_comp(ex)

(typeof(ex) != Expr || ex.head != :call) && (throw(ArgumentError("`@scache[jld]` can " *
!_is_call_expr(ex) && !_is_bc_expr(ex) && (throw(ArgumentError("`@scache[jld]` can " *
"be used only with function calls: passed $(ex)")))

# hyigene
Expand All @@ -125,7 +125,7 @@ macro _scache(type, common_cache_dir, ex)
t = _convert_input(ex, true)

rel_esc(ex) = esc(ex, esc_times+1)
as, ks, vs, res = (_toexpr(t.args), _toexpr(t.kwargs)..., t.res)
as, ks, vs, res, bc = (_toexpr(t.args), _toexpr(t.kwargs)..., t.res, t.broadcast)

# TODO: make a function that interprets `res` to be inserted in `args` or `kwargs`
return quote
Expand All @@ -141,7 +141,8 @@ macro _scache(type, common_cache_dir, ex)
Dict{Symbol,Any}(
pairs($(rel_esc(res)))
)
)
),
broadcast = $(rel_esc(bc))
))

if cached_obj_exists($(type), $(common_cache_dir), _hash)
Expand Down
56 changes: 50 additions & 6 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,43 @@ function throw_n_log(msg::AbstractString; error_type::Type{<:Exception} = ErrorE
throw(error_type(msg))
end

"""
_is_call_expr(expression)

Return whether an expression is a function call or not.
"""
_is_call_expr(::Any) = false
_is_call_expr(ex::Expr) = ex.head == :call

"""
_is_bc_expr(expression)

Return whether an expression is a broadcast or not.
"""
_is_bc_expr(::Any) = false
_is_bc_expr(ex::Expr) = ex.head == :.

"""
_bc2call(expression)

Convert a broadcast expression to a simple call expression.
"""
_bc2call(ex::Any) = ex
function _bc2call(ex::Expr)
if !_is_bc_expr(ex)
return ex
end

@assert typeof(ex.args[2]) == Expr && ex.args[2].head == :tuple "Found unusual " *
"broadcast expression: $(dump(ex))"

res = deepcopy(ex)
res.head = :call
res.args = res.args[2].args

return res
end

## conversion utilities ##

"""
Expand All @@ -105,8 +142,8 @@ x = 10
This function is called inside [`@_scache`](@ref SimpleCaching.@_scache) to generate the
string that will fill the column `COMMAND` in the cache record (if generated).
"""
_strarg(arg::Any, top_level::Bool = false) = Expr(:call, :string, arg)
function _strarg(arg::Expr, top_level::Bool = false)
_strarg(arg::Any, top_level::Bool = false, broadcast::Bool = false) = Expr(:call, :string, arg)
function _strarg(arg::Expr, top_level::Bool = false, broadcast::Bool = false)
if top_level && arg.head == :escape
return _strarg(arg.args[1], true)
elseif top_level && arg.head == :call
Expand All @@ -125,12 +162,16 @@ function _strarg(arg::Expr, top_level::Bool = false)
append!(_kw, filter(x -> typeof(x) == Expr && x.head == :kw, arg.args[2:end]))

return Expr(:call, :string,
"$(arg.args[1])(",
"$(arg.args[1])",
broadcast ? "." : "",
"(",
Expr(:call, :join, _toexpr(_strarg.(_args)), ", "),
length(_kw) > 0 ? "; " : "",
length(_kw) > 0 ? Expr(:call, :join, _toexpr(_strarg.(_kw, true)), ", ") : "",
")"
)
elseif top_level && arg.head == :.
return _strarg(_bc2call(arg), top_level, true)
elseif top_level && arg.head == :kw
return Expr(:call, :string, "$(arg.args[1]) = ", _strarg(arg.args[2]))
elseif top_level && arg.head == :parameters
Expand Down Expand Up @@ -161,8 +202,8 @@ containing the keyword arguments passed.
Note: a dictionary is used for the keyword arguments because otherwise the hash would change
based on the their order.
"""
_convert_input(arg::Any, top_level::Bool = false) = arg
function _convert_input(arg::Expr, top_level::Bool = false)
_convert_input(arg::Any, top_level::Bool = false, broadcast::Bool = false) = arg
function _convert_input(arg::Expr, top_level::Bool = false, broadcast::Bool = false)
_splat2pairs(v::AbstractVector) = length(v) == 0 ? [] : _splat2pairs(v[1])
_splat2pairs(ex::Expr) = Expr(:call, :pairs, ex.args[1])

Expand All @@ -189,8 +230,11 @@ function _convert_input(arg::Expr, top_level::Bool = false)
return (
args = [_convert_input(a, false) for a in _args],
kwargs = Dict{Symbol,Any}(_convert_input.(_kw, true)...),
res = _splat2pairs(_res)
res = _splat2pairs(_res),
broadcast = broadcast
)
elseif top_level && arg.head == :.
return _convert_input(_bc2call(arg), top_level, true)
elseif top_level && arg.head == :kw
return arg.args[1] => _convert_input(arg.args[2])
elseif top_level && arg.head == :parameters
Expand Down
13 changes: 13 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ using Test
const testing_cache_dir = mktempdir(prefix = "SimpleCaching_test_")
const cached_type = "testing_object"

const mat1 = [1 0; 0 1]
const mat2 = [1 1; 0 1]

module A
using SimpleCaching

Expand Down Expand Up @@ -42,6 +45,16 @@ end
@test res3 == res
@test res4 == res

# test broadcasting
res1_normal = mat1 * mat2
res2_normal = mat1 .* mat2
res1_cache = @scache mat1 * mat2
res2_cache = @scache mat1 .* mat2

@test res1_normal == res1_cache
@test res2_normal == res2_cache

# clean
rm(testing_cache_dir; recursive = true)

@testset "no type" begin
Expand Down