-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement Var.plot_vars and Model.plot_vars #230
base: main
Are you sure you want to change the base?
Conversation
I think this is rather a feature, because when your var has an associated distribution the user can varify that it contributes to the model_log_prob / likelihood / prior node. |
Maybe the necessity to build an intermediate model can be removed by using |
I agree 😊 |
I thought a bit more on it, and I now think the behavior is not all that useful. The model nodes we see on the graph above are not model nodes from the "true" model, but from an auxiliary model that is created just to construct the graph for plotting. So connections to these nodes are not diagnostic of the presence or absence of problems in the "true" model that includes child nodes. I also think the potential for confusion is a real danger that would be nice to avoid. I added some commits with the following updates:
|
A more involved example: import liesel.model as lsl
import tensorflow_probability.substrates.jax.distributions as tfd
mu = lsl.Var.new_param(1.0, lsl.Dist(tfd.Normal, loc=0.0, scale=1.0), name="mu")
sigma = lsl.Var.new_param(
1.0,
lsl.Dist(
tfd.InverseGamma,
concentration=lsl.Var.new_value(2.0),
scale=lsl.Var.new_value(0.5),
),
name="sigma",
)
y = lsl.Var.new_obs(1.0, lsl.Dist(tfd.Normal, loc=mu, scale=sigma), name="y")
y.plot_vars() model = lsl.Model([y])
y.plot_vars() y.plot_nodes()
model.plot_nodes() |
All of the below applies to
plot_vars
as well asplot_nodes
.This PR adds methods to the
lsl.Var
and thelsl.Model
objects that wrap the referenced plotting functions. The method onlsl.Model
is inspired by the correspondinglsl.GraphBuilder.plot_vars
method, which has existed for a long time and I find quite convenient.The
lsl.Var.plot_vars
brings more novelty, because it is a way to very quickly plot intermediate sub-models. One tricky bit was to deal with the automatic naming: As soon as you initialize a model, Liesel will automatically assign names to unnamed nodes. Becauselsl.plot_vars
works similar tolsl.GraphBuilder.plot_vars
, this also happens here: A helper model is built, then the plot is created, then the helper model is discarded. The issue is: The previously unnamed nodes keep their new names. I think this is not intended behavior, so thelsl.Var.plot_vars
method resets the automatically assigned names after plotting. It also emits a logging message that notifies users of this behavior, because I think otherwise it could be confusing.lsl.Var.plot_vars
will not work for variables that are already part of a model, because in this case, no intermediate model can be created.In
lsl.Var.plot_nodes
, the three model nodes (log_prob, log_lik, log_prior) will also show up. I think this is ok. This plot is intended for advanced users anyway, and having the three log prob nodes in there allows users to spot potential issues with the likelihood and prior attribution in their subgraph.Example
Note: In this example, the three model nodes are unconnected, because there is no node with a probability in the minimal example model.