forked from dvgodoy/PyTorchStepByStep
-
Notifications
You must be signed in to change notification settings - Fork 0
/
helpers.py
123 lines (103 loc) · 3.85 KB
/
helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import numpy as np
import torch
from torch.utils.data import random_split, WeightedRandomSampler
def make_train_step_fn(model, loss_fn, optimizer):
# Builds function that performs a step in the train loop
def perform_train_step_fn(x, y):
# Sets model to TRAIN mode
model.train()
# Step 1 - Computes our model's predicted output - forward pass
yhat = model(x)
# Step 2 - Computes the loss
loss = loss_fn(yhat, y)
# Step 3 - Computes gradients for both "a" and "b" parameters
loss.backward()
# Step 4 - Updates parameters using gradients and the learning rate
optimizer.step()
optimizer.zero_grad()
# Returns the loss
return loss.item()
# Returns the function that will be called inside the train loop
return perform_train_step_fn
def mini_batch(device, data_loader, step_fn):
mini_batch_losses = []
for x_batch, y_batch in data_loader:
x_batch = x_batch.to(device)
y_batch = y_batch.to(device)
mini_batch_loss = step_fn(x_batch, y_batch)
mini_batch_losses.append(mini_batch_loss)
loss = np.mean(mini_batch_losses)
return loss
def make_val_step_fn(model, loss_fn):
# Builds function that performs a step in the validation loop
def perform_val_step_fn(x, y):
# Sets model to EVAL mode
model.eval()
# Step 1 - Computes our model's predicted output - forward pass
yhat = model(x)
# Step 2 - Computes the loss
loss = loss_fn(yhat, y)
# There is no need to compute Steps 3 and 4, since we don't update parameters during evaluation
return loss.item()
return perform_val_step_fn
def index_splitter(n, splits, seed=13):
idx = torch.arange(n)
# Makes the split argument a tensor
splits_tensor = torch.as_tensor(splits)
# Finds the correct multiplier, so we don't have
# to worry about summing up to N (or one)
multiplier = n / splits_tensor.sum()
splits_tensor = (multiplier * splits_tensor).long()
# If there is a difference, throws at the first split
# so random_split does not complain
diff = n - splits_tensor.sum()
splits_tensor[0] += diff
# Uses PyTorch random_split to split the indices
torch.manual_seed(seed)
return random_split(idx, splits_tensor)
def make_balanced_sampler(y):
# Computes weights for compensating imbalanced classes
classes, counts = y.unique(return_counts=True)
weights = 1.0 / counts.float()
sample_weights = weights[y.squeeze().long()]
# Builds sampler with compute weights
generator = torch.Generator()
sampler = WeightedRandomSampler(
weights=sample_weights,
num_samples=len(sample_weights),
generator=generator,
replacement=True
)
return sampler
def freeze_model(model):
for parameter in model.parameters():
parameter.requires_grad = False
def preprocessed_dataset(model, loader, device=None):
if device is None:
device = next(model.parameters()).device
features = None
labels = None
for i, (x, y) in enumerate(loader):
model.eval()
x = x.to(device)
output = model(x)
if i == 0:
features = output.detach().cpu()
labels = y.cpu()
else:
features = torch.cat([features, output.detach().cpu()])
labels = torch.cat([labels, y.cpu()])
dataset = TensorDataset(features, labels)
return dataset
def inception_loss(outputs, labels):
try:
main, aux = outputs
except ValueError:
main = outputs
aux = None
loss_aux = 0
multi_loss_fn = nn.CrossEntropyLoss(reduction='mean')
loss_main = multi_loss_fn(main, labels)
if aux is not None:
loss_aux = multi_loss_fn(aux, labels)
return loss_main + 0.4 * loss_aux