Skip to content

Commit

Permalink
Make trace a subclass of networkx.DiGraph again (#1191)
Browse files Browse the repository at this point in the history
  • Loading branch information
eb8680 authored and fritzo committed Sep 15, 2018
1 parent 6b1dd37 commit b89519d
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 103 deletions.
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ numpy>=1.7
contextlib2
cloudpickle>=0.3.1
graphviz>=0.8
networkx>=2.0.0
networkx>=2.2rc1
observations>=0.1.4
opt_einsum>=2.2.0
tqdm>=4.25
106 changes: 6 additions & 100 deletions pyro/poutine/trace_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,7 @@
from pyro.util import warn_if_nan, warn_if_inf


class DiGraph(networkx.DiGraph):
"""
Wrapper of :class:`networkx.DiGraph` that makes ``self.nodes`` a ``collections.OrderedDict``.
"""
node_dict_factory = collections.OrderedDict

def fresh_copy(self):
"""
Returns a new ``DiGraph`` instance.
"""
return DiGraph()


class Trace(object):
class Trace(networkx.DiGraph):
"""
Execution trace data structure built on top of :class:`networkx.DiGraph`.
Expand Down Expand Up @@ -76,101 +63,21 @@ class Trace(object):
``'done'``, ``'stop'``, and ``'continuation'`` are only used by Pyro's internals.
"""

node_dict_factory = collections.OrderedDict

def __init__(self, *args, **kwargs):
"""
:param string graph_type: string specifying the kind of trace graph to construct
Constructor. Currently identical to :meth:`networkx.DiGraph.__init__`,
except for storing the graph_type attribute
"""
self._graph = DiGraph(*args, **kwargs)
graph_type = kwargs.pop("graph_type", "flat")
assert graph_type in ("flat", "dense"), \
"{} not a valid graph type".format(graph_type)
self.graph_type = graph_type
super(Trace, self).__init__(*args, **kwargs)

def __del__(self):
"""
Works around cyclic reference bugs in :class:`networkx.DiGraph`
See ``https://github.com/uber/pyro/issues/798``
"""
self._graph.__dict__.clear()

@property
def nodes(self):
"""
Identical to :attr:`networkx.DiGraph.nodes`
"""
return self._graph.nodes

@property
def edges(self):
"""
Identical to :attr:`networkx.DiGraph.edges`
"""
return self._graph.edges

@property
def graph(self):
"""
Identical to :attr:`networkx.DiGraph.graph`
"""
return self._graph.graph

@property
def remove_node(self):
"""
Identical to :meth:`networkx.DiGraph.remove_node`
"""
return self._graph.remove_node

@property
def add_edge(self):
"""
Identical to :meth:`networkx.DiGraph.add_edge`
"""
return self._graph.add_edge

@property
def is_directed(self):
"""
Identical to :attr:`networkx.DiGraph.is_directed`
"""
return self._graph.is_directed

@property
def in_degree(self):
"""
Identical to :meth:`networkx.DiGraph.in_degree`
"""
return self._graph.in_degree

@property
def successors(self):
"""
Identical to :meth:`networkx.DiGraph.successors`
"""
return self._graph.successors

def __contains__(self, site_name):
"""
Identical to :meth:`networkx.DiGraph.__contains__`
"""
return site_name in self._graph

def __iter__(self):
"""
Identical to :meth:`networkx.DiGraph.__iter__`
"""
return iter(self._graph)

def __len__(self):
"""
Identical to :meth:`networkx.DiGraph.__len__`
"""
return len(self._graph)

def add_node(self, site_name, *args, **kwargs):
"""
:param string site_name: the name of the site to be added
Expand All @@ -191,17 +98,16 @@ def add_node(self, site_name, *args, **kwargs):
raise RuntimeError("Multiple {} sites named '{}'".format(kwargs['type'], site_name))

# XXX should copy in case site gets mutated, or dont bother?
self._graph.add_node(site_name, *args, **kwargs)
super(Trace, self).add_node(site_name, *args, **kwargs)

def copy(self):
"""
Makes a shallow copy of self with nodes and edges preserved.
Identical to :meth:`networkx.DiGraph.copy`, but preserves the type
and the self.graph_type attribute
"""
trace = Trace()
trace._graph = self._graph.copy()
trace._graph.__class__ = DiGraph
trace = super(Trace, self).copy()
trace.__class__ = Trace
trace.graph_type = self.graph_type
return trace

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
# add them to `docs/requirements.txt`
'contextlib2',
'graphviz>=0.8',
'networkx>=2.0.0',
'networkx>=2.2rc1',
'numpy>=1.7',
'opt_einsum>=2.2.0',
'six>=1.10.0',
Expand Down
2 changes: 1 addition & 1 deletion tests/infer/test_compute_downstream_costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _brute_force_compute_downstream_costs(model_trace, guide_trace, #
guide_trace.nodes[node]['log_prob']))
downstream_guide_cost_nodes[node] = set([node])

descendants = networkx.descendants(guide_trace._graph, node)
descendants = networkx.descendants(guide_trace, node)

for desc in descendants:
desc_mft = MultiFrameTensor((stacks[desc],
Expand Down

0 comments on commit b89519d

Please sign in to comment.