-
Notifications
You must be signed in to change notification settings - Fork 6
/
main.py
94 lines (84 loc) · 2.48 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import argparse
from datetime import datetime
import gin
from loguru import logger
from torch.utils.data import DataLoader
from utils.common import set_random_seed
from dataset.ray_dataset import RayDataset, ray_collate
from neural_field.model import get_model
from trainer import Trainer
@gin.configurable()
def main(
seed: int = 42,
num_workers: int = 0,
train_split: str = "train",
stages: str = "train_eval",
batch_size: int = 16,
model_name="Rip-NeRF",
):
set_random_seed(seed)
logger.info("==> Init dataloader ...")
train_dataset = RayDataset(split=train_split)
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
num_workers=num_workers,
shuffle=False,
collate_fn=ray_collate,
pin_memory=True,
worker_init_fn=None,
pin_memory_device='cuda',
prefetch_factor=2,
)
test_dataset = RayDataset(split='test')
test_loader = DataLoader(
test_dataset,
batch_size=None,
num_workers=1,
shuffle=False,
pin_memory=True,
worker_init_fn=None,
pin_memory_device='cuda',
)
logger.info("==> Init model ...")
model = get_model(model_name=model_name)(aabb=train_dataset.aabb)
logger.info(model)
logger.info("==> Init trainer ...")
trainer = Trainer(model, train_loader, eval_loader=test_loader)
if "train" in stages:
trainer.fit()
if "eval" in stages:
if "train" not in stages:
trainer.load_ckpt()
trainer.eval(save_results=True, rendering_channels=["rgb", "depth"])
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--ginc",
action="append",
help="gin config file",
)
parser.add_argument(
"--ginb",
action="append",
help="gin bindings",
)
args = parser.parse_args()
ginbs = []
if args.ginb:
ginbs.extend(args.ginb)
gin.parse_config_files_and_bindings(args.ginc, ginbs, finalize_config=False)
exp_name = gin.query_parameter("Trainer.exp_name")
exp_name = "%s/%s/%s/%s" % (
gin.query_parameter("RayDataset.scene_type"),
gin.query_parameter("RayDataset.scene"),
gin.query_parameter("main.model_name"),
(
datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
if exp_name is None
else exp_name
),
)
gin.bind_parameter("Trainer.exp_name", exp_name)
gin.finalize()
main()