-
Notifications
You must be signed in to change notification settings - Fork 1
/
plotNetwork.py
384 lines (300 loc) · 17.7 KB
/
plotNetwork.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
import sys
import networkx as nx
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import copy
from .utils import *
class plotNetwork:
usage = """Produces a static spring-embedded network from a NetworkX graph.
Initial_Parameters
----------
g : NetworkX graph.
Methods
-------
set_params : Set parameters -
imageFileName: The image file name to save to (default: 'networkPlot.jpg')
edgeLabels: Setting to 'True' labels all edges with the similarity score (default: True)
saveImage: Setting to 'True' will save the image to file (default: True)
layout: Set the NetworkX layout type ('circular', 'kamada_kawai', 'random', 'spring', 'spectral') (default: 'spring')
transparent: Setting to 'True' will make the background transparent (default: False)
dpi: The number of Dots Per Inch (DPI) for the image (default: 200)
figSize: The figure size as a tuple (width,height) (default: (30,20))
node_cmap: The CMAP colour palette to use for nodes (default: 'brg')
colorScale: The scale to use for colouring the nodes ("linear", "reverse_linear", "log", "reverse_log", "square", "reverse_square", "area", "reverse_area", "volume", "reverse_volume", "ordinal", "reverse_ordinal") (default: 'linear')
node_color_column: The Peak Table column to use for node colours (default: None sets to black)
sizeScale: The node size scale to apply ("linear", "reverse_linear", "log", "reverse_log", "square", "reverse_square", "area", "reverse_area", "volume", "reverse_volume", "ordinal", "reverse_ordinal") (default: 'reverse_linear')
size_range: The node size scale range to apply. Tuple of length 2. Minimum size to maximum size (default: (150,2000))
sizing_column: The node sizing column to use (default: sizes all nodes to 1)
alpha: Node opacity value (default: 0.5)
nodeLabels: Setting to 'True' will label the nodes (default: True)
fontSize: The font size set for each node (default: 15)
keepSingletons: Setting to 'True' will keep any single nodes not connected by edges in the NetworkX graph) (default: True)
column: Column from Peak Table to filter on (default: no filtering)
threshold: Value to filter on (default: no filtering)
operator: The comparison operator to use when filtering (default: '>')
sign: The sign of the score to filter on ('pos', 'neg' or 'both') (default: 'pos')
help : Print this help text
build : Generates and displays the NetworkX graph.
"""
def __init__(self, g):
self.__g = self.__checkData(copy.deepcopy(g))
self.set_params()
def help(self):
print(plotNetwork.usage)
def set_params(self, imageFileName='networkPlot.jpg', edgeLabels=True, saveImage=True, layout='spring', transparent=False, dpi=200, figSize=(30,20), node_cmap='brg', colorScale='linear', node_color_column='none', sizeScale='reverse_linear', size_range=(150,2000), sizing_column='none', alpha=0.5, nodeLabels=True, fontSize=15, keepSingletons=True, filter_column='none', threshold=0.01, operator='>', sign='pos'):
imageFileName, edgeLabels, saveImage, layout, transparent, dpi, figSize, node_cmap, colorScale, node_color_column, sizeScale, size_range, sizing_column, alpha, nodeLabels, fontSize, keepSingletons, filter_column, threshold, operator, sign = self.__paramCheck(imageFileName, edgeLabels, saveImage, layout, transparent, dpi, figSize, node_cmap, colorScale, node_color_column, sizeScale, size_range, sizing_column, alpha, nodeLabels, fontSize, keepSingletons, filter_column, threshold, operator, sign)
self.__imageFileName = imageFileName;
self.__edgeLabels = edgeLabels;
self.__saveImage = saveImage;
self.__layout = layout;
self.__transparent = transparent;
self.__dpi = dpi
self.__figSize = figSize;
self.__node_cmap = node_cmap;
self.__colorScale = colorScale;
self.__node_color_column = node_color_column;
self.__sizeScale = sizeScale;
self.__size_range = size_range;
self.__sizing_column = sizing_column;
self.__alpha = alpha
self.__nodeLabels = nodeLabels
self.__fontSize = fontSize;
self.__keepSingletons = keepSingletons;
self.__filter_column = filter_column;
self.__filter_threshold = threshold;
self.__operator = operator;
self.__sign = sign;
def build(self):
g = self.__g
plt.subplots(figsize=self.__figSize);
edgeList = []
for idx, (source, target) in enumerate(g.edges()):
weight = g.edges[source, target]['weight']
if self.__sign == "pos":
if weight < 0:
edgeList.append((source, target))
elif self.__sign == "neg":
if weight >= 0:
edgeList.append((source, target))
g.remove_edges_from(edgeList)
if self.__filter_column != 'none':
nodeList = []
for idx, node in enumerate(g.nodes()):
value = float(g.nodes[node][self.__filter_column])
if np.isnan(value):
value = 0;
if self.__operator == ">":
if value > float(self.__filter_threshold):
nodeList.append(node)
elif self.__operator == "<":
if value < float(self.__filter_threshold):
nodeList.append(node)
elif self.__operator == "<=":
if value <= float(self.__filter_threshold):
nodeList.append(node)
elif self.__operator == ">=":
if value >= float(self.__filter_threshold):
nodeList.append(node)
for node in nodeList:
g.remove_node(node)
if not self.__keepSingletons:
edges = list(g.edges())
edgeList = []
for edge in edges:
source = edge[0]
target = edge[1]
edgeList.append(source)
edgeList.append(target)
edgeNodes = np.unique(edgeList)
singleNodes = list(set(edgeNodes).symmetric_difference(set(list(g.nodes()))))
for node in singleNodes:
g.remove_node(node)
if not g.nodes():
print("Error: All nodes have been removed. Please change the filter parameters.")
sys.exit()
if self.__sizing_column == 'none':
size_attr = np.ones(len(g.nodes()))
else:
size_attr = np.array(list(nx.get_node_attributes(g, str(self.__sizing_column)).values()))
if ((self.__sizeScale != "ordinal") and (self.__sizeScale != "reverse_ordinal")):
df_size_attr = pd.Series(size_attr, dtype=float);
size_attr = np.array(list(df_size_attr.fillna(0).values))
node_size = transform(size_attr, self.__sizeScale, self.__size_range[0], self.__size_range[1])
if self.__layout == "circular":
pos = nx.circular_layout(g)
elif self.__layout == "kamada_kawai":
pos = nx.kamada_kawai_layout(g)
elif self.__layout == "random":
pos = nx.random_layout(g)
elif self.__layout == "spring":
pos = nx.spring_layout(g)
elif self.__layout == "spectral":
pos = nx.spectral_layout(g)
nodeCmap = plt.cm.get_cmap(self.__node_cmap) # Sets the color palette for the nodes
if self.__node_color_column == 'none':
node_color = "#000000"
else:
colorsHEX = []
node_color_values = np.array(list(nx.get_node_attributes(g, str(self.__node_color_column)).values()))
try:
float(node_color_values[0])
node_color_values = np.array([float(i) for i in node_color_values])
colorsRGB = self.__get_colors(node_color_values, nodeCmap)[:, :3]
for rgb in colorsRGB:
colorsHEX.append(matplotlib.colors.rgb2hex(rgb))
node_color = colorsHEX
except ValueError:
if matplotlib.colors.is_color_like(node_color_values[0]):
node_color = node_color_values
else:
if ((self.__colorScale != "ordinal") and (self.__colorScale != "reverse_ordinal")):
print("Error: Node colour column is not valid. While colorScale is not ordinal or reverse_ordinal, choose a column containing colour values, floats or integer values.")
sys.exit()
else:
colorsRGB = self.__get_colors(node_color_values, nodeCmap)[:, :3]
for rgb in colorsRGB:
colorsHEX.append(matplotlib.colors.rgb2hex(rgb))
node_color = colorsHEX
nx.draw(g, pos=pos, labels=dict(zip(g.nodes(), list(nx.get_node_attributes(g, 'Label').values()))), node_size=node_size, font_size=self.__fontSize, node_color=node_color, alpha=self.__alpha, with_labels=self.__nodeLabels)
if self.__edgeLabels:
edge_labels = dict({})
for idx, (source, target) in enumerate(g.edges()):
weight = g.edges[source, target]['weight']
edge_labels.update({(source, target): float("{0:.2f}".format(weight))})
nx.draw_networkx_edge_labels(g, pos=pos, edge_labels=edge_labels)
if self.__saveImage:
plt.savefig(self.__imageFileName, dpi=self.__dpi, transparent=self.__transparent)
plt.show()
def __checkData(self, g):
if not isinstance(g, nx.classes.graph.Graph):
print("Error: A NetworkX graph was not entered. Please check your data.")
sys.exit()
return g
def __paramCheck(self, imageFileName, edgeLabels, saveImage, layout, transparent, dpi, figSize, node_cmap, colorScale, node_color_column, sizeScale, size_range, sizing_column, alpha, nodeLabels, fontSize, keepSingletons, filter_column, filter_threshold, operator, sign):
g = self.__g
col_list = list(g.nodes[list(g.nodes.keys())[0]].keys()) + ['none']
cmap_list = list(matplotlib.cm.cmaps_listed) + list(matplotlib.cm.datad)
cmap_list_r = [cmap + '_r' for cmap in cmap_list]
cmap_list = cmap_list + cmap_list_r
if not isinstance(imageFileName, str):
print("Error: Image file name is not valid. Choose a string value.")
sys.exit()
if not isinstance(edgeLabels, bool):
print("Error: Edge labels is not valid. Choose either \"True\" or \"False\".")
sys.exit()
if not isinstance(saveImage, bool):
print("Error: Save image is not valid. Choose either \"True\" or \"False\".")
sys.exit()
if layout not in ["circular", "kamada_kawai", "random", "spring", "spectral"]:
print("Error: Layout program not valid. Choose either \"circular\", \"kamada_kawai\", \"random\", \"spring\", \"spectral\".")
sys.exit()
if not isinstance(transparent, bool):
print("Error: The transparent value is not valid. Choose either \"True\" or \"False\".")
sys.exit()
if not isinstance(dpi, float):
if not isinstance(dpi, int):
print("Error: Dpi is not valid. Choose a float or integer value.")
sys.exit()
if not isinstance(figSize, tuple):
print("Error: Figure size is not valid. Choose a tuple of length 2.")
sys.exit()
else:
for length in figSize:
if not isinstance(length, float):
if not isinstance(length, int):
print("Error: Figure size items not valid. Choose a float or integer value.")
sys.exit()
if not isinstance(node_cmap, str):
print("Error: Node CMAP is not valid. Choose a string value.")
sys.exit()
else:
if node_cmap not in cmap_list:
print("Error: Node CMAP is not valid. Choose one of the following: {}.".format(', '.join(cmap_list)))
sys.exit()
if colorScale.lower() not in ["linear", "reverse_linear", "log", "reverse_log", "square", "reverse_square", "area", "reverse_area", "volume", "reverse_volume", "ordinal", "reverse_ordinal"]:
print("Error: Color scale type not valid. Choose either \"linear\", \"reverse_linear\", \"log\", \"reverse_log\", \"square\", \"reverse_square\", \"area\", \"reverse_area\", \"volume\", \"reverse_volume\", \"ordinal\", \"reverse_ordinal\".")
sys.exit()
if node_color_column not in col_list:
print("Error: Node color column not valid. Choose one of {}.".format(', '.join(col_list)))
sys.exit()
else:
if node_color_column != 'none':
node_color_values = np.array(list(nx.get_node_attributes(g, str(node_color_column)).values()))
if ((colorScale != 'ordinal') and (colorScale != 'reverse_ordinal')):
try:
float(node_color_values[0])
except ValueError:
if not matplotlib.colors.is_color_like(node_color_values[0]):
print("Error: Node colour column is not valid. While colorScale is not ordinal or reverse_ordinal, choose a column containing colour values, floats or integer values.")
sys.exit()
if sizeScale.lower() not in ["linear", "reverse_linear", "log", "reverse_log", "square", "reverse_square", "area", "reverse_area", "volume", "reverse_volume", "ordinal", "reverse_ordinal"]:
print("Error: Size scale type not valid. Choose either \"linear\", \"reverse_linear\", \"log\", \"reverse_log\", \"square\", \"reverse_square\", \"area\", \"reverse_area\", \"volume\", \"reverse_volume\", \"ordinal\", \"reverse_ordinal\".")
sys.exit()
if not isinstance(size_range, tuple):
print("Error: Size range is not valid. Choose a tuple of length 2.")
sys.exit()
else:
for size in size_range:
if not isinstance(size, float):
if not isinstance(size, int):
print("Error: Size values not valid. Choose a float or integer value.")
sys.exit()
if sizing_column not in col_list:
print("Error: Sizing column not valid. Choose one of {}.".format(', '.join(col_list)))
sys.exit()
else:
if sizing_column != 'none':
for idx, node in enumerate(g.nodes()):
if ((sizeScale != 'ordinal') and (sizeScale != 'reverse_ordinal')):
try:
float(g.nodes[node][sizing_column])
except ValueError:
print("Error: Sizing column contains invalid values. While sizeScale is not ordinal or reverse_ordinal, choose a sizing column containing float or integer values.")
sys.exit()
if not isinstance(alpha, float):
if not (alpha >= 0 and alpha <= 1):
print("Error: Alpha value is not valid. Choose a float between 0 and 1.")
sys.exit()
if not isinstance(nodeLabels, bool):
print("Error: Add labels is not valid. Choose either \"True\" or \"False\".")
sys.exit()
if not isinstance(fontSize, float):
if not isinstance(fontSize, int):
print("Error: Font size is not valid. Choose a float or integer value.")
sys.exit()
if not isinstance(keepSingletons, bool):
print("Error: Keep singletons is not valid. Choose either \"True\" or \"False\".")
sys.exit()
if filter_column not in col_list:
print("Error: Filter column not valid. Choose one of {}.".format(', '.join(col_list)))
sys.exit()
else:
if filter_column != 'none':
for idx, node in enumerate(g.nodes()):
try:
float(g.nodes[node][filter_column])
except ValueError:
print("Error: Filter column contains invalid values. Choose a filter column containing float or integer values.")
sys.exit()
if not isinstance(filter_threshold, float):
if not isinstance(filter_threshold, int):
print("Error: Filter threshold is not valid. Choose a float or integer value.")
sys.exit()
elif filter_threshold == 0:
print("Error: Filter threshold should not be zero. Choose a value close to zero or above.")
sys.exit()
elif filter_threshold == 0.0:
print("Error: Filter threshold should not be zero. Choose a value close to zero or above.")
sys.exit()
if operator not in ["<", ">", "<=", ">="]:
print("Error: Operator not valid. Choose either \"<\", \">\", \"<=\" or \">=\".")
sys.exit()
if sign.lower() not in ["pos", "neg", "both"]:
print("Error: Sign not valid. Choose either \"pos\", \"neg\", or \"both\".")
sys.exit()
return imageFileName, edgeLabels, saveImage, layout, transparent, dpi, figSize, node_cmap, colorScale, node_color_column, sizeScale, size_range, sizing_column, alpha, nodeLabels, fontSize, keepSingletons, filter_column, filter_threshold, operator, sign
def __get_colors(self, x, cmap):
scaled_colors = transform(x, self.__colorScale, 0, 1)
return cmap(scaled_colors)