Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for multiple nested produce #163

Open
torfjelde opened this issue Jun 11, 2023 · 1 comment
Open

Support for multiple nested produce #163

torfjelde opened this issue Jun 11, 2023 · 1 comment

Comments

@torfjelde
Copy link
Member

torfjelde commented Jun 11, 2023

As seen discussed extensively in TuringLang/Turing.jl#2001, in particular TuringLang/Turing.jl#2001 (comment), Libtask.jl makes one crucial assumption: every Instruction contains at most 1 produce statement.

This is because

ir = _infer(f, args_type)
binding_values, slots, tape = translate!(RawTape(), ir)

where
function _infer(f, args_type)
# `code_typed` returns a vector: [Pair{Core.CodeInfo, DataType}]
ir0 = code_typed(f, Tuple{args_type...}, optimize=false)[1][1]
return ir0
end

which is then traversed to construct the tape.

There are many cases in which this is just not true in Turing.jl, e.g. when we use @submodel.

Moreover, it's very unclear to me how this can be addressed without doing something very fancy to allow us to recurse into the type-inference that is performed.

EDIT: Here's an example of what I mean:

julia> using Libtask

julia> f(x) = (produce(x); produce(2x); produce(3x); return nothing)
f (generic function with 1 method)

julia> g(x) = f(x)
g (generic function with 1 method)

julia> task = Libtask.TapedTask(f, 1);

julia> consume(task), consume(task), consume(task)
(1, 2, 3)

julia> task = Libtask.TapedTask(g, 1);  # tracing of nested call

julia> consume(task)   # goes through all the `produce` calls before even calling the `callback` (which is `Libtask.producer`)
counter=1
tf=TapedFunction:
* .func => g
* .ir   =>
------------------
CodeInfo(
1%1 = Main.f(x)::Core.Const(nothing)
└──      return %1
)
------------------

ErrorException("There is a produced value which is not consumed.")Union{Ptr{Nothing}, Base.InterpreterIP}[Ptr{Nothing} @0x00007fa8d200eeeb, Ptr{Nothing} @0x00007fa8a0a30f29, Ptr{Nothing} @0x00007fa8a0a36844, Ptr{Nothing} @0x00007fa8a0a36865, Ptr{Nothing} @0x00007fa8e6b44f1d, Ptr{Nothing} @0x00007fa8a0a366e3, Ptr{Nothing} @0x00007fa8a0a36802, Ptr{Nothing} @0x00007fa8e6b44f1d, Ptr{Nothing} @0x00007fa8a0a35f25, Ptr{Nothing} @0x00007fa8a0a361dd, Ptr{Nothing} @0x00007fa8a0a36512, Ptr{Nothing} @0x00007fa8a0a3652f, Ptr{Nothing} @0x00007fa8e6b44f1d, Ptr{Nothing} @0x00007fa8e6b6656f]
ERROR: There is a produced value which is not consumed.
Stacktrace:
 [1] consume(ttask::TapedTask{typeof(g), Tuple{Int64}})
   @ Libtask ~/.julia/packages/Libtask/h7Kal/src/tapedtask.jl:153
 [2] top-level scope
   @ REPL[9]:1
@torfjelde
Copy link
Member Author

One candidate to address this would be to replace Libtask._infer with last(Umlaut.trace) and then build the instruction tape from the resulting trace.

For example:

julia> using DynamicPPL, Distributions, Umlaut

julia> struct DynamicPPLModelCtx{F}
           model::Model{F}
       end

julia> function isprimitive(ctx::DynamicPPLModelCtx, f, args...)
           f === ctx.model.f && return false
           if Base.parentmodule(f) == DynamicPPL
               # Trace into `DynamicPPL._evaluate!!`.
               f === DynamicPPL._evaluate!! && return false
           end
           return true
       end
isprimitive (generic function with 5 methods)

julia> @model function demo(x)
           z ~ Normal()
           x ~ Normal(z, 1)
       end
demo (generic function with 4 methods)

julia> model = demo(1);

julia> ctx = SamplingContext();

julia> varinfo = VarInfo(model);

