diff --git a/pyro/infer/inspect.py b/pyro/infer/inspect.py index 3b2b03fafb..1c0cf9413b 100644 --- a/pyro/infer/inspect.py +++ b/pyro/infer/inspect.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import itertools +import os from collections import defaultdict from pathlib import Path from types import SimpleNamespace @@ -572,7 +573,7 @@ def render_model( list of tuples for semisupervised models. :param model_kwargs: Dict of keyword arguments to pass to the model, or list of dicts for semisupervised models. - :param str filename: File to save rendered model in. + :param str filename: Name of file or path to file to save rendered model in. :param bool render_distributions: Whether to include RV distribution annotations (and param constraints) in the plot. :param bool render_params: Whether to show params inthe plot. @@ -604,9 +605,9 @@ def render_model( graph = render_graph(graph_spec, render_distributions=render_distributions) if filename is not None: - filename = Path(filename) - suffix = filename.suffix[1:] # remove leading period from suffix - graph.render(filename.stem, view=False, cleanup=True, format=suffix) + suffix = Path(filename).suffix[1:] # remove leading period from suffix + filepath = os.path.splitext(filename)[0] + graph.render(filepath, view=False, cleanup=True, format=suffix) return graph