-
Notifications
You must be signed in to change notification settings - Fork 5
/
expand_tfrecords_mpi.py
executable file
·51 lines (39 loc) · 1.5 KB
/
expand_tfrecords_mpi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#!/usr/bin/env python3
#
# Copyright (c) 2018 Dell Inc., or its subsidiaries. All Rights Reserved.
#
# Written by Claudio Fahey <[email protected]>
#
"""
This script will make multiple copies of the original TFRecord training files.
"""
import os
import argparse
from os.path import join, basename, splitext
from shutil import copyfile
def worker(rank, size, input_dir, output_dir, num_copies):
num_files = 1024
# use round-robin scheduling
i = rank
while (i < num_files):
in_file_name = join(input_dir, 'train-%05d-of-%05d' % (i, num_files))
for copy in range(num_copies):
out_file_name = join(output_dir, 'train-%05d-of-%05d-copy-%05d' % (i, num_files, copy))
print('%s => %s' % (in_file_name, out_file_name))
copyfile(in_file_name, out_file_name)
i += size
return
def main():
parser = argparse.ArgumentParser(description='')
parser.add_argument('-i','--input_dir', help='Input directory', required=True)
parser.add_argument('-o','--output_dir', help='Output directory', required=True)
parser.add_argument('-n','--num_copies', type=int, help='Number of copies', required=True)
args = vars(parser.parse_args())
input_dir = args['input_dir']
output_dir = args['output_dir']
num_copies = args['num_copies']
rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
worker(rank, size, input_dir, output_dir, num_copies)
if __name__ == '__main__':
main()