Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make trace a subclass of networkx.DiGraph again #1191

Merged
merged 5 commits into from
Sep 15, 2018
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 7 additions & 100 deletions pyro/poutine/trace_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,7 @@ def _warn_if_nan(name, value):
# Note that -inf log_prob_sum is fine: it is merely a zero-probability event.


class DiGraph(networkx.DiGraph):
"""
Wrapper of :class:`networkx.DiGraph` that makes ``self.nodes`` a ``collections.OrderedDict``.
"""
node_dict_factory = collections.OrderedDict

def fresh_copy(self):
"""
Returns a new ``DiGraph`` instance.
"""
return DiGraph()


class Trace(object):
class Trace(networkx.DiGraph):
"""
Execution trace data structure built on top of :class:`networkx.DiGraph`.

Expand Down Expand Up @@ -87,101 +74,22 @@ class Trace(object):
``'done'``, ``'stop'``, and ``'continuation'`` are only used by Pyro's internals.
"""

node_dict_factory = collections.OrderedDict

def __init__(self, *args, **kwargs):
"""
:param string graph_type: string specifying the kind of trace graph to construct

Constructor. Currently identical to :meth:`networkx.DiGraph.__init__`,
except for storing the graph_type attribute
"""
self._graph = DiGraph(*args, **kwargs)
# self._graph = DiGraph(*args, **kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove before merging

graph_type = kwargs.pop("graph_type", "flat")
assert graph_type in ("flat", "dense"), \
"{} not a valid graph type".format(graph_type)
self.graph_type = graph_type
super(Trace, self).__init__(*args, **kwargs)

def __del__(self):
"""
Works around cyclic reference bugs in :class:`networkx.DiGraph`
See ``https://github.com/uber/pyro/issues/798``
"""
self._graph.__dict__.clear()

@property
def nodes(self):
"""
Identical to :attr:`networkx.DiGraph.nodes`
"""
return self._graph.nodes

@property
def edges(self):
"""
Identical to :attr:`networkx.DiGraph.edges`
"""
return self._graph.edges

@property
def graph(self):
"""
Identical to :attr:`networkx.DiGraph.graph`
"""
return self._graph.graph

@property
def remove_node(self):
"""
Identical to :meth:`networkx.DiGraph.remove_node`
"""
return self._graph.remove_node

@property
def add_edge(self):
"""
Identical to :meth:`networkx.DiGraph.add_edge`
"""
return self._graph.add_edge

@property
def is_directed(self):
"""
Identical to :attr:`networkx.DiGraph.is_directed`
"""
return self._graph.is_directed

@property
def in_degree(self):
"""
Identical to :meth:`networkx.DiGraph.in_degree`
"""
return self._graph.in_degree

@property
def successors(self):
"""
Identical to :meth:`networkx.DiGraph.successors`
"""
return self._graph.successors

def __contains__(self, site_name):
"""
Identical to :meth:`networkx.DiGraph.__contains__`
"""
return site_name in self._graph

def __iter__(self):
"""
Identical to :meth:`networkx.DiGraph.__iter__`
"""
return iter(self._graph)

def __len__(self):
"""
Identical to :meth:`networkx.DiGraph.__len__`
"""
return len(self._graph)

def add_node(self, site_name, *args, **kwargs):
"""
:param string site_name: the name of the site to be added
Expand All @@ -198,17 +106,16 @@ def add_node(self, site_name, *args, **kwargs):
"site {} already in trace".format(site_name)

# XXX should copy in case site gets mutated, or dont bother?
self._graph.add_node(site_name, *args, **kwargs)
super(Trace, self).add_node(site_name, *args, **kwargs)

def copy(self):
"""
Makes a shallow copy of self with nodes and edges preserved.
Identical to :meth:`networkx.DiGraph.copy`, but preserves the type
and the self.graph_type attribute
"""
trace = Trace()
trace._graph = self._graph.copy()
trace._graph.__class__ = DiGraph
trace = super(Trace, self).copy()
trace.__class__ = Trace
trace.graph_type = self.graph_type
return trace

Expand Down