-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_source.py
89 lines (67 loc) · 3.08 KB
/
train_source.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import argparse
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor
from tqdm import tqdm
import config
from models import Net
from utils import GrayscaleToRgb
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def create_dataloaders(batch_size):
dataset = MNIST(config.DATA_DIR/'mnist', train=True, download=True,
transform=Compose([GrayscaleToRgb(), ToTensor()]))
shuffled_indices = np.random.permutation(len(dataset))
train_idx = shuffled_indices[:int(0.8*len(dataset))]
val_idx = shuffled_indices[int(0.8*len(dataset)):]
train_loader = DataLoader(dataset, batch_size=batch_size, drop_last=True,
sampler=SubsetRandomSampler(train_idx),
num_workers=1, pin_memory=True)
val_loader = DataLoader(dataset, batch_size=batch_size, drop_last=False,
sampler=SubsetRandomSampler(val_idx),
num_workers=1, pin_memory=True)
return train_loader, val_loader
def do_epoch(model, dataloader, criterion, optim=None):
total_loss = 0
total_accuracy = 0
for x, y_true in tqdm(dataloader, leave=False):
x, y_true = x.to(device), y_true.to(device)
y_pred = model(x)
loss = criterion(y_pred, y_true)
if optim is not None:
optim.zero_grad()
loss.backward()
optim.step()
total_loss += loss.item()
total_accuracy += (y_pred.max(1)[1] == y_true).float().mean().item()
mean_loss = total_loss / len(dataloader)
mean_accuracy = total_accuracy / len(dataloader)
return mean_loss, mean_accuracy
def main(args):
train_loader, val_loader = create_dataloaders(args.batch_size)
model = Net().to(device)
optim = torch.optim.Adam(model.parameters())
lr_schedule = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, patience=1, verbose=True)
criterion = torch.nn.CrossEntropyLoss()
best_accuracy = 0
for epoch in range(1, args.epochs+1):
model.train()
train_loss, train_accuracy = do_epoch(model, train_loader, criterion, optim=optim)
model.eval()
with torch.no_grad():
val_loss, val_accuracy = do_epoch(model, val_loader, criterion, optim=None)
tqdm.write(f'EPOCH {epoch:03d}: train_loss={train_loss:.4f}, train_accuracy={train_accuracy:.4f} '
f'val_loss={val_loss:.4f}, val_accuracy={val_accuracy:.4f}')
if val_accuracy > best_accuracy:
print('Saving model...')
best_accuracy = val_accuracy
torch.save(model.state_dict(), 'trained_models/source.pt')
lr_schedule.step(val_loss)
if __name__ == '__main__':
arg_parser = argparse.ArgumentParser(description='Train a network on MNIST')
arg_parser.add_argument('--batch-size', type=int, default=64)
arg_parser.add_argument('--epochs', type=int, default=30)
args = arg_parser.parse_args()
main(args)