Skip to content

Commit

Permalink
Merge pull request #6 from ferdiu/dev
Browse files Browse the repository at this point in the history
Fix macros not working with broadcasting
  • Loading branch information
ferdiu authored Aug 15, 2023
2 parents 63927d6 + a8b7ee8 commit 2814d1d
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 9 deletions.
7 changes: 4 additions & 3 deletions src/caching.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ macro _scache(type, common_cache_dir, ex)
esc_times = escdepth(ex)
ex = unesc_comp(ex)

(typeof(ex) != Expr || ex.head != :call) && (throw(ArgumentError("`@scache[jld]` can " *
!_is_call_expr(ex) && !_is_bc_expr(ex) && (throw(ArgumentError("`@scache[jld]` can " *
"be used only with function calls: passed $(ex)")))

# hyigene
Expand All @@ -125,7 +125,7 @@ macro _scache(type, common_cache_dir, ex)
t = _convert_input(ex, true)

rel_esc(ex) = esc(ex, esc_times+1)
as, ks, vs, res = (_toexpr(t.args), _toexpr(t.kwargs)..., t.res)
as, ks, vs, res, bc = (_toexpr(t.args), _toexpr(t.kwargs)..., t.res, t.broadcast)

# TODO: make a function that interprets `res` to be inserted in `args` or `kwargs`
return quote
Expand All @@ -141,7 +141,8 @@ macro _scache(type, common_cache_dir, ex)
Dict{Symbol,Any}(
pairs($(rel_esc(res)))
)
)
),
broadcast = $(rel_esc(bc))
))

if cached_obj_exists($(type), $(common_cache_dir), _hash)
Expand Down
56 changes: 50 additions & 6 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,43 @@ function throw_n_log(msg::AbstractString; error_type::Type{<:Exception} = ErrorE
throw(error_type(msg))
end

"""
_is_call_expr(expression)
Return whether an expression is a function call or not.
"""
_is_call_expr(::Any) = false
_is_call_expr(ex::Expr) = ex.head == :call

"""
_is_bc_expr(expression)
Return whether an expression is a broadcast or not.
"""
_is_bc_expr(::Any) = false
_is_bc_expr(ex::Expr) = ex.head == :.

"""
_bc2call(expression)
Convert a broadcast expression to a simple call expression.
"""
_bc2call(ex::Any) = ex
function _bc2call(ex::Expr)
if !_is_bc_expr(ex)
return ex
end

@assert typeof(ex.args[2]) == Expr && ex.args[2].head == :tuple "Found unusual " *
"broadcast expression: $(dump(ex))"

res = deepcopy(ex)
res.head = :call
res.args = res.args[2].args

return res
end

## conversion utilities ##

"""
Expand All @@ -105,8 +142,8 @@ x = 10
This function is called inside [`@_scache`](@ref SimpleCaching.@_scache) to generate the
string that will fill the column `COMMAND` in the cache record (if generated).
"""
_strarg(arg::Any, top_level::Bool = false) = Expr(:call, :string, arg)
function _strarg(arg::Expr, top_level::Bool = false)
_strarg(arg::Any, top_level::Bool = false, broadcast::Bool = false) = Expr(:call, :string, arg)
function _strarg(arg::Expr, top_level::Bool = false, broadcast::Bool = false)
if top_level && arg.head == :escape
return _strarg(arg.args[1], true)
elseif top_level && arg.head == :call
Expand All @@ -125,12 +162,16 @@ function _strarg(arg::Expr, top_level::Bool = false)
append!(_kw, filter(x -> typeof(x) == Expr && x.head == :kw, arg.args[2:end]))

return Expr(:call, :string,
"$(arg.args[1])(",
"$(arg.args[1])",
broadcast ? "." : "",
"(",
Expr(:call, :join, _toexpr(_strarg.(_args)), ", "),
length(_kw) > 0 ? "; " : "",
length(_kw) > 0 ? Expr(:call, :join, _toexpr(_strarg.(_kw, true)), ", ") : "",
")"
)
elseif top_level && arg.head == :.
return _strarg(_bc2call(arg), top_level, true)
elseif top_level && arg.head == :kw
return Expr(:call, :string, "$(arg.args[1]) = ", _strarg(arg.args[2]))
elseif top_level && arg.head == :parameters
Expand Down Expand Up @@ -161,8 +202,8 @@ containing the keyword arguments passed.
Note: a dictionary is used for the keyword arguments because otherwise the hash would change
based on the their order.
"""
_convert_input(arg::Any, top_level::Bool = false) = arg
function _convert_input(arg::Expr, top_level::Bool = false)
_convert_input(arg::Any, top_level::Bool = false, broadcast::Bool = false) = arg
function _convert_input(arg::Expr, top_level::Bool = false, broadcast::Bool = false)
_splat2pairs(v::AbstractVector) = length(v) == 0 ? [] : _splat2pairs(v[1])
_splat2pairs(ex::Expr) = Expr(:call, :pairs, ex.args[1])

Expand All @@ -189,8 +230,11 @@ function _convert_input(arg::Expr, top_level::Bool = false)
return (
args = [_convert_input(a, false) for a in _args],
kwargs = Dict{Symbol,Any}(_convert_input.(_kw, true)...),
res = _splat2pairs(_res)
res = _splat2pairs(_res),
broadcast = broadcast
)
elseif top_level && arg.head == :.
return _convert_input(_bc2call(arg), top_level, true)
elseif top_level && arg.head == :kw
return arg.args[1] => _convert_input(arg.args[2])
elseif top_level && arg.head == :parameters
Expand Down
13 changes: 13 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ using Test
const testing_cache_dir = mktempdir(prefix = "SimpleCaching_test_")
const cached_type = "testing_object"

const mat1 = [1 0; 0 1]
const mat2 = [1 1; 0 1]

module A
using SimpleCaching

Expand Down Expand Up @@ -42,6 +45,16 @@ end
@test res3 == res
@test res4 == res

# test broadcasting
res1_normal = mat1 * mat2
res2_normal = mat1 .* mat2
res1_cache = @scache mat1 * mat2
res2_cache = @scache mat1 .* mat2

@test res1_normal == res1_cache
@test res2_normal == res2_cache

# clean
rm(testing_cache_dir; recursive = true)

@testset "no type" begin
Expand Down

0 comments on commit 2814d1d

Please sign in to comment.