-
Notifications
You must be signed in to change notification settings - Fork 2
/
warmup.py
129 lines (111 loc) · 5.01 KB
/
warmup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from models.model import RetinaNet
import torch.optim as optim
import matplotlib.pyplot as plt
import math
from torch.optim.lr_scheduler import _LRScheduler
class WarmupLR(_LRScheduler):
def __init__(self, scheduler, init_lr=1e-3, num_warmup=1, warmup_strategy='linear'):
if warmup_strategy not in ['linear', 'cos', 'constant']:
raise ValueError(
"Expect warmup_strategy to be one of ['linear', 'cos', 'constant'] but got {}".format(warmup_strategy))
self._scheduler = scheduler
self._init_lr = init_lr
self._num_warmup = num_warmup
self._step_count = 0
# Define the strategy to warm up learning rate
self._warmup_strategy = warmup_strategy
if warmup_strategy == 'cos':
self._warmup_func = self._warmup_cos
elif warmup_strategy == 'linear':
self._warmup_func = self._warmup_linear
else:
self._warmup_func = self._warmup_const
# save initial learning rate of each param group
# only useful when each param groups having different learning rate
self._format_param()
def __getattr__(self, name):
return getattr(self._scheduler, name)
def state_dict(self):
"""Returns the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in self.__dict__ which
is not the optimizer.
"""
wrapper_state_dict = {key: value for key, value in self.__dict__.items() if
(key != 'optimizer' and key != '_scheduler')}
wrapped_state_dict = {key: value for key, value in self._scheduler.__dict__.items() if key != 'optimizer'}
return {'wrapped': wrapped_state_dict, 'wrapper': wrapper_state_dict}
def load_state_dict(self, state_dict):
"""Loads the schedulers state.
Arguments:
state_dict (dict): scheduler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
self.__dict__.update(state_dict['wrapper'])
self._scheduler.__dict__.update(state_dict['wrapped'])
def _format_param(self):
# learning rate of each param group will increase
# from the min_lr to initial_lr
for group in self._scheduler.optimizer.param_groups:
group['warmup_max_lr'] = group['lr']
group['warmup_initial_lr'] = min(self._init_lr, group['lr'])
def _warmup_cos(self, start, end, pct):
"""cosine annealing function:
current = end + 0.5 * (start + end) * (1 + cos(t_current / t_total * pi)). """
cos_out = math.cos(math.pi * pct) + 1
return end + (start - end) / 2.0 * cos_out
def _warmup_const(self, start, end, pct):
return start if pct < 0.9999 else end
def _warmup_linear(self, start, end, pct):
return (end - start) * pct + start
def get_lr(self):
lrs = []
step_num = self._step_count
# warm up learning rate
if step_num <= self._num_warmup:
for group in self._scheduler.optimizer.param_groups:
computed_lr = self._warmup_func(group['warmup_initial_lr'],
group['warmup_max_lr'],
step_num / self._num_warmup)
lrs.append(computed_lr)
else:
lrs = self._scheduler.get_lr()
return lrs
def step(self, *args):
if self._step_count <= self._num_warmup:
values = self.get_lr()
for param_group, lr in zip(self._scheduler.optimizer.param_groups, values):
param_group['lr'] = lr
self._step_count += 1
else:
# method 1:
# self._scheduler.step(epoch=self._step_count)
# self._step_count += 1
# method 2:
self._scheduler._step_count = self._step_count + 1
self._scheduler.last_epoch = self._step_count
self._scheduler.step()
self._step_count += 1
if __name__ == '__main__':
epochs = 30
model = RetinaNet(backbone='resnet50', loss_func='smooth_l1')
optimizer = optim.Adam(model.parameters(), lr=1e-2)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[round(epochs * x) for x in [0.6, 0.8]],
gamma=0.1)
scheduler = WarmupLR(scheduler, init_lr=1e-5, num_warmup=2, warmup_strategy='cos')
y = []
count_one = 0
count_two = 0
for i in range(epochs):
scheduler.step()
y.append(optimizer.param_groups[0]['lr'])
if optimizer.param_groups[0]['lr'] == 0.1 * 1e-2 and count_one == 0:
print(f'lr first divided location: epoch={i + 1}')
count_one += 1
if optimizer.param_groups[0]['lr'] == 0.01 * 1e-2 and count_two == 0:
print(f'lr second divided location: epoch={i + 1}')
count_two += 1
plt.plot(y, label='LR')
plt.xlabel('epoch')
plt.ylabel('LR')
plt.tight_layout()
plt.savefig('LR-warmup.png', dpi=300)