-
Notifications
You must be signed in to change notification settings - Fork 5
/
pretrain_pl.py
140 lines (124 loc) · 3.96 KB
/
pretrain_pl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import pathlib
import time
import click
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import DataLoader, random_split
from datasets import ModelNet40
from executor import MeshDataClassifierPL
@click.command()
@click.option('--train_dataset', help='The training dataset file path')
@click.option(
'--split_ratio',
default=0.8,
help='The proportion of training samples out of the whole training dataset',
)
@click.option('--eval_dataset', help='The evaluation dataset file path')
@click.option('--hidden_dim', default=1024, help='The dimension of the used models')
@click.option(
'--checkpoint_path',
type=click.Path(file_okay=True, path_type=pathlib.Path),
help='The path of checkpoint',
)
@click.option(
'--output_path',
type=click.Path(file_okay=True, path_type=pathlib.Path),
help='The path of output files',
)
@click.option('--model_name', default='pointnet', help='The model name')
@click.option('--batch_size', default=128, help='The size of each batch')
@click.option('--epochs', default=50, help='The epochs of training process')
@click.option('--use-gpu/--no-use-gpu', default=False, help='If True to use gpu')
@click.option(
'--devices', default=7, help='The number of gpus/tpus you can use for training'
)
@click.option('--seed', default=10, help='The random seed for reproducing results')
def main(
train_dataset,
split_ratio,
eval_dataset,
model_name,
hidden_dim,
batch_size,
epochs,
use_gpu,
checkpoint_path,
output_path,
devices,
seed,
):
seed = int(time.time())
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if use_gpu:
device = 'cuda'
else:
device = 'cpu'
if checkpoint_path:
model = MeshDataClassifierPL.load_from_checkpoint(
checkpoint_path, map_location=device
)
else:
model = MeshDataClassifierPL(
model_name=model_name,
device=device,
hidden_dim=hidden_dim,
batch_size=batch_size,
)
train_and_val_data = ModelNet40(train_dataset, seed=seed)
tot_len = len(train_and_val_data)
train_len = int(tot_len * split_ratio)
validate_len = tot_len - train_len
train_data, validate_data = random_split(
train_and_val_data, [train_len, validate_len]
)
test_data = ModelNet40(eval_dataset, seed=seed)
# drop_last=True, avoid batch=1 error from BatchNorm
train_loader = DataLoader(
train_data, batch_size=batch_size, shuffle=True, num_workers=8, drop_last=True
)
validate_loader = DataLoader(
validate_data,
batch_size=batch_size,
shuffle=False,
num_workers=8,
drop_last=True,
)
test_loader = DataLoader(
test_data, batch_size=batch_size, shuffle=False, num_workers=8
)
logger = TensorBoardLogger(
save_dir='./logs' if output_path is None else output_path,
log_graph=True,
name='{}_dim_{}_batch_{}_epochs_{}_seed_{}'.format(
model_name, hidden_dim, batch_size, epochs, seed
),
)
checkpoint_callback = ModelCheckpoint(
save_top_k=5,
monitor='val_loss',
mode='min',
filename='{epoch:02d}-{val_loss:.2f}-{val_acc:.4f}',
)
trainer = Trainer(
accelerator='gpu' if use_gpu else 'cpu',
devices=devices,
max_epochs=epochs,
check_val_every_n_epoch=1,
enable_checkpointing=True,
logger=logger,
callbacks=[checkpoint_callback],
gradient_clip_val=1.0,
)
model.train()
trainer.fit(model, train_loader, validate_loader)
print(checkpoint_callback.best_model_path)
model.eval()
print('Validation set:')
trainer.test(model, dataloaders=validate_loader)
print('Testing set:')
trainer.test(model, dataloaders=test_loader)
if __name__ == '__main__':
main()