Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Var.plot_vars and Model.plot_vars #230

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open

Conversation

jobrachem
Copy link
Contributor

All of the below applies to plot_vars as well as plot_nodes.

This PR adds methods to the lsl.Var and the lsl.Model objects that wrap the referenced plotting functions. The method on lsl.Model is inspired by the corresponding lsl.GraphBuilder.plot_vars method, which has existed for a long time and I find quite convenient.

The lsl.Var.plot_vars brings more novelty, because it is a way to very quickly plot intermediate sub-models. One tricky bit was to deal with the automatic naming: As soon as you initialize a model, Liesel will automatically assign names to unnamed nodes. Because lsl.plot_vars works similar to lsl.GraphBuilder.plot_vars, this also happens here: A helper model is built, then the plot is created, then the helper model is discarded. The issue is: The previously unnamed nodes keep their new names. I think this is not intended behavior, so the lsl.Var.plot_vars method resets the automatically assigned names after plotting. It also emits a logging message that notifies users of this behavior, because I think otherwise it could be confusing.

lsl.Var.plot_vars will not work for variables that are already part of a model, because in this case, no intermediate model can be created.

In lsl.Var.plot_nodes, the three model nodes (log_prob, log_lik, log_prior) will also show up. I think this is ok. This plot is intended for advanced users anyway, and having the three log prob nodes in there allows users to spot potential issues with the likelihood and prior attribution in their subgraph.

Example

>>> import liesel.model as lsl
>>> a = lsl.Var.new_value(1.0)
>>> b = lsl.Var.new_calc(lambda a: a + 1.0, a=a, name="b")
>>> a.name
''
>>> b.plot_vars()
liesel.model.nodes - INFO - Unnamed variables were temporarily named for plotting. The automatically assigned names are: ['v0']. The names are reset after plotting.

image

>>> b.plot_nodes()
liesel.model.nodes - INFO - Unnamed variables were temporarily named for plotting. The automatically assigned names are: ['v0']. The names are reset after plotting.

image

Note: In this example, the three model nodes are unconnected, because there is no node with a probability in the minimal example model.

>>> a.name
''

@wiep
Copy link
Contributor

wiep commented Dec 10, 2024

In lsl.Var.plot_nodes, the three model nodes (log_prob, log_lik, log_prior) will also show up. I think this is ok. This plot is intended for advanced users anyway, and having the three log prob nodes in there allows users to spot potential issues with the likelihood and prior attribution in their subgraph.

I think this is rather a feature, because when your var has an associated distribution the user can varify that it contributes to the model_log_prob / likelihood / prior node.

@jobrachem
Copy link
Contributor Author

Maybe the necessity to build an intermediate model can be removed by using lsl.Model._build_graph directly. This would be very nice, because it would mean that you can also simply plot the subgraph of a variable that is part of a model.

@jobrachem
Copy link
Contributor Author

In lsl.Var.plot_nodes, the three model nodes (log_prob, log_lik, log_prior) will also show up. I think this is ok. This plot is intended for advanced users anyway, and having the three log prob nodes in there allows users to spot potential issues with the likelihood and prior attribution in their subgraph.

I think this is rather a feature, because when your var has an associated distribution the user can varify that it contributes to the model_log_prob / likelihood / prior node.

I agree 😊

@jobrachem
Copy link
Contributor Author

jobrachem commented Dec 20, 2024

In lsl.Var.plot_nodes, the three model nodes (log_prob, log_lik, log_prior) will also show up. I think this is ok. This plot is intended for advanced users anyway, and having the three log prob nodes in there allows users to spot potential issues with the likelihood and prior attribution in their subgraph.

I think this is rather a feature, because when your var has an associated distribution the user can varify that it contributes to the model_log_prob / likelihood / prior node.

I agree 😊

I thought a bit more on it, and I now think the behavior is not all that useful. The model nodes we see on the graph above are not model nodes from the "true" model, but from an auxiliary model that is created just to construct the graph for plotting. So connections to these nodes are not diagnostic of the presence or absence of problems in the "true" model that includes child nodes.

I also think the potential for confusion is a real danger that would be nice to avoid.

I added some commits with the following updates:

  1. Calling Var.plot_vars() and Var.plot_nodes() will now also work as expected when the Var is already part of a model. In the process, I added the methods Model.var_subgraph() and Model.node_subgraph.
  2. The model nodes are not part of the plot produced by Var.plot_nodes anymore.

@jobrachem
Copy link
Contributor Author

jobrachem commented Dec 20, 2024

A more involved example:

import liesel.model as lsl
import tensorflow_probability.substrates.jax.distributions as tfd

mu = lsl.Var.new_param(1.0, lsl.Dist(tfd.Normal, loc=0.0, scale=1.0), name="mu")
sigma = lsl.Var.new_param(
    1.0,
    lsl.Dist(
        tfd.InverseGamma,
        concentration=lsl.Var.new_value(2.0),
        scale=lsl.Var.new_value(0.5),
    ),
    name="sigma",
)
y = lsl.Var.new_obs(1.0, lsl.Dist(tfd.Normal, loc=mu, scale=sigma), name="y")

y.plot_vars()

grafik

model = lsl.Model([y])
y.plot_vars()

grafik

y.plot_nodes()

grafik

sigma.plot_vars()

grafik

sigma.plot_nodes()

grafik

model.plot_nodes()

grafik

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants