Skip to content

Commit

Permalink
enh(plot.TC): expose supported writers and TC on them
Browse files Browse the repository at this point in the history
  • Loading branch information
ankostis committed Sep 29, 2019
1 parent 4307bb8 commit c5bcb3b
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 15 deletions.
32 changes: 18 additions & 14 deletions graphkit/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,17 @@ def _compute_sequential_method(self, named_inputs, outputs):
return {k: cache[k] for k in iter(cache) if k in outputs}


@staticmethod
def supported_plot_writers():
return {
".png": lambda gplot: gplot.create_png(),
".dot": lambda gplot: gplot.to_string(),
".jpg": lambda gplot: gplot.create_jpeg(),
".jpeg": lambda gplot: gplot.create_jpeg(),
".pdf": lambda gplot: gplot.create_pdf(),
".svg": lambda gplot: gplot.create_svg(),
}

def plot(self, filename=None, show=False):
"""
Plot the graph.
Expand Down Expand Up @@ -422,23 +433,16 @@ def get_node_name(a):

# save plot
if filename:
supported_plot_formaters = {
".png": g.create_png,
".dot": g.to_string,
".jpg": g.create_jpeg,
".jpeg": g.create_jpeg,
".pdf": g.create_pdf,
".svg": g.create_svg,
}
_basename, ext = os.path.splitext(filename)
plot_formater = supported_plot_formaters.get(ext.lower())
if not plot_formater:
raise Exception(
writers = Network.supported_plot_writers()
plot_writer = Network.supported_plot_writers().get(ext.lower())
if not plot_writer:
raise ValueError(
"Unknown file format for saving graph: %s"
" File extensions must be one of: .png .dot .jpg .jpeg .pdf .svg"
% ext)
" File extensions must be one of: %s"
% (ext, ' '.join(writers)))
with open(filename, "wb") as fh:
fh.write(plot_formater())
fh.write(plot_writer(g))

# display graph via matplotlib
if show:
Expand Down
11 changes: 10 additions & 1 deletion test/test_graphkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,14 +327,23 @@ def test_plotting():
sum_op3 = operation(name='sum_op3', needs=['sum1', 'c'], provides='sum3')(add)
net1 = compose(name='my network 1')(sum_op1, sum_op2, sum_op3)

for ext in ".png .dot .jpg .jpeg .pdf .svg".split():
for ext in network.Network.supported_plot_writers():
tdir = tempfile.mkdtemp(suffix=ext)
png_file = osp.join(tdir, "workflow.png")
net1.net.plot(png_file)
try:
assert osp.exists(png_file)
finally:
shutil.rmtree(tdir, ignore_errors=True)
try:
net1.net.plot('bad.format')
assert False, "Should had failed writting arbitrary file format!"
except ValueError as ex:
assert "Unknown file format" in str(ex)

## Check help msg lists all siupported formats
for ext in network.Network.supported_plot_writers():
assert ext in str(ex)


####################################
Expand Down

0 comments on commit c5bcb3b

Please sign in to comment.