forked from niemasd/GEMF
-
Notifications
You must be signed in to change notification settings - Fork 0
/
GEMF_FAVITES.py
executable file
·484 lines (398 loc) · 19.6 KB
/
GEMF_FAVITES.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
#! /usr/bin/env python3
'''
User-friendly GEMF wrapper for use in FAVITES (or elsewhere).
Niema Moshiri 2022
'''
# imports
from datetime import datetime
from json import dump as jdump
from os import chdir, getcwd, makedirs
from os.path import abspath, expanduser, isdir, isfile
import argparse
import random
import subprocess
import sys
# useful variables
VERSION = '1.0.4'
C_UINT_MAX = 4294967295
# defaults
DEFAULT_FN_GEMF_LOG = 'log.txt'
DEFAULT_FN_GEMF_NETWORK = 'network.txt'
DEFAULT_FN_GEMF_NODE2NUM = 'node2num.txt'
DEFAULT_FN_GEMF_OUT = 'output.txt'
DEFAULT_FN_GEMF_PARA = 'para.txt'
DEFAULT_FN_GEMF_STATE2NUM = 'state2num.txt'
DEFAULT_FN_GEMF_STATUS = 'status.txt'
DEFAULT_FN_TRANSITION = 'all_state_transitions.txt'
DEFAULT_FN_TRANSMISSIONS_FAVITES = 'transmission_network.txt'
DEFAULT_GEMF_PATH = 'GEMF'
def get_time():
'''
Get current time
Returns:
`str`: Current time as `YYYY-MM-DD HH:MM:SS`
'''
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
def print_log(s='', end='\n'):
'''
Print to log
Args:
`s` (`str`): String to print
`end` (`str`): Line termination string
'''
tmp = "[%s] %s" % (get_time(), s)
print(tmp, end=end); sys.stdout.flush()
def parse_args():
'''
Parse user arguments
Returns:
`argparse.ArgumentParser`: Parsed user arguments
'''
# user runs with no args (place-holder if I want to add GUI in future)
if len(sys.argv) == 1:
pass
# parse user args
parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-c', '--contact_network', required=True, type=str, help="Contact Network (TSV)")
parser.add_argument('-s', '--initial_states', required=True, type=str, help="Initial States (TSV)")
parser.add_argument('-i', '--infected_states', required=True, type=str, help="Infected States (one per line)")
parser.add_argument('-r', '--rates', required=True, type=str, help="State Transition Rates (TSV)")
parser.add_argument('-t', '--end_time', required=True, type=float, help="End Time")
parser.add_argument('-o', '--output', required=True, type=str, help="Output Directory")
parser.add_argument('--max_events', required=False, type=int, default=C_UINT_MAX, help="Max Number of Events")
parser.add_argument('--output_all_transitions', action="store_true", help="Output All Transition Events (slower)")
parser.add_argument('--quiet', action="store_true", help="Suppress log messages")
parser.add_argument('--rng_seed', required=False, type=int, default=None, help="Random Number Generation Seed")
parser.add_argument('--gemf_path', required=False, type=str, default=DEFAULT_GEMF_PATH, help="Path to GEMF Executable")
args = parser.parse_args()
# convert local paths to absolute
args.contact_network = abspath(expanduser(args.contact_network))
args.initial_states = abspath(expanduser(args.initial_states))
args.infected_states = abspath(expanduser(args.infected_states))
args.rates = abspath(expanduser(args.rates))
args.output = abspath(expanduser(args.output))
return args
def check_args(args):
'''
Check user argumentss for validity
Args:
`args` (`argparse.ArgumentParser`): Parsed user arguments
'''
# check that input files exist
for fn in [args.contact_network, args.initial_states, args.infected_states, args.rates]:
if not isfile(fn):
raise ValueError("File not found: %s" % fn)
# check that end time is positive
if args.end_time <= 0:
raise ValueError("End time must be positive: %s" % args.end_time)
# check that output directory doesn't already exist
if isdir(args.output) or isfile(args.output):
raise ValueError("Output directory exists: %s" % args.output)
# check that RNG seed is non-negative
if args.rng_seed is not None:
if args.rng_seed < 0:
raise ValueError("RNG seed must be positive: %d" % args.rng_seed)
random.seed(args.rng_seed)
def prepare_outdir(outdir, para_fn=DEFAULT_FN_GEMF_PARA, network_fn=DEFAULT_FN_GEMF_NETWORK, node2num_fn=DEFAULT_FN_GEMF_NODE2NUM, status_fn=DEFAULT_FN_GEMF_STATUS, state2num_fn=DEFAULT_FN_GEMF_STATE2NUM, transition_fn=DEFAULT_FN_TRANSITION, transmission_fn=DEFAULT_FN_TRANSMISSIONS_FAVITES, output_transitions=False):
'''
Prepare GEMF output directory
Args:
`outdir` (`str`): Path to output directory
`para_fn` (`str`): File name of GEMF parameter file
`network_fn` (`str`): File name of GEMF network file
`node2num_fn` (`str`): File name of "node label to GEMF number" mapping file
`status_fn` (`str`): File name of GEMF status file
`state2num_fn` (`str`): File name of "state label to GEMF number" mapping file
`transition_fn` (`str`): File name of output "all simulation state transitions" file
`transmission_fn` (`str`): File name of output FAVITES-format transmission network
Returns:
`file`: Write-mode file object to GEMF parameter file
`file`: Write-mode file object to GEMF network file
`file`: Write-mode file object to "GEMF to original node label" mapping file
`file`: Write-mode file object to GEMF status file
`file`: Write-mode file object to "state label to GEMF number" mapping file
`file`: Write-mode file object to output FAVITES-format transmission network
'''
makedirs(outdir)
para_f = open('%s/%s' % (outdir, para_fn), 'w')
network_f = open('%s/%s' % (outdir, network_fn), 'w')
node2num_f = open('%s/%s' % (outdir, node2num_fn), 'w')
status_f = open('%s/%s' % (outdir, status_fn), 'w')
state2num_f = open('%s/%s' % (outdir, state2num_fn), 'w')
if output_transitions:
transition_f = open('%s/%s' % (outdir, transition_fn), 'w')
else:
transition_f = None
transmission_f = open('%s/%s' % (outdir, transmission_fn), 'w')
return para_f, network_f, node2num_f, status_f, state2num_f, transition_f, transmission_f
def create_gemf_network(contact_network_fn, network_f, node2num_f):
'''
Load contact network and convert to GEMF network format
Args:
`contact_network_fn` (`str`): Path to input contact network (FAVITES format)
`network_f` (`file`): Write-mode file object to GEMF network file
`node2num_f` (`file`): Write-mode file object to "node label to GEMF number" mapping file
Returns:
`dict`: A mapping from node label to node number
`list`: A mapping from node number to node label
'''
node2num = dict(); num2node = [None] # None is dummy (GEMF starts node numbers at 1)
for l in open(contact_network_fn):
# skip empty and header lines
if len(l) == 0 or l[0] == '#' or l[0] == '\n':
continue
# parse NODE line
if l.startswith('NODE'):
dummy, u, a = l.split('\t'); u = u.strip()
if u in node2num:
raise ValueError("Duplicate node encountered in contact network file: %s" % u)
node2num[u] = len(num2node); num2node.append(u)
# parse EDGE line
elif l.startswith('EDGE'):
dummy, u, v, a, d_or_u = l.split('\t'); u = u.strip(); v = v.strip(); d_or_u = d_or_u.strip()
if d_or_u != 'd' and d_or_u != 'u':
raise ValueError("Last column of contact network EDGE row must be exactly d or u")
try:
u_num = node2num[u]
except KeyError:
raise ValueError("Node found in EDGE section but not in NODE section: %s" % u)
try:
v_num = node2num[v]
except KeyError:
raise ValueError("Node found in EDGE section but not in NODE section: %s" % v)
network_f.write('%d\t%d\n' % (u_num, v_num))
if d_or_u == 'u':
network_f.write('%d\t%d\n' % (v_num, u_num))
# non-comment and non-empty lines must start with NODE or EDGE
else:
raise ValueError("Invalid contact network file: %s" % contact_network_fn)
# finish up and return
jdump(node2num, node2num_f); node2num_f.close(); network_f.close()
return node2num, num2node
def create_gemf_status(initial_states_fn, status_f, node2num):
'''
Load initial states and convert to GEMF status format
Args:
`initial_states_fn` (`str`): Path to initial states file (FAVITES format)
`status_f` (`file`): Write-mode file object to GEMF status file
`node2num` (`dict`): A mapping from node label to node number (for validity checking)
Returns:
`dict`: A mapping from state label to state number
`list`: A mapping from state number to state label
'''
state2num = dict(); num2state = list()
for l in open(initial_states_fn):
# skip empty and header lines
if len(l) == 0 or l[0] == '#' or l[0] == '\n':
continue
u, s = l.split('\t'); u = u.strip(); s = s.strip()
try:
u_num = node2num[u]
except KeyError:
raise ValueError("Encountered node in inital states file that is not in contact network file: %s" % u)
try:
s_num = state2num[s]
except KeyError:
s_num = len(num2state); state2num[s] = s_num; num2state.append(s)
status_f.write('%d\n' % s_num)
# finish up and return
status_f.close()
return state2num, num2state
def create_gemf_para(rates_fn, end_time, max_events, network_fn, status_fn, out_fn, para_f, state2num_f, state2num, num2state, rng_seed=None):
'''
Load transition rates and convert to GEMF para format
Args:
`rates_fn` (`str`): Path to transition rates file
`end_time` (`float`): Simulation end time
`max_events` (`int`): Max number of transition events
`network_fn` (`str`): File name of GEMF network file
`status_fn` (`str`): File name of GEMF status file
`out_fn` (`str`): File name of GEMF output file
`para_f` (`file`): Write-mode file object to GEMF parameter file
`state2num_f` (`file`): Write-mode file object to "state label to GEMF number" mapping file
`state2num` (`dict`): A mapping from state label to state number
`num2state` (`list`): A mapping from state number to state label
`rng_seed` (`int`): Seed for random number generation
Returns:
`dict`: Transition rates, where `RATE[x][y][z]` denotes the rate of the transition from `y` to `z` caused by `x` (state numbers, not labels)
`list`: Sorted list of inducer state numbers
'''
# load transition rates
RATE = dict() # RATE[by_state][from_state][to_state] = transition rate (by_state == None means nodal transition)
INDUCERS = set() # state numbers of inducer states
for l in open(rates_fn):
if len(l) == 0 or l[0] == '#' or l[0] == '\n':
continue
from_s, to_s, by_s, r = l.split('\t'); from_s = from_s.strip(); to_s = to_s.strip(); by_s = by_s.strip(); r = float(r)
try:
from_s_num = state2num[from_s]
except KeyError:
from_s_num = len(num2state); state2num[from_s] = from_s_num; num2state.append(from_s)
try:
to_s_num = state2num[to_s]
except KeyError:
to_s_num = len(num2state); state2num[to_s] = to_s_num; num2state.append(to_s)
if by_s.lower() == 'none':
by_s = None; by_s_num = None
else:
try:
by_s_num = state2num[by_s]
except KeyError:
by_s_num = len(num2state); state2num[by_s] = by_s_num; num2state.append(by_s)
INDUCERS.add(by_s_num)
if by_s_num not in RATE:
RATE[by_s_num] = dict()
if from_s_num not in RATE[by_s_num]:
RATE[by_s_num][from_s_num] = dict()
if to_s_num in RATE[by_s_num][from_s_num]:
raise ValueError("Duplicate transition encountered: from '%s' to '%s' by '%s'" % (from_s, to_s, by_s))
RATE[by_s_num][from_s_num][to_s_num] = r
jdump(state2num, state2num_f); state2num_f.close(); NUM_STATES = len(state2num)
# write nodal transition matrix (by_state == None)
para_f.write("[NODAL_TRAN_MATRIX]\n")
for s in range(NUM_STATES):
if s in RATE[None]:
rates = [str(RATE[None][s][s_to]) if s_to in RATE[None][s] else '0' for s_to in range(NUM_STATES)]
else:
rates = ['0']*NUM_STATES
para_f.write("%s\n" % '\t'.join(rates))
para_f.write('\n')
# write edged transition matrix (by_state != None)
INDUCERS = sorted(INDUCERS)
para_f.write("[EDGED_TRAN_MATRIX]\n")
for s_by in INDUCERS:
for s_from in range(NUM_STATES):
if s_from in RATE[s_by]:
rates = [str(RATE[s_by][s_from][s_to]) if s_to in RATE[s_by][s_from] else '0' for s_to in range(NUM_STATES)]
else:
rates = ['0']*NUM_STATES
para_f.write("%s\n" % '\t'.join(rates))
para_f.write('\n')
# write remaining sections of parameter file
para_f.write("[STATUS_BEGIN]\n0\n\n")
para_f.write("[INDUCER_LIST]\n%s\n\n" % ' '.join(str(s) for s in INDUCERS))
para_f.write("[SIM_ROUNDS]\n1\n\n")
para_f.write("[INTERVAL_NUM]\n1\n\n")
para_f.write("[MAX_TIME]\n%s\n\n" % end_time)
para_f.write("[MAX_EVENTS]\n%d\n\n" % max_events)
para_f.write("[DIRECTED]\n1\n\n")
para_f.write("[SHOW_INDUCER]\n1\n\n")
para_f.write("[DATA_FILE]\n%s\n\n" % '\n'.join([network_fn.split('/')[-1]]*len(INDUCERS)))
para_f.write("[STATUS_FILE]\n%s\n\n" % status_fn.split('/')[-1])
if rng_seed is not None:
para_f.write("[RANDOM_SEED]\n%d\n\n" % rng_seed)
para_f.write("[OUT_FILE]\n%s\n" % out_fn.split('/')[-1])
para_f.close()
return RATE, INDUCERS
def run_gemf(outdir, log_fn, gemf_path=DEFAULT_GEMF_PATH):
'''
Run GEMF
Args:
`outdir` (`str`): Path to output directory
'''
orig_dir = getcwd()
chdir(outdir)
log_f = open(log_fn, 'w'); subprocess.call([gemf_path], stdout=log_f); log_f.close()
chdir(orig_dir)
return log_f
def convert_transmissions_to_favites(infected_states_fn, status_fn, out_fn, transition_f, transmission_f, num2node, node2num, num2state, state2num, RATE, INDUCERS):
'''
Convert GEMF transmission network to FAVITES format
Args:
`infected_states_fn` (`str`): Path to infected states file
`status_fn` (`str`): Path to GEMF status file
`out_fn` (`str`): Path to GEMF output file
`transition_f` (`file`): Write-mode file object to "all simulation state transitions" file
`transmission_f` (`file`): Write-mode file object to output FAVITES-format transmission network
`state2num` (`dict`): A mapping from state label to state number
`RATE` (`dict`): Transition rates, where `RATE[x][y][z]` denotes the rate of the transition from `y` to `z` caused by `x` (state numbers, not labels)
`INDUCERS` (`list`): Sorted list of inducer state numbers
'''
# load and check infected states
infected_states = {l.strip() for l in open(infected_states_fn)}
for s in infected_states:
if s not in state2num:
raise ValueError("Encountered state in infectious states file that didn't appear in rates or initial states files: %s" % s)
infected_states = {state2num[s] for s in infected_states}
# write seeds to output FAVITES file
for u_num, s_num_s in enumerate(open(status_fn)):
u = num2node[int(u_num)+1]; s_num = int(s_num_s)
if s_num in infected_states:
transmission_f.write("None\t%s\t0\n" % u)
if transition_f is not None:
transition_f.write("%s\tNone\t%s\t0\n" % (u, num2state[s_num]))
# convert GEMF output to FAVITES format
INDUCER_STATES = [None] + INDUCERS
for l in open(out_fn):
# parse easy components
parts = l.split(' ')
t = float(parts[0]) # time of current transition event
rate = float(parts[1]) # total rate of ALL state transitions in the network
v_num = int(parts[2]) # number of individual who transitioned
v = num2node[v_num] # name of individual who transitioned
from_s_num = int(parts[3]) # number of individual's previous state
to_s_num = int(parts[4]) # number of individual's current state
to_s = num2state[to_s_num]
if transition_f is not None:
from_s = num2state[from_s_num]
transition_f.write('%s\t%s\t%s\t%s\n' % (v, from_s, to_s, t))
if from_s_num in infected_states or to_s_num not in infected_states:
continue # only write inducer to transmission file if v went to infected state
# parse inducer lists: inducers[0] = nodal transition, inducers[1] = first inducer state, inducers[2] = second inducer state, etc.
inducers = [[int(tmp) for tmp in inds.split(',') if len(tmp) != 0] for inds in parts[-1].rstrip().lstrip('[').rstrip(']').split('],[')]
inducer_state_rates = [(RATE[INDUCER_STATES[i]][from_s_num][to_s_num] * len(u_nums), i) for i, u_nums in enumerate(inducers) if len(u_nums) != 0]
by_s_inducer_ind = roll_die(inducer_state_rates)[1]
by_s_num = INDUCER_STATES[by_s_inducer_ind]
if INDUCER_STATES[by_s_inducer_ind] is None:
transmission_f.write("None\t%s\t%s\n" % (num2node[v_num], t))
else:
u = num2node[random.choice(inducers[by_s_inducer_ind])]
transmission_f.write("%s\t%s\t%s\n" % (u, v, t))
# finish up
transmission_f.close()
if transition_f is not None:
transition_f.close()
def roll_die(faces):
'''
Roll a multi-faced die
Args:
`faces` (`list`): Die faces as `(prob, label)` `tuple`s
Returns:
`tuple`: The `(prob, label)` die face that succeeded
'''
face_tot = sum(p for p, s in faces)
faces = [(p/face_tot,s) for p, s in faces]
x = random.random(); tot = 0.
for face in faces:
if x <= face[0]:
return face
return faces[-1]
def main():
'''
Main function
'''
if len(sys.argv) > 1 and sys.argv[1].lower().lstrip('-') == 'version':
print("GEMF_FAVITES v%s" % VERSION); exit()
args = parse_args(); check_args(args)
if not args.quiet:
print_log("Running GEMF_FAVITES v%s" % VERSION)
print_log("Preparing output directory: %s" % args.output)
para_f, network_f, node2num_f, status_f, state2num_f, transition_f, transmission_f = prepare_outdir(args.output, output_transitions=args.output_all_transitions)
if not args.quiet:
print_log("Creating GEMF network file...")
node2num, num2node = create_gemf_network(args.contact_network, network_f, node2num_f) # closes network_f and node2num_f
if not args.quiet:
print_log("Creating GEMF status file...")
state2num, num2state = create_gemf_status(args.initial_states, status_f, node2num) # closes status_f
if not args.quiet:
print_log("Creating GEMF parameter file...")
RATE, INDUCERS = create_gemf_para(args.rates, args.end_time, args.max_events, network_f.name, status_f.name, DEFAULT_FN_GEMF_OUT, para_f, state2num_f, state2num, num2state, args.rng_seed) # closes para_f and state2num_f
if not args.quiet:
print_log("Running GEMF...")
log_f = run_gemf(args.output, DEFAULT_FN_GEMF_LOG, args.gemf_path) # closes log_f
if not args.quiet:
print_log("Converting GEMF output to FAVITES format...")
convert_transmissions_to_favites(args.infected_states, status_f.name, '%s/%s' % (args.output, DEFAULT_FN_GEMF_OUT), transition_f, transmission_f, num2node, node2num, num2state, state2num, RATE, INDUCERS) # closes transition_f and transmission_f
# execute main function
if __name__ == "__main__":
main()