From 35d9a919df7d374f75d0f25ca424194840b2351b Mon Sep 17 00:00:00 2001 From: Kostis Anagnostopoulos Date: Sat, 5 Oct 2019 19:19:05 +0300 Subject: [PATCH] ENH(plot): return SVG rendered in JUPYTER, ... + doc: rename in sample code: netop --> pipeline. + enh(build): add `ipython` in test dependencies. + include it in the plot TC. --- graphkit/base.py | 10 +++++++--- graphkit/network.py | 44 ++++++++++++++++++++++++++++++------------- setup.py | 1 + test/test_graphkit.py | 6 +++++- 4 files changed, 44 insertions(+), 17 deletions(-) diff --git a/graphkit/base.py b/graphkit/base.py index 5f425028..36ccca3c 100644 --- a/graphkit/base.py +++ b/graphkit/base.py @@ -171,15 +171,19 @@ def set_execution_method(self, method): assert method in options self._execution_method = method - def plot(self, filename=None, show=False, + def plot(self, filename=None, show=False, jupyter=None, inputs=None, outputs=None, solution=None): """ :param str filename: Write diagram into a file. Common extensions are ``.png .dot .jpg .jpeg .pdf .svg`` call :func:`network.supported_plot_formats()` for more. - :param boolean show: + :param show: If it evaluates to true, opens the diagram in a matplotlib window. + If it equals `-1`, it plots but does not open the Window. + :param jupyter: + If it evaluates to true, return an SVG suitable to render + in *jupyter notebook cells* (`ipython` must be installed). :param inputs: an optional name list, any nodes in there are plotted as a "house" @@ -195,7 +199,7 @@ def plot(self, filename=None, show=False, See :func:`network.plot_graph()` for the plot legend and example code. """ - return self.net.plot(filename, show, inputs, outputs, solution) + return self.net.plot(filename, show, jupyter, inputs, outputs, solution) def __getstate__(self): state = Operation.__getstate__(self) diff --git a/graphkit/network.py b/graphkit/network.py index 140d0b2e..6dc4d48a 100644 --- a/graphkit/network.py +++ b/graphkit/network.py @@ -376,7 +376,7 @@ def _compute_sequential_method(self, named_inputs, outputs): return {k: cache[k] for k in iter(cache) if k in outputs} - def plot(self, filename=None, show=False, + def plot(self, filename=None, show=False, jupyter=None, inputs=None, outputs=None, solution=None): """ Plot a *Graphviz* graph and return it, if no other argument provided. @@ -385,8 +385,12 @@ def plot(self, filename=None, show=False, Write diagram into a file. Common extensions are ``.png .dot .jpg .jpeg .pdf .svg`` call :func:`network.supported_plot_formats()` for more. - :param boolean show: + :param show: If it evaluates to true, opens the diagram in a matplotlib window. + If it equals `-1``, it plots but does not open the Window. + :param jupyter: + If it evaluates to true, return an SVG suitable to render + in *jupyter notebook cells* (`ipython` must be installed). :param inputs: an optional name list, any nodes in there are plotted as a "house" @@ -402,8 +406,8 @@ def plot(self, filename=None, show=False, See :func:`network.plot_graph` for the plot legend and example code. """ - return plot_graph(self.graph, filename, show, self.steps, - inputs, outputs, solution) + return plot_graph(self.graph, filename, show, jupyter, + self.steps, inputs, outputs, solution) def ready_to_schedule_operation(op, has_executed, graph): @@ -461,8 +465,8 @@ def supported_plot_formats(): return [".%s" % f for f in pydot.Dot().formats] -def plot_graph(graph, filename=None, show=False, steps=None, - inputs=None, outputs=None, solution=None): +def plot_graph(graph, filename=None, show=False, jupyter=False, + steps=None, inputs=None, outputs=None, solution=None): """ Plot a *Graphviz* graph/steps and return it, if no other argument provided. @@ -494,9 +498,12 @@ def plot_graph(graph, filename=None, show=False, steps=None, Write diagram into a file. Common extensions are ``.png .dot .jpg .jpeg .pdf .svg`` call :func:`network.supported_plot_formats()` for more. - :param boolean show: + :param show: If it evaluates to true, opens the diagram in a matplotlib window. - If it equals ``-1``, it plots but does not open the Window. + If it equals `-1``, it plots but does not open the Window. + :param jupyter: + If it evaluates to true, return an SVG suitable to render + in *jupyter notebook cells* (`ipython` must be installed). :param steps: a list of nodes & instructions to overlay on the diagram :param inputs: @@ -514,15 +521,18 @@ def plot_graph(graph, filename=None, show=False, steps=None, **Example:** - >>> netop = compose(name="netop")( + >>> from graphkit import compose, operation + >>> from graphkit.modifiers import optional + + >>> pipeline = compose(name="pipeline")( ... operation(name="add", needs=["a", "b1"], provides=["ab1"])(add), ... operation(name="sub", needs=["a", optional("b2")], provides=["ab2"])(lambda a, b=1: a-b), ... operation(name="abb", needs=["ab1", "ab2"], provides=["asked"])(add), ... ) >>> inputs = {'a': 1, 'b1': 2} - >>> solution=netop(inputs) - >>> netop.plot('plot.svg', inputs=inputs, solution=solution, outputs=['asked', 'b1']); + >>> solution=pipeline(inputs) + >>> pipeline.plot('plot.svg', inputs=inputs, solution=solution, outputs=['asked', 'b1']); """ import pydot @@ -596,7 +606,8 @@ def get_node_name(a): penwidth=3, arrowhead="vee") g.add_edge(edge) - # save plot + # Save plot + # if filename: formats = supported_plot_formats() _basename, ext = os.path.splitext(filename) @@ -608,7 +619,14 @@ def get_node_name(a): g.write(filename, format=ext.lower()[1:]) - # display graph via matplotlib + ## Return an SVG renderable in jupyter. + # + if jupyter: + from IPython.display import SVG + g = SVG(data=g.create_svg()) + + ## Display graph via matplotlib + # if show: import matplotlib.pyplot as plt import matplotlib.image as mpimg diff --git a/setup.py b/setup.py index 4ed30ff8..9b3c98f8 100644 --- a/setup.py +++ b/setup.py @@ -37,6 +37,7 @@ }, tests_require=[ "numpy", + "ipython", # to test jupyter plot. "pydot", # to test plot "matplotlib" # to test plot ], diff --git a/test/test_graphkit.py b/test/test_graphkit.py index c4d8a20f..86bbe42e 100644 --- a/test/test_graphkit.py +++ b/test/test_graphkit.py @@ -139,7 +139,8 @@ def test_plotting(): finally: shutil.rmtree(tdir, ignore_errors=True) - ## Don't open matplotlib window. + ## Try matplotlib Window, but + # without opening a Window. # if sys.version_info < (3, 5): # On PY< 3.5 it fails with: @@ -150,6 +151,9 @@ def test_plotting(): # do not open window in headless travis assert pipeline.plot(show=-1) + ## Try jupyter SVG. + assert "display.SVG" in str(type(pipeline.plot(jupyter=True))) + try: pipeline.plot('bad.format') assert False, "Should had failed writting arbitrary file format!"