-
Notifications
You must be signed in to change notification settings - Fork 1
/
triplane_cls.py
99 lines (86 loc) · 3.11 KB
/
triplane_cls.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import hydra
from omegaconf import OmegaConf
import wandb
from pathlib import Path
import importlib
import pytorch_lightning as pl
from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.loggers import WandbLogger
from datamodules.datamodule import load_datamodule
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
import time
code_pos = Path(__file__)
@hydra.main(config_path="cfg/triplanes_classify", config_name="main", version_base="1.1")
def main(cfg) -> None:
print("trainers/" + cfg.trainer.module_name)
trainer_module = importlib.import_module("trainers." + cfg.trainer.module_name)
print(OmegaConf.to_yaml(cfg, resolve=True))
wandb.config = OmegaConf.to_container(
cfg, resolve=True, throw_on_missing=True
)
logger = WandbLogger(
entity=cfg.wandb.entity,
project=cfg.wandb.project,
name=cfg.wandb.run_name,
dir=cfg.wandb.dir,
save_dir=cfg.wandb.dir,
config=OmegaConf.to_container(
cfg, resolve=True, throw_on_missing=True
)
)
artifact = wandb.Artifact("Trainer", type="code")
artifact.add_file("trainers/" + cfg.trainer.module_name + ".py")
logger.experiment.log_artifact(artifact)
ckpt_callback = ModelCheckpoint(
dirpath= Path(cfg.wandb.dir) / cfg.wandb.run_name / "ckpts",
filename="{epoch}-{step}",
monitor="val/acc",
mode="max",
save_top_k=1,
save_last=False,
)
lr_monitor = LearningRateMonitor(logging_interval='step')
trainer = pl.Trainer(
strategy=DDPStrategy(
find_unused_parameters=cfg.runtime.get("find_unused_parameters", True)
)
if cfg.runtime.gpus > 1
else "auto",
precision=cfg.runtime.precision,
benchmark=True,
logger=logger,
callbacks=[ckpt_callback, lr_monitor],
max_epochs=cfg.train.num_epochs,
check_val_every_n_epoch=cfg.val.checkpoint_period,
log_every_n_steps=200,
)
dm = load_datamodule(
cfg
)
dl_train = dm.train_dataloader()
dl_val = dm.val_dataloader()
cfg.network.embedding_dim = cfg.train_transform.random_crop if cfg.train_transform.random_crop>0 else cfg.network.embedding_dim
network = hydra.utils.instantiate(cfg.network)
print("num_params:", sum(p.numel() for p in network.parameters()))
loss = hydra.utils.instantiate(cfg.loss)
optimizer = hydra.utils.instantiate(cfg.opt, network.parameters())
scheduler = hydra.utils.instantiate(cfg.scheduler, optimizer, cfg.opt.lr, total_steps=len(dl_train) * cfg.train.num_epochs)
model = trainer_module.TrainModel(
network=network,
loss=loss,
optimizer=optimizer,
scheduler=scheduler,
cfg=cfg
# train_kwargs=train_args["params"],
# model_kwargs=model_args,
)
trainer.fit(
model,
train_dataloaders=dl_train,
val_dataloaders=dl_val,
)
dl_test = dm.test_dataloader()
trainer.test(ckpt_path="best", dataloaders=dl_test)
wandb.finish()
if __name__ == "__main__":
main()