-
Notifications
You must be signed in to change notification settings - Fork 2
/
preact_resnet_flc.py
127 lines (105 loc) · 5.02 KB
/
preact_resnet_flc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
'''Generic Class for PreAct ResNet with FLC Pooling
Reference:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
Identity Mappings in Deep Residual Networks. arXiv:1603.05027
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from flc_pooling import FLC_Pooling
class PreActBlock(nn.Module):
'''Pre-activation version of the BasicBlock.'''
expansion = 1
def __init__(self, in_planes, planes, stride=1, drop=0):
super(PreActBlock, self).__init__()
self.bn1 = nn.BatchNorm2d(in_planes)
if stride == 1:
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
else:
self.conv1 = nn.Sequential(
FLC_Pooling(),
nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=False))
self.bn2 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
FLC_Pooling(),
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=1, bias=False)
)
def forward(self, x):
out = F.relu(self.bn1(x))
shortcut = self.shortcut(x) if hasattr(self, 'shortcut') else x
out = self.conv1(out)
out = self.conv2(F.relu(self.bn2(out)))
out += shortcut
return out
class PreActBottleneck(nn.Module):
'''Pre-activation version of the original Bottleneck module.'''
expansion = 4
def __init__(self, in_planes, planes, stride=1, drop=0):
super(PreActBottleneck, self).__init__()
self.bn1 = nn.BatchNorm2d(in_planes)
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
if stride == 1:
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
else:
nn.Sequential(
FLC_Pooling(),
nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False))
self.bn3 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
FLC_Pooling(),
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=1, bias=False)
)
def forward(self, x):
out = F.relu(self.bn1(x))
shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
out = self.conv1(out)
out = self.conv2(F.relu(self.bn2(out)))
out = self.conv3(F.relu(self.bn3(out)))
out += shortcut
return out
class PreActResNet(nn.Module):
def __init__(self, block, num_blocks, num_classes=10, drop=0):
super(PreActResNet, self).__init__()
self.in_planes = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.bn = nn.BatchNorm2d(512 * block.expansion)
self.linear = nn.Linear(512 * block.expansion, num_classes)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1]*(num_blocks-1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
out = self.conv1(x)
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = F.relu(self.bn(out))
out = F.avg_pool2d(out, 4)
out = out.view(out.size(0), -1)
out = self.linear(out)
return out
def PreActResNet18(num_classes=10):
return PreActResNet(PreActBlock, [2,2,2,2], num_classes=num_classes)
class PreActResNet_normalized(PreActResNet):
def __init__(self, block, num_blocks, num_classes=10, mu=[0.4914, 0.4822, 0.4465], sigma=[0.2471, 0.2435, 0.2616], device='cuda'):
super(PreActResNet_normalized, self).__init__(block=block, num_blocks=num_blocks, num_classes=num_classes)
self.mu = torch.Tensor(mu).float().view(3, 1, 1).to(device)
self.sigma = torch.Tensor(sigma).float().view(3, 1, 1).to(device)
def forward(self, x):
x = (x - self.mu) / self.sigma
return super(PreActResNet_normalized, self).forward(x)
def PreActResNet18_normalized(num_classes=10, mu=[0.4914, 0.4822, 0.4465], sigma=[0.2471, 0.2435, 0.2616], device='cuda'):
return PreActResNet_normalized(PreActBlock, [2,2,2,2], num_classes=num_classes, mu=mu, sigma=sigma, device=device)