-
Notifications
You must be signed in to change notification settings - Fork 3
/
dynamics.py
747 lines (653 loc) · 30.8 KB
/
dynamics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
import time
import bisect
import numpy as np
import pandas as pd
import networkx as nx
import scipy
import scipy.optimize
import scipy as sp
import random as rd
import heapq
import collections
import itertools
import os
import copy
from counterfactual_tpp import sample_counterfactual, combine
from sampling_utils import thinning_T
# from . import maxcut
from settings import DATA_DIR
def sample_seeds(graph, delta, method='data', n_seeds=None, max_date=None, verbose=True):
"""
Extract seeds from the Ebola cases datasets, by choosing either:
* the first `n_seeds`.
* the first seed until the date `max_date`.
For each seed, we then simulate its recovery time and attribute it to a random node in the
corresponding district. We then start the epidemic at the time of infection of the last seed.
Note that some seeds may have already recovered at this time. In this case, they are just
ignored from the simulation altogether.
Arguments:
---------
graph : nx.Graph
The graph of individuals in districts. Nodes must have the attribute `district`.
delta : float
Recovery rate of the epidemic process. Used to sample recovery times of seeds.
n_seeds : int
Number of seeds to sample.
max_date : str
Maximum date to sample seeds (max_date is included in sampling).
method : str ('data' or 'random')
Method to sample the seeds. Can be one of:
- 'data': Use the seeds from the dataset and sample recovery time
- 'random': Sample random seeds along with their recovery time
verbose : bool
Indicate whether or not to print seed generation process.
"""
assert (n_seeds is not None) or (max_date is not None), "Either `n_seeds` or `max_date` must be given"
if method == 'data':
# Load real data
df = pd.read_csv(os.path.join(DATA_DIR, 'ebola', 'rstb20160308_si_001_cleaned.csv'))
if n_seeds:
df = df.sort_values('infection_timestamp').iloc[:n_seeds]
elif max_date:
df = df[df.infection_date <= max_date].sort_values('infection_timestamp')
# Extract the seed disctricts
seed_names = list(df['district'])
# Extract district name for each node in the graph
node_names = np.array([u for u, d in graph.nodes(data=True)])
node_districts = np.array([d['district'] for u, d in graph.nodes(data=True)])
# Get last infection time of seeds (this is time zero for the simulation)
last_inf_time = df.infection_timestamp.max()
# Init list of seed events
init_event_list = list()
for _, row in df.iterrows():
inf_time = row['infection_timestamp']
# Sample recovery time
rec_time = inf_time + rd.expovariate(delta) - last_inf_time
# Ignore seed if recovered before time zero
if rec_time > 0:
# Randomly sample one node for each seed in the corresponding district
idx = np.random.choice(np.where(node_districts == row['district'])[0])
node = node_names[idx]
# Add infection event
# node to node infection flags initial seeds in code
init_event_list.append([(node, 'inf', node), 0.0]) # Gets infection at the start
# Add recovery event
init_event_list.append([(node, 'rec', None), rec_time])
if verbose:
print(f'Add seed {node} from district {row["district"]} - inf: {0.0}, rec: {rec_time} ')
return init_event_list
elif method == 'random':
if n_seeds is None:
raise ValueError("`n_seeds` must be provided for method `random`")
init_event_list = list()
for _ in range(n_seeds):
node = np.random.choice(graph.nodes())
init_event_list.append([(node, 'inf', node), 0.0])
rec_time = rd.expovariate(delta)
init_event_list.append([(node, 'rec', None), rec_time])
return init_event_list
else:
raise ValueError('Invalid method.')
class PriorityQueue(object):
"""
PriorityQueue with O(1) update and deletion of objects
"""
def __init__(self, initial=[], priorities=[]):
self.pq = []
self.entry_finder = {} # mapping of tasks to entries
self.removed = '<removed-task>' # placeholder for a removed task
self.counter = itertools.count() # unique sequence count
assert(len(initial) == len(priorities))
for i in range(len(initial)):
self.push(initial[i], priority=priorities[i])
def push(self, task, priority=0):
"""Add a new task or update the priority of an existing task"""
if task in self.entry_finder:
self.delete(task)
count = next(self.counter)
entry = [priority, count, task]
self.entry_finder[task] = entry
heapq.heappush(self.pq, entry)
def delete(self, task):
"""Mark an existing task as REMOVED. Raise KeyError if not found."""
entry = self.entry_finder.pop(task)
entry[-1] = self.removed
def remove_all_tasks_of_type(self, type):
"""Removes all existing tasks of a specific type (for SIRSimulation)"""
keys = list(self.entry_finder.keys())
for event in keys:
u, type_, v = event
if type_ == type:
self.delete(event)
def pop_priority(self):
"""
Remove and return the lowest priority task with its priority value.
Raise KeyError if empty.
"""
while self.pq:
priority, _, task = heapq.heappop(self.pq)
if task is not self.removed:
del self.entry_finder[task]
return task, priority
raise KeyError('pop from an empty priority queue')
def pop(self):
"""
Remove and return the lowest priority task. Raise KeyError if empty.
"""
task, _ = self.pop_priority()
return task
def priority(self, task):
"""Return priority of task"""
if task in self.entry_finder:
return self.entry_finder[task][0]
else:
raise KeyError('task not in queue')
def __len__(self):
return len(self.entry_finder)
def __str__(self):
return str(self.pq)
def __repr__(self):
return repr(self.pq)
def __setitem__(self, task, priority):
self.push(task, priority=priority)
class ProgressPrinter(object):
"""
Helper object to print relevant information throughout the epidemic
"""
PRINT_INTERVAL = 0.1
_PRINT_MSG = ('{t:.2f} days elapsed '
'| '
'{S:.0f} sus., '
'{I:.0f} inf., '
'{R:.0f} rec., '
'{Tt:.0f} tre ({TI:.2f}% of inf) | '
# 'I(q): {iq} R(q): {rq} T(q): {tq} |q|: {lq} | '
'max_u {max_u:.2e}'
)
_PRINTLN_MSG = ('Epidemic stopped after {t:.2f} days '
'| '
'{S:.0f} sus., '
'{I:.0f} inf., '
'{R:.0f} rec., '
'{Tt:.0f} tre ({TI:.2f}% of inf) | '
# 'I(q): {iq} R(q): {rq} T(q): {tq} |q|: {lq}'
'max_u {max_u:.2e}'
)
def __init__(self, verbose=True):
self.verbose = verbose
self.last_print = time.time()
def print(self, sir_obj, epitime, end='', force=False):
if not self.verbose:
return
if (time.time() - self.last_print > self.PRINT_INTERVAL) or force:
S = np.sum(sir_obj.is_sus)
I = np.sum(sir_obj.is_inf * (1 - sir_obj.is_rec))
R = np.sum(sir_obj.is_rec)
T = np.sum(sir_obj.is_tre)
Tt = np.sum(sir_obj.is_tre)
TI = 100. * T / I if I > 0 else np.nan
iq = sir_obj.infs_in_queue
rq = sir_obj.recs_in_queue
tq = sir_obj.tres_in_queue
lq = len(sir_obj.queue)
print('\r', self._PRINT_MSG.format(t=epitime, S=S, I=I, R=R, Tt=Tt, TI=TI,
max_u=sir_obj.max_total_control_intensity),
sep='', end='', flush=True)
self.last_print = time.time()
def println(self, sir_obj, epitime):
if not self.verbose:
return
S = np.sum(sir_obj.is_sus)
I = np.sum(sir_obj.is_inf * (1 - sir_obj.is_rec))
R = np.sum(sir_obj.is_rec)
T = np.sum(sir_obj.is_tre)
Tt = np.sum(sir_obj.is_tre)
TI = 100. * T / I if I > 0 else np.nan
iq = sir_obj.infs_in_queue
rq = sir_obj.recs_in_queue
tq = sir_obj.tres_in_queue
lq = len(sir_obj.queue)
print('\r', self._PRINTLN_MSG.format(
t=epitime, S=S, I=I, R=R, Tt=Tt, TI=TI,
max_u=sir_obj.max_total_control_intensity),
sep='', end='\n', flush=True)
self.last_print = time.time()
class SimulationSIR(object):
"""
Simulate continuous-time SIR epidemics with treatement, with exponentially distributed
inter-event times.
The simulation algorithm works by leveraging the Markov property of the model and rejection
sampling. Events are treated in order in a priority queue. An event in the queue is a tuple
the form
`(node, event_type, infector_node)`
where elements are as follows:
`node` : is the node where the event occurs,
`event_type` : is the type of event (i.e. infected 'inf', recovery 'rec', or treatement 'tre')
`infector_node` : for infections only, the node of caused the infection.
"""
AVAILABLE_LPSOLVERS = ['scipy', 'cvxopt']
def __init__(self, G, beta, gamma, delta, rho, verbose=True):
"""
Init an SIR cascade over a graph
Arguments:
---------
G : networkx.Graph()
Graph over which the epidemic propagates
beta : float
Exponential infection rate (positive)
gamma : float
Reduction in infection rate by treatment
delta : float
Exponential recovery rate (non-negative)
rho : float
Increase in recovery rate by treatment
verbose : bool (default: True)
Indicate the print behavior, if set to False, nothing will be printed
"""
if not isinstance(G, nx.Graph):
raise ValueError('Invalid graph type, must be networkx.Graph')
self.G = G
self.A = sp.sparse.csr_matrix(nx.adjacency_matrix(self.G).toarray())
# Cache the number of nodes
self.n_nodes = len(G.nodes())
self.max_deg = np.max([d for n, d in self.G.degree()])
self.min_deg = np.min([d for n, d in self.G.degree()])
self.idx_to_node = dict(zip(range(self.n_nodes), self.G.nodes()))
self.node_to_idx = dict(zip(self.G.nodes(), range(self.n_nodes)))
# Check parameters
if isinstance(beta, (float, int)) and (beta > 0):
self.beta = beta
else:
raise ValueError("`beta` must be a positive float")
if isinstance(gamma, (float, int)) and (gamma >= 0) and (gamma <= beta):
self.gamma = gamma
else:
raise ValueError(("`gamma` must be a positive float smaller than `beta`"))
if isinstance(delta, (float, int)) and (delta >= 0):
self.delta = delta
else:
raise ValueError("`delta` must be a non-negative float")
if isinstance(rho, (float, int)) and (rho >= 0):
self.rho = rho
else:
raise ValueError("`rho` must be a non-negative float")
# Control pre-computations
self.lrsr_initiated = False # flag for initial LRSR computation
self.mcm_initiated = False # flag for initial MCM computation
# Control statistics
self.max_total_control_intensity = 0.0
# Printer for logging
self._printer = ProgressPrinter(verbose=verbose)
def expo(self, rate):
"""Samples a single exponential random variable."""
return rd.expovariate(rate)
def nodes_at_time(self, status, time):
"""
Get the status of all nodes at a given time
"""
if status == 'S':
return self.inf_occured_at > time
elif status == 'I':
return (self.rec_occured_at > time) * (self.inf_occured_at < time)
elif status == 'T':
return (self.tre_occured_at < time) * (self.rec_occured_at > time)
elif status == 'R':
return self.rec_occured_at < time
else:
raise ValueError('Invalid status.')
def _init_run(self, init_event_list, max_time):
"""
Initialize the run of the epidemic
"""
# Max time of the run
self.max_time = max_time
# Priority queue of events by time
# event invariant is ('node', event, 'node') where the second node is the infector if applicable
self.queue = PriorityQueue()
# Cache the number of ins, recs, tres in the queue
self.infs_in_queue = 0
self.recs_in_queue = 0
self.tres_in_queue = 0
# Susceptible nodes tracking: is_sus[node]=1 if node is currently susceptible)
self.initial_seed = np.zeros(self.n_nodes, dtype='bool')
self.is_sus = np.ones(self.n_nodes, dtype='bool') # True if u susceptible
# Infection tracking: is_inf[node]=1 if node has been infected
# (note that the node can be already recovered)
self.inf_occured_at = np.inf * np.ones(self.n_nodes, dtype='float') # time infection of u_idx occurred
self.is_inf = np.zeros(self.n_nodes, dtype='bool') # True if u_idx infected
self.infector = -1 * np.ones(self.n_nodes, dtype='int') # index of node that infected u_idx (if -1, then no infector)
self.num_child_inf = np.zeros(self.n_nodes, dtype='int') # number of neighbors u_idx infected
# Recovery tracking: is_rec[node]=1 if node is currently recovered
self.rec_occured_at = np.inf * np.ones(self.n_nodes, dtype='float') # time recovery of u_idx occured
self.is_rec = np.zeros(self.n_nodes, dtype='bool') # True if u_idx recovered
# Treatment tracking: is_tre[node]=1 if node is currently treated
self.tre_occured_at = np.inf * np.ones(self.n_nodes, dtype='float') # time treatment of u_idx occured
self.is_tre = np.zeros(self.n_nodes, dtype='bool') # True if u_idx treated
# Conrol tracking
self.old_lambdas = np.zeros(self.n_nodes, dtype='float') # control intensity of prev iter
self.max_interventions_reached = False
# Add the initial events to priority queue
for event, time in init_event_list:
u, event_type, _ = event
u_idx = self.node_to_idx[u]
self.initial_seed[u_idx] = True
if event_type == 'inf':
# Initial infections have infections from NaN to u
self.queue.push(event, priority=time)
self.infs_in_queue += 1
elif event_type == 'rec':
self.queue.push(event, priority=time)
self.recs_in_queue += 1
else:
raise ValueError('Invalid Event Type for initial seeds.')
def _process_infection_event(self, u, time, w):
"""
Mark node `u` as infected at time `time`
Sample its recovery time and its neighbors infection times and add to the queue
"""
# Get node index
u_idx = self.node_to_idx[u]
# Handle infection event
self.is_inf[u_idx] = True
self.is_sus[u_idx] = False
self.inf_occured_at[u_idx] = time
if self.initial_seed[u_idx]:
# Handle initial seeds
self.infector[u_idx] = -1
else:
w_idx = self.node_to_idx[w]
self.infector[u_idx] = w_idx
self.num_child_inf[w_idx] += 1
recovery_time_u = time + self.expo(self.delta)
self.queue.push((u, 'rec', None), priority=recovery_time_u)
self.recs_in_queue += 1
# Set neighbors infection events
for v in self.G.neighbors(u):
v_idx = self.node_to_idx[v]
if self.is_sus[v_idx]:
infection_time_v = time + self.expo(self.beta)
self.queue.push((v, 'inf', u), priority=infection_time_v)
self.infs_in_queue += 1
def _process_recovery_event(self, u, time):
"""
Mark node `node` as recovered at time `time`
"""
# Get node index
u_idx = self.node_to_idx[u]
# Handle recovery event
self.rec_occured_at[u_idx] = time
self.is_rec[u_idx] = True
def _process_treatment_event(self, u, time):
"""
Mark node `u` as treated at time `time`
Update its recovery time and its neighbors infection times and the queue
"""
# Get node index
u_idx = self.node_to_idx[u]
# Handle treatement event
self.tre_occured_at[u_idx] = time
self.is_tre[u_idx] = True
# Update own recovery event with rejection sampling
assert(self.rho <= 0)
if rd.random() < - self.rho / self.delta:
# reject previous event
self.queue.delete((u, 'rec', None))
# re-sample
new_recovery_time_u = time + self.expo(self.delta + self.rho)
self.queue.push((u, 'rec', None), priority=new_recovery_time_u)
# Update neighbors infection events triggered by u
for v in self.G.neighbors(u):
v_idx = self.node_to_idx[v]
if self.is_sus[v_idx]:
if rd.random() < self.gamma / self.beta:
# reject previous event
self.queue.delete((v, 'inf', u))
# re-sample
if self.beta - self.gamma > 0:
new_infection_time_v = time + self.expo(self.beta - self.gamma)
else:
# Avoid DivisionByZeroError if beta = gamma
# i.e., if no infectivity under treatement, then set infection time to inf
# We still set an event to make the algo easier and avoid bugs
new_infection_time_v = np.inf
self.queue.push((v, 'inf', u), priority=new_infection_time_v)
def _control(self, u, time, policy='NO'):
# Get node index
u_idx = self.node_to_idx[u]
# Check if max interventions were reached (for FL)
if '-FL' in policy:
max_interventions = self.policy_dict['front-loading']['max_interventions']
current_interventions = np.sum(self.is_tre)
if current_interventions > max_interventions:
# End interventions for this simulation
self.max_interventions_reached = True
self.queue.remove_all_tasks_of_type('tre')
print('All treatments ended')
return
# Compute control intensity
self.new_lambda = self._compute_lambda(u, time, policy=policy)
# Sample treatment event
delta = self.new_lambda - self.old_lambdas[u_idx]
if delta < 0:
# Update treatment event with rejection sampling as intensity was reduced
if rd.random() < 1 - self.new_lambda / self.old_lambdas[u_idx]:
# reject previous event
self.queue.delete((u, 'tre', None))
if self.new_lambda > 0:
# re-sample
new_treatment_time_u = time + self.expo(self.new_lambda)
self.queue.push((u, 'tre', None), priority=new_treatment_time_u)
elif delta > 0:
# Sample new/additional treatment event with the superposition principle
new_treatment_time_u = time + self.expo(delta)
self.queue.push((u, 'tre', None), priority=new_treatment_time_u)
self.tres_in_queue += 1
# store lambda
self.old_lambdas[u_idx] = self.new_lambda
def _compute_lambda(self, u, time, policy='NO'):
"""Computes control intensity of the respective policy"""
if policy == 'NO':
return 0.0
elif policy == 'TR':
# lambda = const.
return self.policy_dict['TR']
elif policy == 'TR-FL':
return self._frontloadPolicy(
self.policy_dict['TR'],
self.policy_dict['TR'], time)
elif policy == 'MN':
# lambda ~ deg(u)
return self.G.degree(u) * self.policy_dict['MN']
elif policy == 'MN-FL':
return self._frontloadPolicy(
self.G.degree(u) * self.policy_dict['MN'],
self.max_deg * self.policy_dict['MN'], time)
elif policy == 'LN':
# lambda ~ (maxdeg - deg(u) + 1)
return (self.max_deg - self.G.degree(u) + 1) * self.policy_dict['LN']
elif policy == 'LN-FL':
return self._frontloadPolicy(
(self.max_deg - self.G.degree(u) + 1) * self.policy_dict['LN'],
(self.max_deg - self.min_deg + 1) * self.policy_dict['LN'], time)
elif policy == 'LRSR':
# lambda ~ 1/rank
# where rank is order of largest reduction in spectral radius of A
intensity, _ = self._compute_LRSR_lambda(u, time)
return intensity
elif policy == 'LRSR-FL':
intensity, max_intensity = self._compute_LRSR_lambda(u, time)
return self._frontloadPolicy(
intensity, max_intensity, time)
elif policy == 'MCM':
# lambda ~ 1/rank
# where rank is MCM heuristic ranking
intensity, _ = self._compute_MCM_lambda(u, time)
return intensity
elif policy == 'MCM-FL':
intensity, max_intensity = self._compute_MCM_lambda(u, time)
return self._frontloadPolicy(
intensity, max_intensity, time)
elif policy == 'SOC':
return self._compute_SOC_lambda(u, time)
else:
raise KeyError('Invalid policy code.')
def launch_epidemic(self, init_event_list, max_time=np.inf, policy='NO', policy_dict={}, stop_criteria=None):
"""
Run the epidemic, starting from initial event list, for at most `max_time` units of time
"""
self._init_run(init_event_list, max_time)
self.policy = policy
self.policy_dict = policy_dict
# Set SOC control parameters
# TODO: Handle policy parameters better
if policy == 'SOC':
self.eta = policy_dict['eta']
self.q_x = policy_dict['q_x']
self.q_lam = policy_dict['q_lam']
if policy_dict.get('lpsolver') in self.AVAILABLE_LPSOLVERS:
self.lpsolver = policy_dict['lpsolver']
else:
raise ValueError("Invalid `lpsolver`")
time = 0.0
while self.queue:
# Get the next event to process
(u, event_type, w), time = self.queue.pop_priority()
# Update queue cache
if event_type == 'inf':
self.infs_in_queue -= 1
elif event_type == 'rec':
self.recs_in_queue -= 1
elif event_type == 'tre':
self.tres_in_queue -= 1
# Get node index
u_idx = self.node_to_idx[u]
# Stop at the end of the observation window
if time > self.max_time:
time = self.max_time
break
# Process the event
# Check validity of infection event (node u is not infected yet)
if (event_type == 'inf') and (not self.is_inf[u_idx]):
assert self.is_sus[u_idx], f"Node `{u}` should be susceptible to be infected"
w_idx = self.node_to_idx[w]
if self.initial_seed[u_idx] or (not self.is_rec[w_idx]):
self._process_infection_event(u, time, w)
# Check validity of recovery event (node u is not recovered yet)
elif (event_type == 'rec') and (not self.is_rec[u_idx]):
assert self.is_inf[u_idx], f"Node `{u}` should be infected to be recovered"
self._process_recovery_event(u, time)
# Check validity of treatement event (node u is not treated yet, and not recovered)
elif (event_type == 'tre') and (not self.is_tre[u_idx]) and (not self.is_rec[u_idx]):
assert self.is_inf[u_idx], f"Node `{u}` should be infected to be treated"
self._process_treatment_event(u, time)
# If no-one is infected, the epidemic is finished. Stop the simulation.
if np.sum(self.is_inf * (1 - self.is_rec)) == 0:
break
if stop_criteria:
if stop_criteria(self):
break
# Update Control for infected nodes still untreated
if not self.max_interventions_reached:
controlled_nodes = np.where(self.is_inf * (1 - self.is_rec) * (1 - self.is_tre))[0]
if self.policy == 'SOC':
self._update_LP_sol()
for u_idx in controlled_nodes:
self._control(self.idx_to_node[u_idx], time, policy=self.policy)
self.max_total_control_intensity = max(
self.max_total_control_intensity, self.old_lambdas.sum())
self._printer.print(self, time)
self._printer.println(self, time)
# Free memory
del self.queue
def calculate_counterfactual(self, new_beta, G_prime, intervention_time = 0, vaccine = np.array([])):
"""
Calculate the counterfactual infection and recovery times for the given intervention.
input:
new_beta: the new beta value (for the vaccination policy)
G_prime: the graph for which the counterfactual infection and recovery times are calculated
(Can be the same as initial graph or different in graph isolation)
intervention_time: the time at which the intervention is assumed to have happened
vaccine: the vaccine policy
"""
# initializing the counterfactual queue with seed nodes
seeds = []
seed_priorities = []
processed = np.zeros(self.n_nodes, dtype = 'bool')
for node in self.G.nodes():
node_id = self.node_to_idx[node]
if self.inf_occured_at[node_id] <= intervention_time and self.rec_occured_at[node_id] > intervention_time:
seeds.append(node)
seed_priorities.append(intervention_time)
if self.rec_occured_at[node_id] <= intervention_time:
processed[node_id] = True
seeds = np.array(seeds)
seed_priorities = np.array(seed_priorities)
self.cf_queue = PriorityQueue(initial=seeds, priorities=seed_priorities)
general_processed = copy.deepcopy(processed)
beta_max = max(self.beta, new_beta)
# initializing the counterfactual infection and recovery times
self.cf_inf = np.zeros(self.n_nodes)
self.cf_rec = np.zeros(self.n_nodes)
self.cf_infector = np.full(self.n_nodes, -1)
for node in self.G.nodes():
node_id = self.node_to_idx[node]
if general_processed[node_id]:
self.cf_inf[node_id] = self.inf_occured_at[node_id]
self.cf_rec[node_id] = self.rec_occured_at[node_id]
else:
if node in seeds:
self.cf_inf[node_id] = intervention_time
self.cf_rec[node_id] = self.cf_inf[node_id] + self.expo(self.delta)
else:
self.cf_inf[node_id]= np.inf
# processing nodes in order and calculate their counterfctual inf and rec times
while self.cf_queue:
node, _ = self.cf_queue.pop_priority()
node_id = self.node_to_idx[node]
if not processed[node_id]:
processed[node_id] = True
for neighbor in G_prime.neighbors(node):
neighbor_id = self.node_to_idx[neighbor]
if general_processed[neighbor_id]:
continue
def gamma2(t):
if len(vaccine) == 0 or vaccine[neighbor_id] == True:
g_beta = new_beta
else:
g_beta = self.beta
if t == self.cf_inf[node_id] or t == self.cf_rec[node_id]: return 1 * g_beta
return g_beta * np.heaviside(t - self.cf_inf[node_id], 1) - g_beta * np.heaviside(t - self.cf_rec[node_id], 0)
if self.infector[neighbor_id] == node_id:
t_inf = self.inf_occured_at[neighbor_id]
def gamma1(t):
if t == self.cf_inf[node_id] or t == min(self.cf_rec[node_id], t_inf): return 1 * self.beta
return self.beta * np.heaviside(t - self.cf_inf[node_id], 1) - self.beta * np.heaviside(t - min(self.cf_rec[node_id], t_inf), 0)
############# CF
gamma_bar = lambda t: beta_max - gamma1(t)
t_rejected= thinning_T(start=self.cf_inf[node_id], intensity=gamma_bar, lambda_max=beta_max, T= min(self.cf_rec[node_id], t_inf))
sample, lambdas, indicators = combine(np.array([t_inf]), np.array([gamma1(t_inf)]), t_rejected, gamma1)
counterfactuals, _ = sample_counterfactual(sample, lambdas, beta_max, indicators, gamma2)
############# CF
if len(counterfactuals) != 0:
t_cf_inf = np.min(counterfactuals)
else:
t_cf_inf = np.inf
else:
H = thinning_T(start=0, intensity=gamma2, lambda_max=beta_max, T= self.max_time)
if len(H) != 0:
t_cf_inf = np.min(H)
else:
t_cf_inf = np.inf
if self.cf_inf[neighbor_id] > t_cf_inf:
self.cf_inf[neighbor_id] = t_cf_inf
self.cf_rec[neighbor_id] = t_cf_inf + self.expo(self.delta)
self.cf_infector[neighbor_id] = node_id
self.cf_queue.push(neighbor, priority=t_cf_inf)
for rec_time in self.cf_rec:
if rec_time > self.max_time:
rec_time = np.inf
# check this later.
for node in self.G.nodes():
if node in seeds:
node_id = self.node_to_idx[node]
self.cf_inf[node_id] = self.inf_occured_at[node_id]