diff --git a/graphkit/network.py b/graphkit/network.py index 82b7d128..4dbfcd14 100644 --- a/graphkit/network.py +++ b/graphkit/network.py @@ -524,8 +524,6 @@ def plot_graph(graph, filename=None, show=False, steps=None, """ import pydot - import matplotlib.pyplot as plt - import matplotlib.image as mpimg assert graph is not None @@ -610,6 +608,9 @@ def get_node_name(a): # display graph via matplotlib if show: + import matplotlib.pyplot as plt + import matplotlib.image as mpimg + png = g.create_png() sio = io.BytesIO(png) img = mpimg.imread(sio)