-
Notifications
You must be signed in to change notification settings - Fork 3
/
trainer.py
77 lines (65 loc) · 2.65 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import json
import os
import sys
QRELS_GUEST_PATH = '/output/qrels/qrels.qrel'
MODELS_GUEST_PATH = '/output'
class Trainer:
def __init__(self, trainer_config=None):
self.config = trainer_config
def set_config(self, trainer_config):
self.config = trainer_config
def train(self, client, topic_path_guest, test_split_path_guest,
validation_split_path_guest, generate_save_tag):
"""
Performs training
"""
save_tag = generate_save_tag(self.config.tag, self.config.load_from_snapshot)
exists = len(client.images.list(filters={"reference": "{}:{}".format(self.config.repo, save_tag)})) != 0
if not exists:
sys.exit("Must prepare image first...")
volumes = {
os.path.abspath(self.config.model_folder): {
"bind": MODELS_GUEST_PATH,
"mode": "rw"
},
os.path.abspath(self.config.topic): {
"bind": os.path.join(topic_path_guest, os.path.basename(self.config.topic)),
"mode": "ro"
},
os.path.abspath(self.config.qrels): {
"bind": QRELS_GUEST_PATH,
"mode": "ro"
},
os.path.abspath(self.config.test_split): {
"bind": test_split_path_guest,
"mode": "ro"
},
os.path.abspath(self.config.validation_split): {
"bind": validation_split_path_guest,
"mode": "ro"
}
}
train_args = {
"collection": {
"name": self.config.collection
},
"opts": {key: value for (key, value) in map(lambda x: x.split("="), self.config.opts)},
"topic": {
"path": os.path.join(topic_path_guest, os.path.basename(self.config.topic)),
"format": self.config.topic_format
},
"qrels": {
"path": QRELS_GUEST_PATH
},
"model_folder": {
"path": MODELS_GUEST_PATH
}
}
runtime = "nvidia" if self.config.gpu else "runc"
print("Starting container from saved image...")
container = client.containers.run("{}:{}".format(self.config.repo, save_tag),
command="sh -c '/train --json {}'".format(json.dumps(json.dumps(train_args))),
volumes=volumes, detach=True, runtime=runtime)
print("Logs for training in container with ID {}...".format(container.id))
for line in container.logs(stream=True):
print(str(line.decode('utf-8')), end="")