julia> t = last(Umlaut.trace(DynamicPPL._evaluate!!, model, varinfo, ctx; ctx=DynamicPPLModelCtx(model)))
Tape{DynamicPPLModelCtx{typeof(demo)}}
  inp %1::typeof(DynamicPPL._evaluate!!)
  inp %2::Model{typeof(demo), (:x,), (), (), Tuple{Int64}, Tuple{}, DefaultContext}
  inp %3::TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}
  inp %4::SamplingContext{SampleFromPrior, DefaultContext, Random.TaskLocalRNG}
  %5 = make_evaluate_args_and_kwargs(%2, %3, %4)::Tuple{Tuple{Model{typeof(demo), (:x,), (), (), Tuple{Int64}, Tuple{}, DefaultContext}, TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, SamplingContext{SampleFromPrior, DefaultContext, Random.TaskLocalRNG}, Int64}, NamedTuple{(), Tuple{}}} 
  %6 = indexed_iterate(%5, 1)::Tuple{Tuple{Model{typeof(demo), (:x,), (), (), Tuple{Int64}, Tuple{}, DefaultContext}, TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, SamplingContext{SampleFromPrior, DefaultContext, Random.TaskLocalRNG}, Int64}, Int64} 
  %7 = getfield(%6, 1)::Tuple{Model{typeof(demo), (:x,), (), (), Tuple{Int64}, Tuple{}, DefaultContext}, TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, SamplingContext{SampleFromPrior, DefaultContext, Random.TaskLocalRNG}, Int64} 
  %8 = getfield(%6, 2)::Int64 
  %9 = indexed_iterate(%5, 2, %8)::Tuple{NamedTuple{(), Tuple{}}, Int64} 
  %10 = getfield(%9, 1)::NamedTuple{(), Tuple{}} 
  %11 = NamedTuple()::NamedTuple{(), Tuple{}} 
  %12 = merge(%11, %10)::NamedTuple{(), Tuple{}} 
  %13 = isempty(%12)::Bool 
  %14 = getproperty(%2, :f)::typeof(demo) 
  %15 = check_variable_length(%7, 4, 7)::Nothing 
  %16 = getindex(%7, 1)::Model{typeof(demo), (:x,), (), (), Tuple{Int64}, Tuple{}, DefaultContext} 
  %17 = getindex(%7, 2)::TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64} 
  %18 = getindex(%7, 3)::SamplingContext{SampleFromPrior, DefaultContext, Random.TaskLocalRNG} 
  %19 = getindex(%7, 4)::Int64 
  const %20 = nothing::Nothing
  const %21 = nothing::Nothing
  const %22 = nothing::Nothing
  const %23 = nothing::Nothing
  const %24 = nothing::Nothing
  const %25 = nothing::Nothing
  const %26 = nothing::Nothing
  const %27 = nothing::Nothing
  const %28 = nothing::Nothing
  const %29 = nothing::Nothing
  const %30 = nothing::Nothing
  const %31 = nothing::Nothing
  const %32 = nothing::Nothing
  const %33 = nothing::Nothing
  %34 = Normal()::Normal{Float64} 
  %35 = apply_type(VarName, :z)::UnionAll 
  %36 = %35()::VarName{:z, Setfield.IdentityLens} 
  %37 = resolve_varnames(%36, %34)::VarName{:z, Setfield.IdentityLens} 
  %38 = contextual_isassumption(%18, %37)::Bool 
  %39 = inargnames(%37, %16)::Bool 
  %40 = !(%39)::Bool 
  const %41 = true::Bool
  const %42 = nothing::Nothing
  %43 = tuple(%18)::Tuple{SamplingContext{SampleFromPrior, DefaultContext, Random.TaskLocalRNG}} 
  %44 = check_tilde_rhs(%34)::Normal{Float64} 
  %45 = unwrap_right_vn(%44, %37)::Tuple{Normal{Float64}, VarName{:z, Setfield.IdentityLens}} 
  %46 = tuple(%17)::Tuple{TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}} 
  %47 = check_variable_length(%43, 1, 43)::Nothing 
  %48 = check_variable_length(%45, 2, 45)::Nothing 
  %49 = getindex(%45, 1)::Normal{Float64} 
  %50 = getindex(%45, 2)::VarName{:z, Setfield.IdentityLens} 
  %51 = check_variable_length(%46, 1, 46)::Nothing 
  %52 = tilde_assume!!(%18, %49, %50, %17)::Tuple{Float64, TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}} 
  %53 = indexed_iterate(%52, 1)::Tuple{Float64, Int64} 
  %54 = getfield(%53, 1)::Float64 
  %55 = getfield(%53, 2)::Int64 
  %56 = indexed_iterate(%52, 2, %55)::Tuple{TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, Int64} 
  %57 = getfield(%56, 1)::TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64} 
  %58 = Normal(%54, 1)::Normal{Float64} 
  %59 = apply_type(VarName, :x)::UnionAll 
  %60 = %59()::VarName{:x, Setfield.IdentityLens} 
  %61 = resolve_varnames(%60, %58)::VarName{:x, Setfield.IdentityLens} 
  %62 = contextual_isassumption(%18, %61)::Bool 
  %63 = inargnames(%61, %16)::Bool 
  %64 = !(%63)::Bool 
  %65 = inmissings(%61, %16)::Bool 
  %66 = ===(%19, missing)::Bool 
  %67 = inargnames(%61, %16)::Bool 
  %68 = !(%67)::Bool 
  %69 = check_tilde_rhs(%58)::Normal{Float64} 
  %70 = tilde_observe!!(%18, %69, %19, %61, %57)::Tuple{Int64, TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}} 
  %71 = indexed_iterate(%70, 1)::Tuple{Int64, Int64} 
  %72 = getfield(%71, 1)::Int64 
  %73 = getfield(%71, 2)::Int64 
  %74 = indexed_iterate(%70, 2, %73)::Tuple{TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, Int64} 
  %75 = getfield(%74, 1)::TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64} 
  %76 = tuple(%72, %75)::Tuple{Int64, TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}} 

This should even be possible to use with Libtask.jl with some minor changes by just compiling the tape t and then pass that to Libtask._infer_.

Note the isprimitive would have to be fine-tuned to also support usage of @submodel, but it could easily be done.

AFAIK the main drawback by making this change is that we drop support for control-flow (unless we want to perform trace on every call).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant