Skip to content

Commit

Permalink
ENH(plot): return SVG rendered in JUPYTER, ...
Browse files Browse the repository at this point in the history
+ doc: rename in sample code: netop --> pipeline.
+ enh(build): add `ipython` in test dependencies.
+ include it in the plot TC.
  • Loading branch information
ankostis committed Oct 5, 2019
1 parent 7d389c3 commit 35d9a91
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 17 deletions.
10 changes: 7 additions & 3 deletions graphkit/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,15 +171,19 @@ def set_execution_method(self, method):
assert method in options
self._execution_method = method

def plot(self, filename=None, show=False,
def plot(self, filename=None, show=False, jupyter=None,
inputs=None, outputs=None, solution=None):
"""
:param str filename:
Write diagram into a file.
Common extensions are ``.png .dot .jpg .jpeg .pdf .svg``
call :func:`network.supported_plot_formats()` for more.
:param boolean show:
:param show:
If it evaluates to true, opens the diagram in a matplotlib window.
If it equals `-1`, it plots but does not open the Window.
:param jupyter:
If it evaluates to true, return an SVG suitable to render
in *jupyter notebook cells* (`ipython` must be installed).
:param inputs:
an optional name list, any nodes in there are plotted
as a "house"
Expand All @@ -195,7 +199,7 @@ def plot(self, filename=None, show=False,
See :func:`network.plot_graph()` for the plot legend and example code.
"""
return self.net.plot(filename, show, inputs, outputs, solution)
return self.net.plot(filename, show, jupyter, inputs, outputs, solution)

def __getstate__(self):
state = Operation.__getstate__(self)
Expand Down
44 changes: 31 additions & 13 deletions graphkit/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ def _compute_sequential_method(self, named_inputs, outputs):
return {k: cache[k] for k in iter(cache) if k in outputs}


def plot(self, filename=None, show=False,
def plot(self, filename=None, show=False, jupyter=None,
inputs=None, outputs=None, solution=None):
"""
Plot a *Graphviz* graph and return it, if no other argument provided.
Expand All @@ -385,8 +385,12 @@ def plot(self, filename=None, show=False,
Write diagram into a file.
Common extensions are ``.png .dot .jpg .jpeg .pdf .svg``
call :func:`network.supported_plot_formats()` for more.
:param boolean show:
:param show:
If it evaluates to true, opens the diagram in a matplotlib window.
If it equals `-1``, it plots but does not open the Window.
:param jupyter:
If it evaluates to true, return an SVG suitable to render
in *jupyter notebook cells* (`ipython` must be installed).
:param inputs:
an optional name list, any nodes in there are plotted
as a "house"
Expand All @@ -402,8 +406,8 @@ def plot(self, filename=None, show=False,
See :func:`network.plot_graph` for the plot legend and example code.
"""
return plot_graph(self.graph, filename, show, self.steps,
inputs, outputs, solution)
return plot_graph(self.graph, filename, show, jupyter,
self.steps, inputs, outputs, solution)


def ready_to_schedule_operation(op, has_executed, graph):
Expand Down Expand Up @@ -461,8 +465,8 @@ def supported_plot_formats():
return [".%s" % f for f in pydot.Dot().formats]


def plot_graph(graph, filename=None, show=False, steps=None,
inputs=None, outputs=None, solution=None):
def plot_graph(graph, filename=None, show=False, jupyter=False,
steps=None, inputs=None, outputs=None, solution=None):
"""
Plot a *Graphviz* graph/steps and return it, if no other argument provided.
Expand Down Expand Up @@ -494,9 +498,12 @@ def plot_graph(graph, filename=None, show=False, steps=None,
Write diagram into a file.
Common extensions are ``.png .dot .jpg .jpeg .pdf .svg``
call :func:`network.supported_plot_formats()` for more.
:param boolean show:
:param show:
If it evaluates to true, opens the diagram in a matplotlib window.
If it equals ``-1``, it plots but does not open the Window.
If it equals `-1``, it plots but does not open the Window.
:param jupyter:
If it evaluates to true, return an SVG suitable to render
in *jupyter notebook cells* (`ipython` must be installed).
:param steps:
a list of nodes & instructions to overlay on the diagram
:param inputs:
Expand All @@ -514,15 +521,18 @@ def plot_graph(graph, filename=None, show=False, steps=None,
**Example:**
>>> netop = compose(name="netop")(
>>> from graphkit import compose, operation
>>> from graphkit.modifiers import optional
>>> pipeline = compose(name="pipeline")(
... operation(name="add", needs=["a", "b1"], provides=["ab1"])(add),
... operation(name="sub", needs=["a", optional("b2")], provides=["ab2"])(lambda a, b=1: a-b),
... operation(name="abb", needs=["ab1", "ab2"], provides=["asked"])(add),
... )
>>> inputs = {'a': 1, 'b1': 2}
>>> solution=netop(inputs)
>>> netop.plot('plot.svg', inputs=inputs, solution=solution, outputs=['asked', 'b1']);
>>> solution=pipeline(inputs)
>>> pipeline.plot('plot.svg', inputs=inputs, solution=solution, outputs=['asked', 'b1']);
"""
import pydot
Expand Down Expand Up @@ -596,7 +606,8 @@ def get_node_name(a):
penwidth=3, arrowhead="vee")
g.add_edge(edge)

# save plot
# Save plot
#
if filename:
formats = supported_plot_formats()
_basename, ext = os.path.splitext(filename)
Expand All @@ -608,7 +619,14 @@ def get_node_name(a):

g.write(filename, format=ext.lower()[1:])

# display graph via matplotlib
## Return an SVG renderable in jupyter.
#
if jupyter:
from IPython.display import SVG
g = SVG(data=g.create_svg())

## Display graph via matplotlib
#
if show:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
},
tests_require=[
"numpy",
"ipython", # to test jupyter plot.
"pydot", # to test plot
"matplotlib" # to test plot
],
Expand Down
6 changes: 5 additions & 1 deletion test/test_graphkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ def test_plotting():
finally:
shutil.rmtree(tdir, ignore_errors=True)

## Don't open matplotlib window.
## Try matplotlib Window, but
# without opening a Window.
#
if sys.version_info < (3, 5):
# On PY< 3.5 it fails with:
Expand All @@ -150,6 +151,9 @@ def test_plotting():
# do not open window in headless travis
assert pipeline.plot(show=-1)

## Try jupyter SVG.
assert "display.SVG" in str(type(pipeline.plot(jupyter=True)))

try:
pipeline.plot('bad.format')
assert False, "Should had failed writting arbitrary file format!"
Expand Down

0 comments on commit 35d9a91

Please sign in to comment.