-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
71 lines (58 loc) · 2.67 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
"""Module with training of model"""
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers.neptune import NeptuneLogger
from modules.config import DataConfig, ModelConfig
from modules.model.model_lightning import ClassificationModel
if __name__ == '__main__':
torch.cuda.empty_cache()
data_config = DataConfig()
model_config = ModelConfig()
hparams = {
'learning_rate': 0.001,
'n_classes': data_config.n_classes,
'max_epochs': 150,
'batch_size': 15,
'model_name': 'efficientnet-b1',
'width': 2048,
'size': (128, 512)
}
model = ClassificationModel(n_classes=data_config.n_classes,
file_path=data_config.dataset_path,
batch_size=hparams['batch_size'],
hparams=hparams,
model_name=hparams['model_name'],
width=hparams['width'],
size=hparams['size'])
# model = ClassificationModel.load_from_checkpoint(checkpoint_path='5_epoch_effnet0.ckpt',
# n_classes=data_config.n_classes,
# file_path=data_config.dataset_path,
# batch_size=hparams['batch_size'],
# hparams=hparams,
# model_name=hparams['model_name'],
# width=hparams['width'],
# size=hparams['size'])
checkpoing_callback = ModelCheckpoint(
filepath=model_config.weights_folder,
save_top_k=True,
verbose=True,
monitor='val_loss',
mode='min',
prefix=hparams['model_name']
)
neptune_logger = NeptuneLogger(
api_key='eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vdWkubmVwdHVuZS5haSIs'
'ImFwaV91cmwiOiJodHRwczovL3VpLm5lcHR1bmUuYWkiLCJhcGlfa'
'2V5IjoiMTIyODQyZGUtNTdiMS00MDBlLWEzZmYtMzU0N2Q4MDViMjQ0In0=',
project_name='vadbeg/birds',
experiment_name=f'{hparams["model_name"]}, CrossEntropyLoss, width=2048',
params=hparams,
tags=['pytorch-lightning', 'birds']
)
trainer = Trainer(gpus=1, num_nodes=1,
checkpoint_callback=checkpoing_callback,
max_epochs=hparams['max_epochs'],
logger=neptune_logger)
trainer.fit(model=model)
trainer = Trainer()