-
Notifications
You must be signed in to change notification settings - Fork 8
/
model.py
120 lines (99 loc) · 4.05 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
import torch.nn as nn
import torch.nn.functional as F
from ssim import msssim
NUM_BANDS = 4
def conv3x3(in_channels, out_channels, stride=1):
return nn.Sequential(
nn.ReplicationPad2d(1),
nn.Conv2d(in_channels, out_channels, 3, stride=stride)
)
def interpolate(inputs, size=None, scale_factor=None):
return F.interpolate(inputs, size=size, scale_factor=scale_factor,
mode='bilinear', align_corners=True)
class CompoundLoss(nn.Module):
def __init__(self, pretrained, alpha=0.8, normalize=True):
super(CompoundLoss, self).__init__()
self.pretrained = pretrained
self.alpha = alpha
self.normalize = normalize
def forward(self, prediction, target):
return (F.mse_loss(prediction, target) +
F.mse_loss(self.pretrained(prediction), self.pretrained(target)) +
self.alpha * (1.0 - msssim(prediction, target,
normalize=self.normalize)))
class FEncoder(nn.Sequential):
def __init__(self):
channels = [NUM_BANDS, 32, 64, 128]
super(FEncoder, self).__init__(
conv3x3(channels[0], channels[1]),
nn.ReLU(True),
conv3x3(channels[1], channels[2]),
nn.ReLU(True),
conv3x3(channels[2], channels[3]),
nn.ReLU(True)
)
class REncoder(nn.Sequential):
def __init__(self):
channels = [NUM_BANDS * 3, 32, 64, 128]
super(REncoder, self).__init__(
conv3x3(channels[0], channels[1]),
nn.ReLU(True),
conv3x3(channels[1], channels[2]),
nn.ReLU(True),
conv3x3(channels[2], channels[3])
)
class Decoder(nn.Sequential):
def __init__(self):
channels = [128, 64, 32, NUM_BANDS]
super(Decoder, self).__init__(
conv3x3(channels[0], channels[1]),
nn.ReLU(True),
conv3x3(channels[1], channels[2]),
nn.ReLU(True),
nn.Conv2d(channels[2], channels[3], 1)
)
class Pretrained(nn.Sequential):
def __init__(self):
channels = [NUM_BANDS, 32, 64, 128, 128]
super(Pretrained, self).__init__(
conv3x3(channels[0], channels[1]),
nn.ReLU(True),
conv3x3(channels[1], channels[2]),
nn.ReLU(True),
conv3x3(channels[2], channels[3], 2),
nn.ReLU(True),
conv3x3(channels[3], channels[4]),
nn.ReLU(True)
)
class FusionNet(nn.Module):
def __init__(self):
super(FusionNet, self).__init__()
self.encoder = FEncoder()
self.residual = REncoder()
self.decoder = Decoder()
def forward(self, inputs):
inputs[0] = interpolate(inputs[0], scale_factor=16)
inputs[-1] = interpolate(inputs[-1], scale_factor=16)
prev_diff = self.residual(torch.cat((inputs[0], inputs[1], inputs[-1]), 1))
if len(inputs) == 5:
inputs[2] = interpolate(inputs[2], scale_factor=16)
next_diff = self.residual(torch.cat((inputs[2], inputs[3], inputs[-1]), 1))
if self.training:
prev_fusion = self.encoder(inputs[1]) + prev_diff
next_fusion = self.encoder(inputs[3]) + next_diff
return self.decoder(prev_fusion), self.decoder(next_fusion)
else:
one = inputs[0].new_tensor(1.0)
epsilon = inputs[0].new_tensor(1e-8)
prev_dist = torch.abs(prev_diff) + epsilon
next_dist = torch.abs(next_diff) + epsilon
prev_mask = one.div(prev_dist).div(one.div(prev_dist) + one.div(next_dist))
prev_mask = prev_mask.clamp_(0.0, 1.0)
next_mask = one - prev_mask
result = (prev_mask * (self.encoder(inputs[1]) + prev_diff) +
next_mask * (self.encoder(inputs[3]) + next_diff))
result = self.decoder(result)
return result
else:
return self.decoder(self.encoder(inputs[1]) + prev_diff)