diff --git a/pyro/infer/inspect.py b/pyro/infer/inspect.py index dffb6fd6fd..9bb5ca2520 100644 --- a/pyro/infer/inspect.py +++ b/pyro/infer/inspect.py @@ -549,6 +549,12 @@ def render_model( :returns: A model graph. :rtype: graphviz.Digraph """ + assert model_args is None or isinstance( + model_args, tuple + ), "model_args must be None or tuple" + assert model_kwargs is None or isinstance( + model_kwargs, dict + ), "model_kwargs must be None or dict" relations = get_model_relations(model, model_args, model_kwargs) graph_spec = generate_graph_specification(relations, render_params=render_params) graph = render_graph(graph_spec, render_distributions=render_distributions)