diff --git a/Project.toml b/Project.toml index 3e844484..34a1ef3f 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,7 @@ Future = "9fa8497b-333b-5362-9e8d-4d0656e87820" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Requires = "ae029012-a4dd-5104-9daa-d747884805df" +Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] @@ -19,6 +20,7 @@ CompositionsBase = "0.1" ConstructionBase = "1.2" MacroTools = "0.4.4, 0.5" Requires = "0.5, 1.0" +Static = "0.3" StaticNumbers = "0.3" julia = "1.3" diff --git a/src/optics.jl b/src/optics.jl index 9d532096..06b455dd 100644 --- a/src/optics.jl +++ b/src/optics.jl @@ -1,11 +1,12 @@ export @optic export set, modify export ∘, opcompose, var"⨟" -export Elements, Recursive, If, Properties +export Elements, Recursive, Query, If, Properties export setproperties export constructorof using ConstructionBase using CompositionsBase +using Static using Base: getproperty using Base @@ -134,6 +135,10 @@ else using Base: Returns end +struct Changed end +struct Unchanged end + + @inline function _set(obj, optic, val, ::ModifyBased) modify(Returns(val), obj, optic) end @@ -318,6 +323,185 @@ function _modify(f, obj, r::Recursive, ::ModifyBased) end end +""" + modify_stateful(f, (obj,state), optic) => Tuple{NewValue,NewState} + +Here `f` has signature `f(::Value, ::State) => Tuple{NewValue,NewState}`. +""" +function modify_stateful end + +@generated function modify_stateful(f::F, (obj, state)::T, optic::Properties) where {T,F} + _modify_stateful_inner(T) +end + +# Separated for testing object/state combinations without restarts +function _modify_stateful_inner(::Type{<:Tuple{O,S}}) where {O,S} + modifications = [] + vals = Expr(:tuple) + fns = fieldnames(O) + for (i, fn) in enumerate(fns) + v = Symbol("val$i") + st = if S <: ContextState + if O <: Tuple + :(ContextState(state.vals, obj, StaticInt{$(QuoteNode(fn))}())) + else + :(ContextState(state.vals, obj, StaticSymbol{$(QuoteNode(fn))}())) + end + else + :state + end + ms = :(($v, state) = f(getfield(props, $(QuoteNode(fn))), $st)) + push!(modifications, ms) + push!(vals.args, v) + end + patch = O <: Tuple ? vals : :(NamedTuple{$fns}($vals)) + start = :(props = getproperties(obj)) + rest = MacroTools.@q begin + patch = $patch + new_obj = maybesetproperties(state, obj, patch) + return (new_obj, state) + end + Expr(:block, start, modifications..., rest) +end + +maybesetproperties(state, obj, patch) = setproperties(obj, patch) +maybesetstate(state, obj, patch) = state + +abstract type AbstractQuery end + +""" + Query(select, descend, optic) + Query(; select=Any, descend=x -> true, optic=Properties()) + +Query an object recursively, choosing fields where `select` +returns `true`, and descending when `descend` returns `true`. + +```jldoctest +julia> using Accessors + +julia> q = Query(; select=x -> x isa Int, descend=x -> x isa Tuple) +Query{var"#5#7", var"#6#8", Properties}(var"#5#7"(), var"#6#8"(), Properties()) + +julia> obj = (7, (a=17.0, b=2.0f0), ("3", 4, 5.0), ((x=19, a=6.0,)), [1]) +(7, (a = 17.0, b = 2.0f0), ("3", 4, 5.0), (x = 19, a = 6.0), [1]) + +julia> q(obj) +(7, 4) +``` +$EXPERIMENTAL +""" +struct Query{Select,Descend,Optic<:Union{ComposedOptic,Properties}} <: AbstractQuery + select_condition::Select + descent_condition::Descend + optic::Optic +end +Query(select, descend=x -> true) = Query(select, descend, Properties()) +Query(; select=Any, descend=x -> true, optic=Properties()) = Query(select, descend, optic) + +OpticStyle(::Type{<:AbstractQuery}) = SetBased() + +struct ContextState{V,O,FN} + vals::V + obj::O + fn::FN +end +struct GetAllState{V} + vals::V +end +struct SetAllState{C,V,I} + change::C + vals::V + itr::I +end + +const GetStates = Union{GetAllState,ContextState} + +@inline pop(x) = first(x), Base.tail(x) +@inline push(x, val) = (x..., val) +@inline push(x::GetAllState, val) = GetAllState(push(x.vals, val)) +@inline push(x::ContextState, val) = ContextState(push(x.vals, val), nothing, nothing) + +(q::Query)(obj) = _getall(obj, q) + +function _getall(obj, q::Q) where Q<:Query + initial_state = GetAllState(()) + _, final_state = let q=q + modify_stateful((obj, initial_state), q) do o, s + new_state = push(s, outer(q.optic, o, s)) + o, new_state + end + end + final_state.vals +end + +function set(obj, q::Q, vals) where Q<:Query + initial_state = SetAllState(Unchanged(), vals, 1) + final_obj, _ = let obj=obj, q=q, initial_state=initial_state + modify_stateful((obj, initial_state), q) do o, s + new_output = outer(q.optic, o, s) + new_state = SetAllState(Changed(), s.vals, s.itr + 1) + new_output, new_state + end + end + return final_obj +end + +function context(f::F, obj, q::Q) where {F,Q<:Query} + initial_state = ContextState((), nothing, nothing) + _, final_state = let f=f + modify_stateful((obj, initial_state), q) do o, s + new_state = push(s, f(s.obj, known(s.fn))) + o, new_state + end + end + return final_state.vals +end + +modify(f, obj, q::Query) = set(obj, q, map(f, q(obj))) + +@inline function modify_stateful(f::F, (obj, state), q::Q) where {F,Q<:Query} + let f=f, q=q + modify_stateful((obj, state), inner(q.optic)) do o, s + if (q::Q).select_condition(o) + (f::F)(o, s) + elseif (q::Q).descent_condition(o) + ds = descent_state(s) + o, ns = modify_stateful(f::F, (o, ds), q::Q) + o, merge_state(ds, ns) + else + o, s + end + end + end +end + +maybesetproperties(state::GetStates, obj, patch) = obj +maybesetproperties(state::SetAllState, obj, patch) = + maybesetproperties(state.change, state, obj, patch) +maybesetproperties(::Changed, state::SetAllState, obj, patch) = setproperties(obj, patch) +maybesetproperties(::Unchanged, state::SetAllState, obj, patch) = obj + +descent_state(state::SetAllState) = SetAllState(Unchanged(), state.vals, state.itr) +descent_state(state) = state + +merge_state(s1::SetAllState, s2) = SetAllState(anychanged(s1, s2), s2.vals, s2.itr) +merge_state(s1, s2) = s2 + +anychanged(s1, s2) = anychanged(s1.change, s2.change) +anychanged(::Unchanged, ::Unchanged) = Unchanged() +anychanged(::Unchanged, ::Changed) = Changed() +anychanged(::Changed, ::Unchanged) = Changed() +anychanged(::Changed, ::Changed) = Changed() + +inner(optic) = optic +inner(optic::ComposedOptic) = optic.inner + +outer(optic, o, state::GetStates) = o +outer(optic::ComposedOptic, o, state::GetStates) = optic.outer(o) +outer(optic::ComposedOptic, o, state::SetAllState) = set(o, optic.outer, state.vals[state.itr]) +outer(optic, o, state::SetAllState) = state.vals[state.itr] + + ################################################################################ ##### Lenses ################################################################################ diff --git a/src/setindex.jl b/src/setindex.jl index 919c38e3..7c40308d 100644 --- a/src/setindex.jl +++ b/src/setindex.jl @@ -2,6 +2,10 @@ Base.@propagate_inbounds function setindex(args...) Base.setindex(args...) end +Base.@propagate_inbounds function setindex(xs::NamedTuple{K}, v, i::Int) where K + Base.setindex(xs, v, K[i]) +end + @inline setindex(::Base.RefValue, val) = Ref(val) Base.@propagate_inbounds function setindex(xs::AbstractArray, v, I...) diff --git a/src/sugar.jl b/src/sugar.jl index 3d5494dd..2edee6a5 100644 --- a/src/sugar.jl +++ b/src/sugar.jl @@ -1,4 +1,4 @@ -export @set, @optic, @reset, @modify +export @set, @optic, @reset, @modify, @getall, @setall using MacroTools """ @@ -84,13 +84,106 @@ end This function can be used to create a customized variant of [`@modify`](@ref). See also [`opticmacro`](@ref), [`setmacro`](@ref). """ - function modifymacro(optictransform, f, obj_optic) f = esc(f) obj, optic = parse_obj_optic(obj_optic) :(($modify)($f, $obj, $(optictransform)($optic))) end +""" + @getall [x for x in obs if f(x)] + +Get each `x` in `obj` that matches the condition `f`. + +This can be combined with other optics, e.g. + +```julia +julia> using Accessors + +julia> obj = ("1", 2, 3, (a=4, b="5")) +("1", 2, 3, (a = 4, b = "5")) + +julia> @getall (x for x in obj if x isa Number && iseven(x)) +(2, 4) +``` +""" +macro getall(ex) + getallmacro(ex) +end +macro getall(ex, descend) + getallmacro(ex; descend=descend) +end + +function getallmacro(ex; descend=true) + # Wrap descend in an anonoymous function + descend = :(descend -> $descend) + if @capture(ex, (lens_ for var_ in obj_ if select_)) + select = _select(select, var) + optic =_optics(lens) + :(Query($select, $descend, $optic)($(esc(obj)))) + elseif @capture(ex, [lens_ for var_ in obj_ if select_]) + select = _select(select, var) + optic =_optics(lens) + :([Query($select, $descend, $optic)($(esc(obj)))...]) + elseif @capture(ex, (lens_ for var_ in obj_)) + select = _ -> false + optic = _optics(lens) + :(Query($select, $descend, $optic)($(esc(obj)))) + elseif @capture(ex, [lens_ for var_ in obj_]) + select = _ -> false + optic = _optics(lens) + :([Query($select, $descend, $optic)($(esc(obj)))...]) + else + error("@getall must be passed a generator or array comprehension") + end +end + +# Turn this into an anonoymous function so it +# doesn't matter which argument val is in +_select(select, val) = :($(esc(val)) -> $(esc(select))) +function _optics(ex) + obj, optic = parse_obj_optic(ex) + :($optic ∘ Properties()) +end + +""" + @setall [x for x in obs if f(x)] = values + +Set each `x` in `obj` matching the condition `f` +to values from the `Tuple` or vector `values`. + +# Example + +Used combination with lenses to set the `b` field of the +second item of all `Tuple`: + +```jldoctest +julia> using Accessors + +julia> obj = ("x", (1, (a = missing, b = :y), (2, (a = missing, b = :b)))) +("x", (1, (a = missing, b = :y), (2, (a = missing, b = :b)))) + +julia> @setall (x[2].b for x in obj if x isa Tuple) = (:x, :a) +("x", (1, (a = missing, b = :x), (2, (a = missing, b = :b)))) +``` +""" +macro setall(ex) + setallmacro(ex) +end + +function setallmacro(ex) + if @capture(ex, ((lens_ for var_ in obj_ if select_) = vals_)) + select = _select(select, var) + optic =_optics(lens) + :(set($(esc(obj)), Query(; select=$select, optic=$optic), $(esc(vals)))) + elseif @capture(ex, ((lens_ for var_ in obj_) = vals_)) + optic = _optics(lens) + :(set($(esc(obj)), Query(; optic=$optic), $(esc(vals)))) + else + error("@setall must be passed a generator") + end +end + foldtree(op, init, x) = op(init, x) foldtree(op, init, ex::Expr) = op(foldl((acc, x) -> foldtree(op, acc, x), ex.args; init=init), ex) @@ -310,4 +403,3 @@ function show_composition_order(io::IO, optic::ComposedOptic) show_composition_order(io, optic.inner) print(io, ")") end - diff --git a/test/test_queries.jl b/test/test_queries.jl new file mode 100644 index 00000000..bde5ce33 --- /dev/null +++ b/test/test_queries.jl @@ -0,0 +1,60 @@ +using Accessors, Test, BenchmarkTools, Static +using Accessors: setall, getall, context +obj = (7, (a=17.0, b=2.0f0), ("3", 4, 5.0), ((x=19, a=6.0,)), [1]) +vals = (1.0, 2.0, 3.0, 4.0) +# Fields is the default +q = Query(; + select=x -> x isa NamedTuple, + descend=x -> x isa Tuple, + optic = (Accessors.@optic _.a) ∘ Accessors.Properties() + # optic = Accessors.Properties() +) +q(obj) + +@code_native q(obj) +@code_warntype q(obj) + +@benchmark $q($obj) +@test q(obj) == (17.0, 6.0) + +# using ProfileView, Cthulhu +# @descend getall(obj, q) +# f(obj, q) = for i in 1:10000000 getall(obj, q) end +# @profview f(obj, q) + +missings_obj = (a=missing, b=1, c=(d=missing, e=(f=missing, g=2))) +@test Query(ismissing)(missings_obj) === (missing, missing, missing) +@benchmark Query(ismissing)($missings_obj) + +# Need a wrapper so we don't have to pass in the starting iterator +set(obj, q, vals) +@benchmark set($obj, $q, $vals) +# using ProfileView +# @profview for i in 1:1000000 setall(obj, q, vals) end +@code_native set(obj, q, vals) +@code_warntype set(obj, q, vals) + +# @btime Accessors.set($obj, $slowlens, $vals) +@test set(obj, q, vals) == + (7, (a=1.0, b=2.0f0), ("3", 4, 5.0), ((x=19, a=2.0,)), [1]) + +@btime modify(x -> 10x, $obj, $q) + +# Context +q = Query(; + select=x -> x isa Int, + descend=x -> x isa NamedTuple, + optic = Accessors.Properties() +) +obj2 = (1.0, :a, (b=2, c=2)) +@test context((o, fn) -> fn, obj2, q) == (:b, :c) +@test context((o, fn) -> typeof(o), obj2, q) == (typeof(obj2[3]), typeof(obj2[3])) +@btime context((o, fn) -> fn, $obj2, $q) + +# Macros +@test (@getall (x for x in missings_obj if x isa Number)) == (1, 2) +expected = (a=missing, b=5, c=(d=missing, e=(f=missing, g=6))) +@test (@setall (x for x in missings_obj if x isa Number) = (5, 6)) === expected +@test (@getall (x[2].g for x in missings_obj if x isa NamedTuple)) == (2,) +@test (@setall (x[2].g for x in missings_obj if x isa NamedTuple) = 5) === + (a=missing, b=1, c=(d=missing, e=(f=missing, g=5)))