Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make trace a subclass of networkx.DiGraph again #1191

Merged
merged 5 commits into from
Sep 15, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ numpy>=1.7
contextlib2
cloudpickle>=0.3.1
graphviz>=0.8
networkx>=2.0.0
networkx>=2.2rc1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, let's just remember to update this to >=2.2 before the next Pyro release.

observations>=0.1.4
106 changes: 6 additions & 100 deletions pyro/poutine/trace_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,7 @@ def _warn_if_nan(name, value):
# Note that -inf log_prob_sum is fine: it is merely a zero-probability event.


class DiGraph(networkx.DiGraph):
"""
Wrapper of :class:`networkx.DiGraph` that makes ``self.nodes`` a ``collections.OrderedDict``.
"""
node_dict_factory = collections.OrderedDict

def fresh_copy(self):
"""
Returns a new ``DiGraph`` instance.
"""
return DiGraph()


class Trace(object):
class Trace(networkx.DiGraph):
"""
Execution trace data structure built on top of :class:`networkx.DiGraph`.

Expand Down Expand Up @@ -87,101 +74,21 @@ class Trace(object):
``'done'``, ``'stop'``, and ``'continuation'`` are only used by Pyro's internals.
"""

node_dict_factory = collections.OrderedDict

def __init__(self, *args, **kwargs):
"""
:param string graph_type: string specifying the kind of trace graph to construct

Constructor. Currently identical to :meth:`networkx.DiGraph.__init__`,
except for storing the graph_type attribute
"""
self._graph = DiGraph(*args, **kwargs)
graph_type = kwargs.pop("graph_type", "flat")
assert graph_type in ("flat", "dense"), \
"{} not a valid graph type".format(graph_type)
self.graph_type = graph_type
super(Trace, self).__init__(*args, **kwargs)

def __del__(self):
"""
Works around cyclic reference bugs in :class:`networkx.DiGraph`
See ``https://github.com/uber/pyro/issues/798``
"""
self._graph.__dict__.clear()

@property
def nodes(self):
"""
Identical to :attr:`networkx.DiGraph.nodes`
"""
return self._graph.nodes

@property
def edges(self):
"""
Identical to :attr:`networkx.DiGraph.edges`
"""
return self._graph.edges

@property
def graph(self):
"""
Identical to :attr:`networkx.DiGraph.graph`
"""
return self._graph.graph

@property
def remove_node(self):
"""
Identical to :meth:`networkx.DiGraph.remove_node`
"""
return self._graph.remove_node

@property
def add_edge(self):
"""
Identical to :meth:`networkx.DiGraph.add_edge`
"""
return self._graph.add_edge

@property
def is_directed(self):
"""
Identical to :attr:`networkx.DiGraph.is_directed`
"""
return self._graph.is_directed

@property
def in_degree(self):
"""
Identical to :meth:`networkx.DiGraph.in_degree`
"""
return self._graph.in_degree

@property
def successors(self):
"""
Identical to :meth:`networkx.DiGraph.successors`
"""
return self._graph.successors

def __contains__(self, site_name):
"""
Identical to :meth:`networkx.DiGraph.__contains__`
"""
return site_name in self._graph

def __iter__(self):
"""
Identical to :meth:`networkx.DiGraph.__iter__`
"""
return iter(self._graph)

def __len__(self):
"""
Identical to :meth:`networkx.DiGraph.__len__`
"""
return len(self._graph)

def add_node(self, site_name, *args, **kwargs):
"""
:param string site_name: the name of the site to be added
Expand All @@ -198,17 +105,16 @@ def add_node(self, site_name, *args, **kwargs):
"site {} already in trace".format(site_name)

# XXX should copy in case site gets mutated, or dont bother?
self._graph.add_node(site_name, *args, **kwargs)
super(Trace, self).add_node(site_name, *args, **kwargs)

def copy(self):
"""
Makes a shallow copy of self with nodes and edges preserved.
Identical to :meth:`networkx.DiGraph.copy`, but preserves the type
and the self.graph_type attribute
"""
trace = Trace()
trace._graph = self._graph.copy()
trace._graph.__class__ = DiGraph
trace = super(Trace, self).copy()
trace.__class__ = Trace
trace.graph_type = self.graph_type
return trace

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
# add them to `docs/requirements.txt`
'contextlib2',
'graphviz>=0.8',
'networkx>=2.0.0',
'networkx>=2.2rc1',
'numpy>=1.7',
'six>=1.10.0',
'torch>=0.4.0',
Expand Down
2 changes: 1 addition & 1 deletion tests/infer/test_compute_downstream_costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _brute_force_compute_downstream_costs(model_trace, guide_trace, #
guide_trace.nodes[node]['log_prob']))
downstream_guide_cost_nodes[node] = set([node])

descendants = networkx.descendants(guide_trace._graph, node)
descendants = networkx.descendants(guide_trace, node)

for desc in descendants:
desc_mft = MultiFrameTensor((stacks[desc],
Expand Down