-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for multiple nested produce
#163
Comments
One candidate to address this would be to replace For example: julia> using DynamicPPL, Distributions, Umlaut
julia> struct DynamicPPLModelCtx{F}
model::Model{F}
end
julia> function isprimitive(ctx::DynamicPPLModelCtx, f, args...)
f === ctx.model.f && return false
if Base.parentmodule(f) == DynamicPPL
# Trace into `DynamicPPL._evaluate!!`.
f === DynamicPPL._evaluate!! && return false
end
return true
end
isprimitive (generic function with 5 methods)
julia> @model function demo(x)
z ~ Normal()
x ~ Normal(z, 1)
end
demo (generic function with 4 methods)
julia> model = demo(1);
julia> ctx = SamplingContext();
julia> varinfo = VarInfo(model);
julia> t = last(Umlaut.trace(DynamicPPL._evaluate!!, model, varinfo, ctx; ctx=DynamicPPLModelCtx(model)))
Tape{DynamicPPLModelCtx{typeof(demo)}}
inp %1::typeof(DynamicPPL._evaluate!!)
inp %2::Model{typeof(demo), (:x,), (), (), Tuple{Int64}, Tuple{}, DefaultContext}
inp %3::TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}
inp %4::SamplingContext{SampleFromPrior, DefaultContext, Random.TaskLocalRNG}
%5 = make_evaluate_args_and_kwargs(%2, %3, %4)::Tuple{Tuple{Model{typeof(demo), (:x,), (), (), Tuple{Int64}, Tuple{}, DefaultContext}, TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, SamplingContext{SampleFromPrior, DefaultContext, Random.TaskLocalRNG}, Int64}, NamedTuple{(), Tuple{}}}
%6 = indexed_iterate(%5, 1)::Tuple{Tuple{Model{typeof(demo), (:x,), (), (), Tuple{Int64}, Tuple{}, DefaultContext}, TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, SamplingContext{SampleFromPrior, DefaultContext, Random.TaskLocalRNG}, Int64}, Int64}
%7 = getfield(%6, 1)::Tuple{Model{typeof(demo), (:x,), (), (), Tuple{Int64}, Tuple{}, DefaultContext}, TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, SamplingContext{SampleFromPrior, DefaultContext, Random.TaskLocalRNG}, Int64}
%8 = getfield(%6, 2)::Int64
%9 = indexed_iterate(%5, 2, %8)::Tuple{NamedTuple{(), Tuple{}}, Int64}
%10 = getfield(%9, 1)::NamedTuple{(), Tuple{}}
%11 = NamedTuple()::NamedTuple{(), Tuple{}}
%12 = merge(%11, %10)::NamedTuple{(), Tuple{}}
%13 = isempty(%12)::Bool
%14 = getproperty(%2, :f)::typeof(demo)
%15 = check_variable_length(%7, 4, 7)::Nothing
%16 = getindex(%7, 1)::Model{typeof(demo), (:x,), (), (), Tuple{Int64}, Tuple{}, DefaultContext}
%17 = getindex(%7, 2)::TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}
%18 = getindex(%7, 3)::SamplingContext{SampleFromPrior, DefaultContext, Random.TaskLocalRNG}
%19 = getindex(%7, 4)::Int64
const %20 = nothing::Nothing
const %21 = nothing::Nothing
const %22 = nothing::Nothing
const %23 = nothing::Nothing
const %24 = nothing::Nothing
const %25 = nothing::Nothing
const %26 = nothing::Nothing
const %27 = nothing::Nothing
const %28 = nothing::Nothing
const %29 = nothing::Nothing
const %30 = nothing::Nothing
const %31 = nothing::Nothing
const %32 = nothing::Nothing
const %33 = nothing::Nothing
%34 = Normal()::Normal{Float64}
%35 = apply_type(VarName, :z)::UnionAll
%36 = %35()::VarName{:z, Setfield.IdentityLens}
%37 = resolve_varnames(%36, %34)::VarName{:z, Setfield.IdentityLens}
%38 = contextual_isassumption(%18, %37)::Bool
%39 = inargnames(%37, %16)::Bool
%40 = !(%39)::Bool
const %41 = true::Bool
const %42 = nothing::Nothing
%43 = tuple(%18)::Tuple{SamplingContext{SampleFromPrior, DefaultContext, Random.TaskLocalRNG}}
%44 = check_tilde_rhs(%34)::Normal{Float64}
%45 = unwrap_right_vn(%44, %37)::Tuple{Normal{Float64}, VarName{:z, Setfield.IdentityLens}}
%46 = tuple(%17)::Tuple{TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}}
%47 = check_variable_length(%43, 1, 43)::Nothing
%48 = check_variable_length(%45, 2, 45)::Nothing
%49 = getindex(%45, 1)::Normal{Float64}
%50 = getindex(%45, 2)::VarName{:z, Setfield.IdentityLens}
%51 = check_variable_length(%46, 1, 46)::Nothing
%52 = tilde_assume!!(%18, %49, %50, %17)::Tuple{Float64, TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}}
%53 = indexed_iterate(%52, 1)::Tuple{Float64, Int64}
%54 = getfield(%53, 1)::Float64
%55 = getfield(%53, 2)::Int64
%56 = indexed_iterate(%52, 2, %55)::Tuple{TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, Int64}
%57 = getfield(%56, 1)::TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}
%58 = Normal(%54, 1)::Normal{Float64}
%59 = apply_type(VarName, :x)::UnionAll
%60 = %59()::VarName{:x, Setfield.IdentityLens}
%61 = resolve_varnames(%60, %58)::VarName{:x, Setfield.IdentityLens}
%62 = contextual_isassumption(%18, %61)::Bool
%63 = inargnames(%61, %16)::Bool
%64 = !(%63)::Bool
%65 = inmissings(%61, %16)::Bool
%66 = ===(%19, missing)::Bool
%67 = inargnames(%61, %16)::Bool
%68 = !(%67)::Bool
%69 = check_tilde_rhs(%58)::Normal{Float64}
%70 = tilde_observe!!(%18, %69, %19, %61, %57)::Tuple{Int64, TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}}
%71 = indexed_iterate(%70, 1)::Tuple{Int64, Int64}
%72 = getfield(%71, 1)::Int64
%73 = getfield(%71, 2)::Int64
%74 = indexed_iterate(%70, 2, %73)::Tuple{TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, Int64}
%75 = getfield(%74, 1)::TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}
%76 = tuple(%72, %75)::Tuple{Int64, TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}} This should even be possible to use with Libtask.jl with some minor changes by just compiling the tape Note the AFAIK the main drawback by making this change is that we drop support for control-flow (unless we want to perform |
As seen discussed extensively in TuringLang/Turing.jl#2001, in particular TuringLang/Turing.jl#2001 (comment), Libtask.jl makes one crucial assumption: every
Instruction
contains at most 1produce
statement.This is because
Libtask.jl/src/tapedfunction.jl
Lines 73 to 74 in 95e32aa
where
Libtask.jl/src/tapedfunction.jl
Lines 44 to 48 in 95e32aa
which is then traversed to construct the tape.
There are many cases in which this is just not true in Turing.jl, e.g. when we use
@submodel
.Moreover, it's very unclear to me how this can be addressed without doing something very fancy to allow us to recurse into the type-inference that is performed.
EDIT: Here's an example of what I mean:
The text was updated successfully, but these errors were encountered: