From 733f1634ee60f7a043e21553c1b4345123a29e60 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 15 Jun 2018 13:50:33 -0700 Subject: [PATCH 1/5] make trace a subclass of networkx.DiGraph again --- pyro/poutine/trace_struct.py | 107 +++-------------------------------- 1 file changed, 7 insertions(+), 100 deletions(-) diff --git a/pyro/poutine/trace_struct.py b/pyro/poutine/trace_struct.py index 91aae08244..e2597fe206 100644 --- a/pyro/poutine/trace_struct.py +++ b/pyro/poutine/trace_struct.py @@ -20,20 +20,7 @@ def _warn_if_nan(name, value): # Note that -inf log_prob_sum is fine: it is merely a zero-probability event. -class DiGraph(networkx.DiGraph): - """ - Wrapper of :class:`networkx.DiGraph` that makes ``self.nodes`` a ``collections.OrderedDict``. - """ - node_dict_factory = collections.OrderedDict - - def fresh_copy(self): - """ - Returns a new ``DiGraph`` instance. - """ - return DiGraph() - - -class Trace(object): +class Trace(networkx.DiGraph): """ Execution trace data structure built on top of :class:`networkx.DiGraph`. @@ -87,6 +74,8 @@ class Trace(object): ``'done'``, ``'stop'``, and ``'continuation'`` are only used by Pyro's internals. """ + node_dict_factory = collections.OrderedDict + def __init__(self, *args, **kwargs): """ :param string graph_type: string specifying the kind of trace graph to construct @@ -94,94 +83,13 @@ def __init__(self, *args, **kwargs): Constructor. Currently identical to :meth:`networkx.DiGraph.__init__`, except for storing the graph_type attribute """ - self._graph = DiGraph(*args, **kwargs) + # self._graph = DiGraph(*args, **kwargs) graph_type = kwargs.pop("graph_type", "flat") assert graph_type in ("flat", "dense"), \ "{} not a valid graph type".format(graph_type) self.graph_type = graph_type super(Trace, self).__init__(*args, **kwargs) - def __del__(self): - """ - Works around cyclic reference bugs in :class:`networkx.DiGraph` - See ``https://github.com/uber/pyro/issues/798`` - """ - self._graph.__dict__.clear() - - @property - def nodes(self): - """ - Identical to :attr:`networkx.DiGraph.nodes` - """ - return self._graph.nodes - - @property - def edges(self): - """ - Identical to :attr:`networkx.DiGraph.edges` - """ - return self._graph.edges - - @property - def graph(self): - """ - Identical to :attr:`networkx.DiGraph.graph` - """ - return self._graph.graph - - @property - def remove_node(self): - """ - Identical to :meth:`networkx.DiGraph.remove_node` - """ - return self._graph.remove_node - - @property - def add_edge(self): - """ - Identical to :meth:`networkx.DiGraph.add_edge` - """ - return self._graph.add_edge - - @property - def is_directed(self): - """ - Identical to :attr:`networkx.DiGraph.is_directed` - """ - return self._graph.is_directed - - @property - def in_degree(self): - """ - Identical to :meth:`networkx.DiGraph.in_degree` - """ - return self._graph.in_degree - - @property - def successors(self): - """ - Identical to :meth:`networkx.DiGraph.successors` - """ - return self._graph.successors - - def __contains__(self, site_name): - """ - Identical to :meth:`networkx.DiGraph.__contains__` - """ - return site_name in self._graph - - def __iter__(self): - """ - Identical to :meth:`networkx.DiGraph.__iter__` - """ - return iter(self._graph) - - def __len__(self): - """ - Identical to :meth:`networkx.DiGraph.__len__` - """ - return len(self._graph) - def add_node(self, site_name, *args, **kwargs): """ :param string site_name: the name of the site to be added @@ -198,7 +106,7 @@ def add_node(self, site_name, *args, **kwargs): "site {} already in trace".format(site_name) # XXX should copy in case site gets mutated, or dont bother? - self._graph.add_node(site_name, *args, **kwargs) + super(Trace, self).add_node(site_name, *args, **kwargs) def copy(self): """ @@ -206,9 +114,8 @@ def copy(self): Identical to :meth:`networkx.DiGraph.copy`, but preserves the type and the self.graph_type attribute """ - trace = Trace() - trace._graph = self._graph.copy() - trace._graph.__class__ = DiGraph + trace = super(Trace, self).copy() + trace.__class__ = Trace trace.graph_type = self.graph_type return trace From 0fe52ca3b9103abd47fc2e94a798c263a7427421 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 19 Jun 2018 18:46:32 -0700 Subject: [PATCH 2/5] update networkx versions --- docs/requirements.txt | 2 +- pyro/poutine/trace_struct.py | 1 - setup.py | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index 5faa392d10..4f5587d6b1 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -4,5 +4,5 @@ numpy>=1.7 contextlib2 cloudpickle>=0.3.1 graphviz>=0.8 -networkx>=2.0.0 +networkx>=2.2 observations>=0.1.4 diff --git a/pyro/poutine/trace_struct.py b/pyro/poutine/trace_struct.py index e2597fe206..e594ab5293 100644 --- a/pyro/poutine/trace_struct.py +++ b/pyro/poutine/trace_struct.py @@ -83,7 +83,6 @@ def __init__(self, *args, **kwargs): Constructor. Currently identical to :meth:`networkx.DiGraph.__init__`, except for storing the graph_type attribute """ - # self._graph = DiGraph(*args, **kwargs) graph_type = kwargs.pop("graph_type", "flat") assert graph_type in ("flat", "dense"), \ "{} not a valid graph type".format(graph_type) diff --git a/setup.py b/setup.py index 94b7e433c4..b86dc80542 100644 --- a/setup.py +++ b/setup.py @@ -76,7 +76,7 @@ # add them to `docs/requirements.txt` 'contextlib2', 'graphviz>=0.8', - 'networkx>=2.0.0', + 'networkx>=2.2', 'numpy>=1.7', 'six>=1.10.0', 'torch>=0.4.0', From d2c539eed1f8c947c1899c791137e97adeee42ae Mon Sep 17 00:00:00 2001 From: eb8680 Date: Fri, 14 Sep 2018 11:28:14 -0700 Subject: [PATCH 3/5] 2.2 to 2.2rc1 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index b86dc80542..d95eb25f25 100644 --- a/setup.py +++ b/setup.py @@ -76,7 +76,7 @@ # add them to `docs/requirements.txt` 'contextlib2', 'graphviz>=0.8', - 'networkx>=2.2', + 'networkx>=2.2rc1', 'numpy>=1.7', 'six>=1.10.0', 'torch>=0.4.0', From 2b825b04897b9047df1ffd2a2b977cc8f6dd8169 Mon Sep 17 00:00:00 2001 From: eb8680 Date: Fri, 14 Sep 2018 11:28:46 -0700 Subject: [PATCH 4/5] 2.2 to 2.2rc1 --- docs/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index 4f5587d6b1..42626dcfcb 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -4,5 +4,5 @@ numpy>=1.7 contextlib2 cloudpickle>=0.3.1 graphviz>=0.8 -networkx>=2.2 +networkx>=2.2rc1 observations>=0.1.4 From d7ec27a7ee0c5e380b21215f2a611331f73fa393 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 14 Sep 2018 13:42:59 -0700 Subject: [PATCH 5/5] remove ._graph --- tests/infer/test_compute_downstream_costs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/infer/test_compute_downstream_costs.py b/tests/infer/test_compute_downstream_costs.py index 662ea99b93..f428c48859 100644 --- a/tests/infer/test_compute_downstream_costs.py +++ b/tests/infer/test_compute_downstream_costs.py @@ -28,7 +28,7 @@ def _brute_force_compute_downstream_costs(model_trace, guide_trace, # guide_trace.nodes[node]['log_prob'])) downstream_guide_cost_nodes[node] = set([node]) - descendants = networkx.descendants(guide_trace._graph, node) + descendants = networkx.descendants(guide_trace, node) for desc in descendants: desc_mft = MultiFrameTensor((stacks[desc],