forked from PruneTruong/DenseMatching
-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_training.py
74 lines (57 loc) · 2.57 KB
/
run_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
import argparse
import importlib
import cv2 as cv
import torch
import torch.backends.cudnn
import random
import numpy as np
from shutil import copyfile
from datetime import date
import admin.settings as ws_settings
def run_training(train_module, train_name, seed, cudnn_benchmark=True):
"""Run a train scripts in train_settings.
args:
train_module: Name of module in the "train_settings/" folder.
train_name: Name of the train settings file.
cudnn_benchmark: Use cudnn benchmark or not (default is True).
"""
# This is needed to avoid strange crashes related to opencv
cv.setNumThreads(0)
torch.backends.cudnn.benchmark = cudnn_benchmark
# dd/mm/YY
today = date.today()
d1 = today.strftime("%d/%m/%Y")
print('Training: {} {}\nDate: {}'.format(train_module, train_name, d1))
settings = ws_settings.Settings()
settings.module_name = train_module
settings.script_name = train_name
settings.project_path = 'train_settings/{}/{}'.format(train_module, train_name)
settings.seed = seed
# will save the checkpoints there
save_dir = os.path.join(settings.env.workspace_dir, settings.project_path)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
copyfile(settings.project_path + '.py', os.path.join(save_dir, settings.script_name + '.py'))
expr_module = importlib.import_module('train_settings.{}.{}'.format(train_module.replace('/', '.'),
train_name.replace('/', '.')))
expr_func = getattr(expr_module, 'run')
expr_func(settings)
def main():
parser = argparse.ArgumentParser(description='Run a train scripts in train_settings.')
parser.add_argument('train_module', type=str, help='Name of module in the "train_settings/" folder.')
parser.add_argument('train_name', type=str, help='Name of the train settings file.')
parser.add_argument('--cudnn_benchmark', type=bool, default=True,
help='Set cudnn benchmark on (1) or off (0) (default is on).')
parser.add_argument('--seed', type=int, default=1992, help='Pseudo-RNG seed')
args = parser.parse_args()
# args.seed = random.randint(0, 3000000)
args.seed = torch.initial_seed() & (2 ** 32 - 1)
print('Seed is {}'.format(args.seed))
random.seed(int(args.seed))
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
run_training(args.train_module, args.train_name, cudnn_benchmark=args.cudnn_benchmark, seed=args.seed)
if __name__ == '__main__':
main()