-
Notifications
You must be signed in to change notification settings - Fork 4
/
inference.py
81 lines (69 loc) · 2.55 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from absl import app
from absl import flags
import os
from pathlib import Path
import logging
from tqdm import tqdm
import torch
from e3_layers.utils import build
from e3_layers import configs
from e3_layers.data import CondensedDataset, DataLoader, Batch
FLAGS = flags.FLAGS
flags.DEFINE_string("config", None, "The name of the config.")
flags.DEFINE_string("config_spec", 'eval', "Config specification, the argument of get_config().")
flags.DEFINE_string("output_path", "results.hdf5", "Path to the output file to create.")
flags.DEFINE_string("name", "default", "Name of the experiment.")
flags.DEFINE_string("model_path", None, "The name of the model checkpoint.")
flags.DEFINE_list("output_keys", [], "The output keys to save.")
flags.DEFINE_integer("seed", None, "The RNG seed.")
flags.DEFINE_integer(
"dataloader_num_workers", 4, "Number of workers per training process."
)
flags.DEFINE_boolean(
"equivariance_test",
False,
"Whether to test the equivariance of the neural network.",
)
flags.DEFINE_string("verbose", "INFO", "Logging verbosity.")
flags.mark_flags_as_required(["config"])
def evaluate(argv):
config_name = FLAGS.config
config = getattr(configs, config_name, None)
assert not config is None, f"Config {config_name} not found."
config = config(FLAGS.config_spec)
model = build(config.model_config)
device = torch.device('cuda')
model.to(device=device)
if FLAGS.model_path:
state_dict = torch.load(FLAGS.model_path, map_location=device)
model_state_dict = {}
for key, value in state_dict.items():
if key[:7] == 'module.': # remove DDP wrappers
key = key[7:]
model_state_dict[key] = value
model.load_state_dict(model_state_dict)
data_config = config.data_config
dataset = CondensedDataset(**data_config)
dl_kwargs = dict(
batch_size=config.batch_size,
num_workers=FLAGS.dataloader_num_workers,
pin_memory=(device != torch.device("cpu")),
# avoid getting stuck
timeout=(10 if FLAGS.dataloader_num_workers > 0 else 0)
)
dataloader = DataLoader(
dataset=dataset,
shuffle=False,
**dl_kwargs,
)
model.eval()
lst = []
for batch in tqdm(iter(dataloader)):
batch = batch.to(device)
out = model(batch.clone())
dic = {key: out[key] for key in FLAGS.output_keys}
lst.append(dic)
result = Batch.from_data_list(lst, attrs = batch.attrs)
result.dumpHDF5(FLAGS.output_path)
if __name__ == "__main__":
app.run(evaluate)