From b89519d29a5c8b889f6a351070d4bc5153e0f5af Mon Sep 17 00:00:00 2001 From: eb8680 Date: Fri, 14 Sep 2018 17:44:55 -0700 Subject: [PATCH] Make trace a subclass of networkx.DiGraph again (#1191) --- docs/requirements.txt | 2 +- pyro/poutine/trace_struct.py | 106 ++----------------- setup.py | 2 +- tests/infer/test_compute_downstream_costs.py | 2 +- 4 files changed, 9 insertions(+), 103 deletions(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index bd8afb8089..29bfb18139 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -4,7 +4,7 @@ numpy>=1.7 contextlib2 cloudpickle>=0.3.1 graphviz>=0.8 -networkx>=2.0.0 +networkx>=2.2rc1 observations>=0.1.4 opt_einsum>=2.2.0 tqdm>=4.25 diff --git a/pyro/poutine/trace_struct.py b/pyro/poutine/trace_struct.py index b446a1a8af..6672d0cfbc 100644 --- a/pyro/poutine/trace_struct.py +++ b/pyro/poutine/trace_struct.py @@ -9,20 +9,7 @@ from pyro.util import warn_if_nan, warn_if_inf -class DiGraph(networkx.DiGraph): - """ - Wrapper of :class:`networkx.DiGraph` that makes ``self.nodes`` a ``collections.OrderedDict``. - """ - node_dict_factory = collections.OrderedDict - - def fresh_copy(self): - """ - Returns a new ``DiGraph`` instance. - """ - return DiGraph() - - -class Trace(object): +class Trace(networkx.DiGraph): """ Execution trace data structure built on top of :class:`networkx.DiGraph`. @@ -76,6 +63,8 @@ class Trace(object): ``'done'``, ``'stop'``, and ``'continuation'`` are only used by Pyro's internals. """ + node_dict_factory = collections.OrderedDict + def __init__(self, *args, **kwargs): """ :param string graph_type: string specifying the kind of trace graph to construct @@ -83,94 +72,12 @@ def __init__(self, *args, **kwargs): Constructor. Currently identical to :meth:`networkx.DiGraph.__init__`, except for storing the graph_type attribute """ - self._graph = DiGraph(*args, **kwargs) graph_type = kwargs.pop("graph_type", "flat") assert graph_type in ("flat", "dense"), \ "{} not a valid graph type".format(graph_type) self.graph_type = graph_type super(Trace, self).__init__(*args, **kwargs) - def __del__(self): - """ - Works around cyclic reference bugs in :class:`networkx.DiGraph` - See ``https://github.com/uber/pyro/issues/798`` - """ - self._graph.__dict__.clear() - - @property - def nodes(self): - """ - Identical to :attr:`networkx.DiGraph.nodes` - """ - return self._graph.nodes - - @property - def edges(self): - """ - Identical to :attr:`networkx.DiGraph.edges` - """ - return self._graph.edges - - @property - def graph(self): - """ - Identical to :attr:`networkx.DiGraph.graph` - """ - return self._graph.graph - - @property - def remove_node(self): - """ - Identical to :meth:`networkx.DiGraph.remove_node` - """ - return self._graph.remove_node - - @property - def add_edge(self): - """ - Identical to :meth:`networkx.DiGraph.add_edge` - """ - return self._graph.add_edge - - @property - def is_directed(self): - """ - Identical to :attr:`networkx.DiGraph.is_directed` - """ - return self._graph.is_directed - - @property - def in_degree(self): - """ - Identical to :meth:`networkx.DiGraph.in_degree` - """ - return self._graph.in_degree - - @property - def successors(self): - """ - Identical to :meth:`networkx.DiGraph.successors` - """ - return self._graph.successors - - def __contains__(self, site_name): - """ - Identical to :meth:`networkx.DiGraph.__contains__` - """ - return site_name in self._graph - - def __iter__(self): - """ - Identical to :meth:`networkx.DiGraph.__iter__` - """ - return iter(self._graph) - - def __len__(self): - """ - Identical to :meth:`networkx.DiGraph.__len__` - """ - return len(self._graph) - def add_node(self, site_name, *args, **kwargs): """ :param string site_name: the name of the site to be added @@ -191,7 +98,7 @@ def add_node(self, site_name, *args, **kwargs): raise RuntimeError("Multiple {} sites named '{}'".format(kwargs['type'], site_name)) # XXX should copy in case site gets mutated, or dont bother? - self._graph.add_node(site_name, *args, **kwargs) + super(Trace, self).add_node(site_name, *args, **kwargs) def copy(self): """ @@ -199,9 +106,8 @@ def copy(self): Identical to :meth:`networkx.DiGraph.copy`, but preserves the type and the self.graph_type attribute """ - trace = Trace() - trace._graph = self._graph.copy() - trace._graph.__class__ = DiGraph + trace = super(Trace, self).copy() + trace.__class__ = Trace trace.graph_type = self.graph_type return trace diff --git a/setup.py b/setup.py index f53e8fd876..efb404960f 100644 --- a/setup.py +++ b/setup.py @@ -79,7 +79,7 @@ # add them to `docs/requirements.txt` 'contextlib2', 'graphviz>=0.8', - 'networkx>=2.0.0', + 'networkx>=2.2rc1', 'numpy>=1.7', 'opt_einsum>=2.2.0', 'six>=1.10.0', diff --git a/tests/infer/test_compute_downstream_costs.py b/tests/infer/test_compute_downstream_costs.py index 662ea99b93..f428c48859 100644 --- a/tests/infer/test_compute_downstream_costs.py +++ b/tests/infer/test_compute_downstream_costs.py @@ -28,7 +28,7 @@ def _brute_force_compute_downstream_costs(model_trace, guide_trace, # guide_trace.nodes[node]['log_prob'])) downstream_guide_cost_nodes[node] = set([node]) - descendants = networkx.descendants(guide_trace._graph, node) + descendants = networkx.descendants(guide_trace, node) for desc in descendants: desc_mft = MultiFrameTensor((stacks[desc],