-
Notifications
You must be signed in to change notification settings - Fork 0
/
hawkes_process.py
98 lines (77 loc) · 2.77 KB
/
hawkes_process.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from __future__ import print_function
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score
from pyhawkes.models import DiscreteTimeNetworkHawkesModelSpikeAndSlab
np.random.seed(0)
def demo(X ,N_iter=10, window=-1, dt_max=4, p=0.25):
K = X.shape[1]
T = X.shape[0]
print('=====%s==='%str(K))
network_hypers = {"p": p, "allow_self_connections": False}
S = X
weight_hypers = {"parallel_resampling":False}
test_model = DiscreteTimeNetworkHawkesModelSpikeAndSlab(
K=K, dt_max=dt_max,
network_hypers=network_hypers, weight_hypers=weight_hypers)
test_model.add_data(S)
# Initialize plots
try:
test_figure, test_handles = test_model.plot(color="#e41a1c", T_slice=(0,T))
except:
print('No causal graph constructed')
N_samples = N_iter
samples = []
lps = []
w = np.zeros(shape=(K,K))
for itr in range(N_samples):
print("Gibbs iteration ", itr)
test_model.resample_model()
lps.append(test_model.log_probability())
samples.append(test_model.copy_sample())
# print('=====Add by zzk=====')
# print(test_model.W_effective)
w = w+test_model.W_effective
# Update plots
try:
test_model.plot(handles=test_handles)
except:
print('No causal graph constructed')
# print('=====Final W=====')
np.set_printoptions(suppress=True)
# print(w/N_samples)
return w/N_samples
def analyze_samples(true_model, samples, lps):
N_samples = len(samples)
# Compute sample statistics for second half of samples
A_samples = np.array([s.weight_model.A for s in samples])
W_samples = np.array([s.weight_model.W for s in samples])
lps = np.array(lps)
offset = N_samples // 2
A_mean = A_samples[offset:, ...].mean(axis=0)
W_mean = W_samples[offset:, ...].mean(axis=0)
plt.figure()
plt.plot(np.arange(N_samples), lps, 'k')
plt.xlabel("Iteration")
plt.ylabel("Log probability")
plt.show()
# Compute the link prediction accuracy curves
auc_A_mean = roc_auc_score(true_model.weight_model.A.ravel(),
A_mean.ravel())
auc_W_mean = roc_auc_score(true_model.weight_model.A.ravel(),
W_mean.ravel())
aucs = []
for A in A_samples:
aucs.append(roc_auc_score(true_model.weight_model.A.ravel(), A.ravel()))
plt.figure()
plt.plot(aucs, '-r')
plt.plot(auc_A_mean * np.ones_like(aucs), '--r')
plt.plot(auc_W_mean * np.ones_like(aucs), '--b')
plt.xlabel("Iteration")
plt.ylabel("Link prediction AUC")
plt.show()
plt.ioff()
plt.show()
if __name__ == "__main__":
demo()