From 5b2fcb68800875e570d7bb8c78ed00d360b6cfd5 Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Mon, 27 Nov 2023 12:01:53 -0500 Subject: [PATCH] Compiler support for optimizing PersistentDict (#51993) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This is part of the work to address #51352 by attempting to allow the compiler to perform SRAO on persistent data structures like `PersistentDict` as if they were regular immutable data structures. These sorts of data structures have very complicated internals (with lots of mutation, memory sharing, etc.), but a relatively simple interface. As such, it is unlikely that our compiler will have sufficient power to optimize this interface by analyzing the implementation. We thus need to come up with some other mechanism that gives the compiler license to perform the requisite optimization. One way would be to just hardcode `PersistentDict` into the compiler, optimizing it like any of the other builtin datatypes. However, this is of course very unsatisfying. At the other end of the spectrum would be something like a generic rewrite rule system (e-graphs anyone?) that would let the PersistentDict implementation declare its interface to the compiler and the compiler would use this for optimization (in a perfect world, the actual rewrite would then be checked using some sort of formal methods). I think that would be interesting, but we're very far from even being able to design something like that (at least in Base - experiments with external AbstractInterpreters in this direction are encouraged). This PR tries to come up with a reasonable middle ground, where the compiler gets some knowledge of the protocol hardcoded without having to know about the implementation details of the data structure. The basic ideas is that `Core` provides some magic generic functions that implementations can extend. Semantically, they are not special. They dispatch as usual, and implementations are expected to work properly even in the absence of any compiler optimizations. However, the compiler is semantically permitted to perform structural optimization using these magic generic functions. In the concrete case, this PR introduces the `KeyValue` interface which consists of two generic functions, `get` and `set`. The core optimization is that the compiler is allowed to rewrite any occurrence of `get(set(x, k, v), k)` into `v` without additional legality checks. In particular, the compiler performs no type checks, conversions, etc. The higher level implementation code is expected to do all that. This approach closely matches the general direction we've been taking in external AbstractInterpreters for embedding additional semantics and optimization opportunities into Julia code (although we generally use methods there, rather than full generic functions), so I think we have some evidence that this sort of approach works reasonably well. Nevertheless, this is certainly an experiment and the interface is explicitly declared unstable. ## Current Status This is fully working and implemented, but the optimization currently bails on anything but the simplest cases. Filling all those cases in is not particularly hard, but should be done along with a more invasive refactoring of SROA, so we should figure out the general direction here first and then we can finish all that up in a follow-up cleanup. ## Obligatory benchmark Before: ``` julia> using BenchmarkTools julia> function foo() a = Base.PersistentDict(:a => 1) return a[:a] end foo (generic function with 1 method) julia> @benchmark foo() BenchmarkTools.Trial: 10000 samples with 993 evaluations. Range (min … max): 32.940 ns … 28.754 μs ┊ GC (min … max): 0.00% … 99.76% Time (median): 49.647 ns ┊ GC (median): 0.00% Time (mean ± σ): 57.519 ns ± 333.275 ns ┊ GC (mean ± σ): 10.81% ± 2.22% ▃█▅ ▁▃▅▅▃▁ ▁▃▂ ▂ ▁▂▄▃▅▇███▇▃▁▂▁▁▁▁▁▁▁▁▂▂▅██████▅▂▁▁▁▁▁▁▁▁▁▁▂▃▃▇███▇▆███▆▄▃▃▂▂ ▃ 32.9 ns Histogram: frequency by time 68.6 ns < Memory estimate: 128 bytes, allocs estimate: 4. julia> @code_typed foo() CodeInfo( 1 ─ %1 = invoke Vector{Union{Base.HashArrayMappedTries.HAMT{Symbol, Int64}, Base.HashArrayMappedTries.Leaf{Symbol, Int64}}}(Base.HashArrayMappedTries.undef::UndefInitializer, 1::Int64)::Vector{Union{Base.HashArrayMappedTries.HAMT{Symbol, Int64}, Base.HashArrayMappedTries.Leaf{Symbol, Int64}}} │ %2 = %new(Base.HashArrayMappedTries.HAMT{Symbol, Int64}, %1, 0x00000000)::Base.HashArrayMappedTries.HAMT{Symbol, Int64} │ %3 = %new(Base.HashArrayMappedTries.Leaf{Symbol, Int64}, :a, 1)::Base.HashArrayMappedTries.Leaf{Symbol, Int64} │ %4 = Base.getfield(%2, :data)::Vector{Union{Base.HashArrayMappedTries.HAMT{Symbol, Int64}, Base.HashArrayMappedTries.Leaf{Symbol, Int64}}} │ %5 = $(Expr(:boundscheck, true))::Bool └── goto #5 if not %5 2 ─ %7 = Base.sub_int(1, 1)::Int64 │ %8 = Base.bitcast(UInt64, %7)::UInt64 │ %9 = Base.getfield(%4, :size)::Tuple{Int64} │ %10 = $(Expr(:boundscheck, true))::Bool │ %11 = Base.getfield(%9, 1, %10)::Int64 │ %12 = Base.bitcast(UInt64, %11)::UInt64 │ %13 = Base.ult_int(%8, %12)::Bool └── goto #4 if not %13 3 ─ goto #5 4 ─ %16 = Core.tuple(1)::Tuple{Int64} │ invoke Base.throw_boundserror(%4::Vector{Union{Base.HashArrayMappedTries.HAMT{Symbol, Int64}, Base.HashArrayMappedTries.Leaf{Symbol, Int64}}}, %16::Tuple{Int64})::Union{} └── unreachable 5 ┄ %19 = Base.getfield(%4, :ref)::MemoryRef{Union{Base.HashArrayMappedTries.HAMT{Symbol, Int64}, Base.HashArrayMappedTries.Leaf{Symbol, Int64}}} │ %20 = Base.memoryref(%19, 1, false)::MemoryRef{Union{Base.HashArrayMappedTries.HAMT{Symbol, Int64}, Base.HashArrayMappedTries.Leaf{Symbol, Int64}}} │ Base.memoryrefset!(%20, %3, :not_atomic, false)::MemoryRef{Union{Base.HashArrayMappedTries.HAMT{Symbol, Int64}, Base.HashArrayMappedTries.Leaf{Symbol, Int64}}} └── goto #6 6 ─ %23 = Base.getfield(%2, :bitmap)::UInt32 │ %24 = Base.or_int(%23, 0x00010000)::UInt32 │ Base.setfield!(%2, :bitmap, %24)::UInt32 └── goto #7 7 ─ %27 = %new(Base.PersistentDict{Symbol, Int64}, %2)::Base.PersistentDict{Symbol, Int64} └── goto #8 8 ─ %29 = invoke Base.getindex(%27::Base.PersistentDict{Symbol, Int64}, :a::Symbol)::Int64 └── return %29 ``` After: ``` julia> using BenchmarkTools julia> function foo() a = Base.PersistentDict(:a => 1) return a[:a] end foo (generic function with 1 method) julia> @benchmark foo() BenchmarkTools.Trial: 10000 samples with 1000 evaluations. Range (min … max): 2.459 ns … 11.320 ns ┊ GC (min … max): 0.00% … 0.00% Time (median): 2.460 ns ┊ GC (median): 0.00% Time (mean ± σ): 2.469 ns ± 0.183 ns ┊ GC (mean ± σ): 0.00% ± 0.00% ▂ █ ▁ █ ▂ █▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁█ █ 2.46 ns Histogram: log(frequency) by time 2.47 ns < Memory estimate: 0 bytes, allocs estimate: 0. julia> @code_typed foo() CodeInfo( 1 ─ return 1 ``` --- base/boot.jl | 3 +- base/compiler/ssair/passes.jl | 82 ++++++++++++++++++++- base/dict.jl | 129 ++++++++++++++++++---------------- base/hamt.jl | 12 +++- base/optimized_generics.jl | 57 +++++++++++++++ test/compiler/irpasses.jl | 10 +++ 6 files changed, 228 insertions(+), 65 deletions(-) create mode 100644 base/optimized_generics.jl diff --git a/base/boot.jl b/base/boot.jl index b84f9b76838e0..218e4e2e533b1 100644 --- a/base/boot.jl +++ b/base/boot.jl @@ -957,7 +957,6 @@ function _hasmethod(@nospecialize(tt)) # this function has a special tfunc return Intrinsics.not_int(ccall(:jl_gf_invoke_lookup, Any, (Any, Any, UInt), tt, nothing, world) === nothing) end - # for backward compat arrayref(inbounds::Bool, A::Array, i::Int...) = Main.Base.getindex(A, i...) const_arrayref(inbounds::Bool, A::Array, i::Int...) = Main.Base.getindex(A, i...) @@ -969,4 +968,6 @@ export arrayref, arrayset, arraysize, const_arrayref # For convenience EnterNode(old::EnterNode, new_dest::Int) = EnterNode(new_dest) +include(Core, "optimized_generics.jl") + ccall(:jl_set_istopmod, Cvoid, (Any, Bool), Core, true) diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index 8b3ae55fe7429..f280b11532c19 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -6,6 +6,13 @@ function is_known_call(@nospecialize(x), @nospecialize(func), ir::Union{IRCode,I return singleton_type(ft) === func end +function is_known_invoke_or_call(@nospecialize(x), @nospecialize(func), ir::Union{IRCode,IncrementalCompact}) + isinvoke = isexpr(x, :invoke) + (isinvoke || isexpr(x, :call)) || return false + ft = argextype(x.args[isinvoke ? 2 : 1], ir) + return singleton_type(ft) === func +end + struct SSAUse kind::Symbol idx::Int @@ -819,6 +826,76 @@ function lift_svec_ref!(compact::IncrementalCompact, idx::Int, stmt::Expr) return end +function lift_leaves_keyvalue(compact::IncrementalCompact, @nospecialize(key), + leaves::Vector{Any}, 𝕃ₒ::AbstractLattice) + # For every leaf, the lifted value + lifted_leaves = LiftedLeaves() + for i = 1:length(leaves) + leaf = leaves[i] + cache_key = leaf + if isa(leaf, AnySSAValue) + (def, leaf) = walk_to_def(compact, leaf) + if is_known_invoke_or_call(def, Core.OptimizedGenerics.KeyValue.set, compact) + @assert isexpr(def, :invoke) + if length(def.args) in (5, 6) + collection = def.args[end-2] + set_key = def.args[end-1] + set_val_idx = length(def.args) + elseif length(def.args) == 4 + collection = def.args[end-1] + # Key is deleted + # TODO: Model this + return nothing + elseif length(def.args) == 3 + collection = def.args[end] + # The whole collection is deleted + # TODO: Model this + return nothing + else + return nothing + end + if set_key === key || (egal_tfunc(𝕃ₒ, argextype(key, compact), argextype(set_key, compact)) == Const(true)) + lift_arg!(compact, leaf, cache_key, def, set_val_idx, lifted_leaves) + continue + end + # TODO: Continue walking the chain + return nothing + end + end + return nothing + end + return lifted_leaves +end + +function lift_keyvalue_get!(compact::IncrementalCompact, idx::Int, stmt::Expr, 𝕃ₒ::AbstractLattice) + collection = stmt.args[end-1] + key = stmt.args[end] + + leaves, visited_philikes = collect_leaves(compact, collection, Any, 𝕃ₒ, phi_or_ifelse_predecessors) + isempty(leaves) && return + + lifted_leaves = lift_leaves_keyvalue(compact, key, leaves, 𝕃ₒ) + lifted_leaves === nothing && return + + result_t = Union{} + for v in values(lifted_leaves) + v === nothing && return + result_t = tmerge(𝕃ₒ, result_t, argextype(v.val, compact)) + end + + lifted_val = perform_lifting!(compact, + visited_philikes, key, result_t, lifted_leaves, collection, nothing) + + compact[idx] = lifted_val === nothing ? nothing : Expr(:call, Core.tuple, lifted_val.val) + if lifted_val !== nothing + if !⊑(𝕃ₒ, compact[SSAValue(idx)][:type], result_t) + compact[SSAValue(idx)][:flag] |= IR_FLAG_REFINED + end + end + + return +end + # TODO: We could do the whole lifing machinery here, but really all # we want to do is clean this up when it got inserted by inlining, # which always targets simple `svec` call or `_compute_sparams`, @@ -1004,7 +1081,7 @@ function sroa_pass!(ir::IRCode, inlining::Union{Nothing,InliningState}=nothing) for ((_, idx), stmt) in compact # check whether this statement is `getfield` / `setfield!` (or other "interesting" statement) isa(stmt, Expr) || continue - is_setfield = is_isdefined = is_finalizer = false + is_setfield = is_isdefined = is_finalizer = is_keyvalue_get = false field_ordering = :unspecified if is_known_call(stmt, setfield!, compact) 4 <= length(stmt.args) <= 5 || continue @@ -1094,6 +1171,9 @@ function sroa_pass!(ir::IRCode, inlining::Union{Nothing,InliningState}=nothing) lift_comparison!(isa, compact, idx, stmt, 𝕃ₒ) elseif is_known_call(stmt, Core.ifelse, compact) fold_ifelse!(compact, idx, stmt) + elseif is_known_invoke_or_call(stmt, Core.OptimizedGenerics.KeyValue.get, compact) + 2 == (length(stmt.args) - (isexpr(stmt, :invoke) ? 2 : 1)) || continue + lift_keyvalue_get!(compact, idx, stmt, 𝕃ₒ) elseif isexpr(stmt, :new) refine_new_effects!(𝕃ₒ, compact, idx, stmt) end diff --git a/base/dict.jl b/base/dict.jl index 83180f5c0ee1b..768f8215946b8 100644 --- a/base/dict.jl +++ b/base/dict.jl @@ -887,10 +887,35 @@ _similar_for(c::AbstractDict, ::Type{T}, itr, isz, len) where {T} = include("hamt.jl") using .HashArrayMappedTries +using Core.OptimizedGenerics: KeyValue const HAMT = HashArrayMappedTries struct PersistentDict{K,V} <: AbstractDict{K,V} trie::HAMT.HAMT{K,V} + # Serves as a marker for an empty initialization + @noinline function KeyValue.set(::Type{PersistentDict{K, V}}) where {K, V} + new{K, V}(HAMT.HAMT{K,V}()) + end + @noinline function KeyValue.set(::Type{PersistentDict{K, V}}, ::Nothing, key, val) where {K, V} + new{K, V}(HAMT.HAMT{K, V}(key => val)) + end + @noinline function KeyValue.set(dict::PersistentDict{K, V}, key, val) where {K, V} + trie = dict.trie + h = HAMT.HashState(key) + found, present, trie, i, bi, top, hs = HAMT.path(trie, key, h, #=persistent=# true) + HAMT.insert!(found, present, trie, i, bi, hs, val) + return new{K, V}(top) + end + @noinline function KeyValue.set(dict::PersistentDict{K, V}, key) where {K, V} + trie = dict.trie + h = HAMT.HashState(key) + found, present, trie, i, bi, top, _ = HAMT.path(trie, key, h, #=persistent=# true) + if found && present + deleteat!(trie.data, i) + HAMT.unset!(trie, bi) + end + return new{K, V}(top) + end end """ @@ -925,19 +950,27 @@ Base.PersistentDict{Symbol, Int64} with 1 entry: """ PersistentDict -PersistentDict{K,V}() where {K,V} = PersistentDict(HAMT.HAMT{K,V}()) -PersistentDict{K,V}(KV::Pair) where {K,V} = PersistentDict(HAMT.HAMT{K,V}(KV)) -PersistentDict(KV::Pair{K,V}) where {K,V} = PersistentDict(HAMT.HAMT{K,V}(KV)) +PersistentDict{K,V}() where {K, V} = KeyValue.set(PersistentDict{K,V}) +function PersistentDict{K,V}(KV::Pair) where {K,V} + KeyValue.set( + PersistentDict{K, V}, + nothing, + KV...) +end +function PersistentDict(KV::Pair{K,V}) where {K,V} + KeyValue.set( + PersistentDict{K, V}, + nothing, + KV...) +end PersistentDict(dict::PersistentDict, pair::Pair) = PersistentDict(dict, pair...) PersistentDict{K,V}(dict::PersistentDict{K,V}, pair::Pair) where {K,V} = PersistentDict(dict, pair...) + + function PersistentDict(dict::PersistentDict{K,V}, key, val) where {K,V} key = convert(K, key) val = convert(V, val) - trie = dict.trie - h = HAMT.HashState(key) - found, present, trie, i, bi, top, hs = HAMT.path(trie, key, h, #=persistent=# true) - HAMT.insert!(found, present, trie, i, bi, hs, val) - return PersistentDict(top) + return KeyValue.set(dict, key, val) end function PersistentDict{K,V}(KV::Pair, rest::Pair...) where {K,V} @@ -959,84 +992,60 @@ end eltype(::PersistentDict{K,V}) where {K,V} = Pair{K,V} function in(key_val::Pair{K,V}, dict::PersistentDict{K,V}, valcmp=(==)) where {K,V} - trie = dict.trie - if HAMT.islevel_empty(trie) - return false - end - key, val = key_val - - h = HAMT.HashState(key) - found, present, trie, i, _, _, _ = HAMT.path(trie, key, h) - if found && present - leaf = @inbounds trie.data[i]::HAMT.Leaf{K,V} - return valcmp(val, leaf.val) && return true - end - return false + found = KeyValue.get(dict, key) + found === nothing && return false + return valcmp(val, only(found)) end function haskey(dict::PersistentDict{K}, key::K) where K - trie = dict.trie - h = HAMT.HashState(key) - found, present, _, _, _, _, _ = HAMT.path(trie, key, h) - return found && present + return KeyValue.get(dict, key) !== nothing end function getindex(dict::PersistentDict{K,V}, key::K) where {K,V} - trie = dict.trie - if HAMT.islevel_empty(trie) - throw(KeyError(key)) - end - h = HAMT.HashState(key) - found, present, trie, i, _, _, _ = HAMT.path(trie, key, h) - if found && present - leaf = @inbounds trie.data[i]::HAMT.Leaf{K,V} - return leaf.val - end - throw(KeyError(key)) + found = KeyValue.get(dict, key) + found === nothing && throw(KeyError(key)) + return only(found) end function get(dict::PersistentDict{K,V}, key::K, default) where {K,V} - trie = dict.trie - if HAMT.islevel_empty(trie) - return default - end - h = HAMT.HashState(key) - found, present, trie, i, _, _, _ = HAMT.path(trie, key, h) - if found && present - leaf = @inbounds trie.data[i]::HAMT.Leaf{K,V} - return leaf.val - end - return default + found = KeyValue.get(dict, key) + found === nothing && return default + return only(found) end -function get(default::Callable, dict::PersistentDict{K,V}, key::K) where {K,V} +@noinline function KeyValue.get(dict::PersistentDict{K, V}, key) where {K, V} trie = dict.trie if HAMT.islevel_empty(trie) - return default + return nothing end h = HAMT.HashState(key) found, present, trie, i, _, _, _ = HAMT.path(trie, key, h) if found && present leaf = @inbounds trie.data[i]::HAMT.Leaf{K,V} - return leaf.val + return (leaf.val,) end - return default() + return nothing end -iterate(dict::PersistentDict, state=nothing) = HAMT.iterate(dict.trie, state) +@noinline function KeyValue.get(default, dict::PersistentDict, key) + found = KeyValue.get(dict, key) + found === nothing && return default() + return only(found) +end + +function get(default::Callable, dict::PersistentDict{K,V}, key::K) where {K,V} + found = KeyValue.get(dict, key) + found === nothing && return default() + return only(found) +end function delete(dict::PersistentDict{K}, key::K) where K - trie = dict.trie - h = HAMT.HashState(key) - found, present, trie, i, bi, top, _ = HAMT.path(trie, key, h, #=persistent=# true) - if found && present - deleteat!(trie.data, i) - HAMT.unset!(trie, bi) - end - return PersistentDict(top) + return KeyValue.set(dict, key) end +iterate(dict::PersistentDict, state=nothing) = HAMT.iterate(dict.trie, state) + length(dict::PersistentDict) = HAMT.length(dict.trie) isempty(dict::PersistentDict) = HAMT.isempty(dict.trie) empty(::PersistentDict, ::Type{K}, ::Type{V}) where {K, V} = PersistentDict{K, V}() diff --git a/base/hamt.jl b/base/hamt.jl index e940f4e00b1d5..fc298b8b7a338 100644 --- a/base/hamt.jl +++ b/base/hamt.jl @@ -65,12 +65,18 @@ mutable struct HAMT{K, V} HAMT{K,V}(data, bitmap) where {K,V} = new{K,V}(data, bitmap) HAMT{K, V}() where {K, V} = new{K,V}(Vector{Union{Leaf{K, V}, HAMT{K, V}}}(undef, 0), zero(BITMAP)) end -function HAMT{K,V}((k,v)::Pair) where {K, V} - k = convert(K, k) - v = convert(V, v) + +@Base.assume_effects :nothrow function init_hamt(K, V, k, v) # For a single element we can't have a hash-collision trie = HAMT{K,V}(Vector{Union{Leaf{K, V}, HAMT{K, V}}}(undef, 1), zero(BITMAP)) trie.data[1] = Leaf{K,V}(k,v) + return trie +end + +function HAMT{K,V}((k,v)::Pair) where {K, V} + k = convert(K, k) + v = convert(V, v) + trie = init_hamt(K, V, k, v) bi = BitmapIndex(HashState(k)) set!(trie, bi) return trie diff --git a/base/optimized_generics.jl b/base/optimized_generics.jl new file mode 100644 index 0000000000000..86b54a294564d --- /dev/null +++ b/base/optimized_generics.jl @@ -0,0 +1,57 @@ +# This file is a part of Julia. License is MIT: https://julialang.org/license + +module OptimizedGenerics + +# This file defines interfaces that are recognized and optimized by the compiler +# They are intended to be used by data structure implementations that wish to +# opt into some level of compiler optimizations. These interfaces are +# EXPERIMENTAL and currently intended for use by Base only. They are subject +# to change or removal without notice. It is undefined behavior to add methods +# to these generics that do not conform to the specified interface. +# +# The intended way to use these generics is that data structures will provide +# appropriate implementations for a generic. In the absence of compiler +# optimizations, these behave like regular methods. However, the compiler is +# semantically allowed to perform certain structural optimizations on +# appropriate combinations of these intrinsics without proving correctness. + +# Compiler-recognized generics for immutable key-value stores (dicts, etc.) +""" + module KeyValue + +Implements a key-value like interface where the compiler has liberty to perform +the following transformations. The core optimization semantically allowed for +the compiler is: + + get(set(x, key, val), key) -> (val,) + +where the compiler will recursively look through `x`. Keys are compared by +egality. + +Implementations must observe the following constraints: + +1. It is undefined behavior for `get` not to return the exact (by egality) val + stored for a given `key`. +""" +module KeyValue + """ + set(collection, [key [, val]]) + set(T, collection, key, val) + + Set the `key` in `collection` to `val`. If `val` is omitted, deletes the + value from the collection. If `key` is omitted as well, deletes all elements + of the collection. + """ + function set end + + """ + get(collection, key) + + Retrieve the value corresponding to `key` in `collection` as a single + element tuple or `nothing` if no value corresponding to the key was found. + `key`s are compared by egal. + """ + function get end +end + +end diff --git a/test/compiler/irpasses.jl b/test/compiler/irpasses.jl index 204d0400ea701..fe5263059b0ba 100644 --- a/test/compiler/irpasses.jl +++ b/test/compiler/irpasses.jl @@ -1616,3 +1616,13 @@ let m = Meta.@lower 1 + 1 end # JET.test_opt(Core.Compiler.cfg_simplify!, (Core.Compiler.IRCode,)) + +# Test support for Core.OptimizedGenerics.KeyValue protocol +function persistent_dict_elim() + a = Base.PersistentDict(:a => 1) + return a[:a] +end +# Ideally we would be able to fully eliminate this, +# but currently this would require an extra round of constprop +@test_broken fully_eliminated(persistent_dict_elim) +@test code_typed(persistent_dict_elim)[1][1].code[end] == Core.ReturnNode(1)