-
Notifications
You must be signed in to change notification settings - Fork 0
/
tools.py
69 lines (42 loc) · 1.6 KB
/
tools.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import numpy as np
import networkx as nx
# For illustration purpose only [easy to understand the process]
# -----------------------------
def pure_cascade_virality(G):
'''G is a directed graph(tree)'''
if not nx.is_weakly_connected(G):
# return None
return
nodes = [k for (k,v) in G.out_degree() if v>0] # non-leaf nodes
virality = 0
for source in nodes:
path_lens = nx.single_source_shortest_path_length(G, source) # shortest path length
path_lens = {k: v for k, v in path_lens.items() if v > 0} # filter 0
virality += np.array(list(path_lens.values())).mean() # mean length from source to other nodes
return virality
# Works in a recursive manner [more efficient]
# -----------------------------
def recursive_path_length(G, V, seed):
'''G is a directed graph(tree)'''
V[seed] = []
for i in G.successors(seed):
V[seed].append(1)
V[seed] += [j+1 for j in recursive_path_length(G, V, i)]
return V[seed]
def recursive_cascade_virality(G, source=None):
'''G is a directed graph(tree)'''
if not nx.is_weakly_connected(G):
# return None
return
if not source:
# if root is not given, find it by yourself
source = [k for (k,v) in G.in_degree() if v==0][0]
V_dic = {}
recursive_path_length(G, V_dic, source)
# return V_dic # return original paths
virality = 0
for (k, v) in V_dic.items():
# print(k, v)
if len(v)>0:
virality += np.mean(v)
return virality # return cascade virality