diff --git a/setup.py b/setup.py index 42e4d3f..8c42b67 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '1.4.2', + version = '1.4.3', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', long_description_content_type = 'text/markdown', diff --git a/vit_pytorch/ats_vit.py b/vit_pytorch/ats_vit.py index 779c400..f3c788c 100644 --- a/vit_pytorch/ats_vit.py +++ b/vit_pytorch/ats_vit.py @@ -110,18 +110,11 @@ def forward(self, attn, value, mask): # classes -class PreNorm(nn.Module): - def __init__(self, dim, fn): - super().__init__() - self.norm = nn.LayerNorm(dim) - self.fn = fn - def forward(self, x, **kwargs): - return self.fn(self.norm(x), **kwargs) - class FeedForward(nn.Module): def __init__(self, dim, hidden_dim, dropout = 0.): super().__init__() self.net = nn.Sequential( + nn.LayerNorm(dim), nn.Linear(dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), @@ -138,6 +131,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., output_num_token self.heads = heads self.scale = dim_head ** -0.5 + self.norm = nn.LayerNorm(dim) self.attend = nn.Softmax(dim = -1) self.dropout = nn.Dropout(dropout) @@ -154,6 +148,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., output_num_token def forward(self, x, *, mask): num_tokens = x.shape[1] + x = self.norm(x) qkv = self.to_qkv(x).chunk(3, dim = -1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) @@ -189,8 +184,8 @@ def __init__(self, dim, depth, max_tokens_per_depth, heads, dim_head, mlp_dim, d self.layers = nn.ModuleList([]) for _, output_num_tokens in zip(range(depth), max_tokens_per_depth): self.layers.append(nn.ModuleList([ - PreNorm(dim, Attention(dim, output_num_tokens = output_num_tokens, heads = heads, dim_head = dim_head, dropout = dropout)), - PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + Attention(dim, output_num_tokens = output_num_tokens, heads = heads, dim_head = dim_head, dropout = dropout), + FeedForward(dim, mlp_dim, dropout = dropout) ])) def forward(self, x): diff --git a/vit_pytorch/cait.py b/vit_pytorch/cait.py index eac9185..f95cae2 100644 --- a/vit_pytorch/cait.py +++ b/vit_pytorch/cait.py @@ -44,18 +44,11 @@ def __init__(self, dim, fn, depth): def forward(self, x, **kwargs): return self.fn(x, **kwargs) * self.scale -class PreNorm(nn.Module): - def __init__(self, dim, fn): - super().__init__() - self.norm = nn.LayerNorm(dim) - self.fn = fn - def forward(self, x, **kwargs): - return self.fn(self.norm(x), **kwargs) - class FeedForward(nn.Module): def __init__(self, dim, hidden_dim, dropout = 0.): super().__init__() self.net = nn.Sequential( + nn.LayerNorm(dim), nn.Linear(dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), @@ -72,6 +65,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): self.heads = heads self.scale = dim_head ** -0.5 + self.norm = nn.LayerNorm(dim) self.to_q = nn.Linear(dim, inner_dim, bias = False) self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) @@ -89,6 +83,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): def forward(self, x, context = None): b, n, _, h = *x.shape, self.heads + x = self.norm(x) context = x if not exists(context) else torch.cat((x, context), dim = 1) qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1)) @@ -115,8 +110,8 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., layer_dro for ind in range(depth): self.layers.append(nn.ModuleList([ - LayerScale(dim, PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), depth = ind + 1), - LayerScale(dim, PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)), depth = ind + 1) + LayerScale(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), depth = ind + 1), + LayerScale(dim, FeedForward(dim, mlp_dim, dropout = dropout), depth = ind + 1) ])) def forward(self, x, context = None): layers = dropout_layers(self.layers, dropout = self.layer_dropout) diff --git a/vit_pytorch/cross_vit.py b/vit_pytorch/cross_vit.py index b894a2f..210cbe0 100644 --- a/vit_pytorch/cross_vit.py +++ b/vit_pytorch/cross_vit.py @@ -13,22 +13,13 @@ def exists(val): def default(val, d): return val if exists(val) else d -# pre-layernorm - -class PreNorm(nn.Module): - def __init__(self, dim, fn): - super().__init__() - self.norm = nn.LayerNorm(dim) - self.fn = fn - def forward(self, x, **kwargs): - return self.fn(self.norm(x), **kwargs) - # feedforward class FeedForward(nn.Module): def __init__(self, dim, hidden_dim, dropout = 0.): super().__init__() self.net = nn.Sequential( + nn.LayerNorm(dim), nn.Linear(dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), @@ -47,6 +38,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): self.heads = heads self.scale = dim_head ** -0.5 + self.norm = nn.LayerNorm(dim) self.attend = nn.Softmax(dim = -1) self.dropout = nn.Dropout(dropout) @@ -60,6 +52,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): def forward(self, x, context = None, kv_include_self = False): b, n, _, h = *x.shape, self.heads + x = self.norm(x) context = default(context, x) if kv_include_self: @@ -86,8 +79,8 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): self.norm = nn.LayerNorm(dim) for _ in range(depth): self.layers.append(nn.ModuleList([ - PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), - PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), + FeedForward(dim, mlp_dim, dropout = dropout) ])) def forward(self, x): @@ -121,8 +114,8 @@ def __init__(self, sm_dim, lg_dim, depth, heads, dim_head, dropout): self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ - ProjectInOut(sm_dim, lg_dim, PreNorm(lg_dim, Attention(lg_dim, heads = heads, dim_head = dim_head, dropout = dropout))), - ProjectInOut(lg_dim, sm_dim, PreNorm(sm_dim, Attention(sm_dim, heads = heads, dim_head = dim_head, dropout = dropout))) + ProjectInOut(sm_dim, lg_dim, Attention(lg_dim, heads = heads, dim_head = dim_head, dropout = dropout)), + ProjectInOut(lg_dim, sm_dim, ttention(sm_dim, heads = heads, dim_head = dim_head, dropout = dropout)) ])) def forward(self, sm_tokens, lg_tokens): diff --git a/vit_pytorch/cvt.py b/vit_pytorch/cvt.py index 2750284..6f214f7 100644 --- a/vit_pytorch/cvt.py +++ b/vit_pytorch/cvt.py @@ -34,19 +34,11 @@ def forward(self, x): mean = torch.mean(x, dim = 1, keepdim = True) return (x - mean) / (var + self.eps).sqrt() * self.g + self.b -class PreNorm(nn.Module): - def __init__(self, dim, fn): - super().__init__() - self.norm = LayerNorm(dim) - self.fn = fn - def forward(self, x, **kwargs): - x = self.norm(x) - return self.fn(x, **kwargs) - class FeedForward(nn.Module): def __init__(self, dim, mult = 4, dropout = 0.): super().__init__() self.net = nn.Sequential( + LayerNorm(dim), nn.Conv2d(dim, dim * mult, 1), nn.GELU(), nn.Dropout(dropout), @@ -75,6 +67,7 @@ def __init__(self, dim, proj_kernel, kv_proj_stride, heads = 8, dim_head = 64, d self.heads = heads self.scale = dim_head ** -0.5 + self.norm = LayerNorm(dim) self.attend = nn.Softmax(dim = -1) self.dropout = nn.Dropout(dropout) @@ -89,6 +82,8 @@ def __init__(self, dim, proj_kernel, kv_proj_stride, heads = 8, dim_head = 64, d def forward(self, x): shape = x.shape b, n, _, y, h = *shape, self.heads + + x = self.norm(x) q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = 1)) q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> (b h) (x y) d', h = h), (q, k, v)) @@ -107,8 +102,8 @@ def __init__(self, dim, proj_kernel, kv_proj_stride, depth, heads, dim_head = 64 self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ - PreNorm(dim, Attention(dim, proj_kernel = proj_kernel, kv_proj_stride = kv_proj_stride, heads = heads, dim_head = dim_head, dropout = dropout)), - PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout)) + Attention(dim, proj_kernel = proj_kernel, kv_proj_stride = kv_proj_stride, heads = heads, dim_head = dim_head, dropout = dropout), + FeedForward(dim, mlp_mult, dropout = dropout) ])) def forward(self, x): for attn, ff in self.layers: diff --git a/vit_pytorch/deepvit.py b/vit_pytorch/deepvit.py index c86f122..a62cb7a 100644 --- a/vit_pytorch/deepvit.py +++ b/vit_pytorch/deepvit.py @@ -5,25 +5,11 @@ from einops import rearrange, repeat from einops.layers.torch import Rearrange -class Residual(nn.Module): - def __init__(self, fn): - super().__init__() - self.fn = fn - def forward(self, x, **kwargs): - return self.fn(x, **kwargs) + x - -class PreNorm(nn.Module): - def __init__(self, dim, fn): - super().__init__() - self.norm = nn.LayerNorm(dim) - self.fn = fn - def forward(self, x, **kwargs): - return self.fn(self.norm(x), **kwargs) - class FeedForward(nn.Module): def __init__(self, dim, hidden_dim, dropout = 0.): super().__init__() self.net = nn.Sequential( + nn.LayerNorm(dim), nn.Linear(dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), @@ -40,6 +26,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): self.heads = heads self.scale = dim_head ** -0.5 + self.norm = nn.LayerNorm(dim) self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) self.dropout = nn.Dropout(dropout) @@ -59,6 +46,8 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): def forward(self, x): b, n, _, h = *x.shape, self.heads + x = self.norm(x) + qkv = self.to_qkv(x).chunk(3, dim = -1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) @@ -86,13 +75,13 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ - Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))), - Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))) + Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), + FeedForward(dim, mlp_dim, dropout = dropout) ])) def forward(self, x): for attn, ff in self.layers: - x = attn(x) - x = ff(x) + x = attn(x) + x + x = ff(x) + x return x class DeepViT(nn.Module): diff --git a/vit_pytorch/local_vit.py b/vit_pytorch/local_vit.py index 3163e41..afaf858 100644 --- a/vit_pytorch/local_vit.py +++ b/vit_pytorch/local_vit.py @@ -26,16 +26,6 @@ def forward(self, x, **kwargs): x = self.fn(x, **kwargs) return torch.cat((cls_token, x), dim = 1) -# prenorm - -class PreNorm(nn.Module): - def __init__(self, dim, fn): - super().__init__() - self.norm = nn.LayerNorm(dim) - self.fn = fn - def forward(self, x, **kwargs): - return self.fn(self.norm(x), **kwargs) - # feed forward related classes class DepthWiseConv2d(nn.Module): @@ -52,6 +42,7 @@ class FeedForward(nn.Module): def __init__(self, dim, hidden_dim, dropout = 0.): super().__init__() self.net = nn.Sequential( + nn.LayerNorm(dim), nn.Conv2d(dim, hidden_dim, 1), nn.Hardswish(), DepthWiseConv2d(hidden_dim, hidden_dim, 3, padding = 1), @@ -77,6 +68,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): self.heads = heads self.scale = dim_head ** -0.5 + self.norm = nn.LayerNorm(dim) self.attend = nn.Softmax(dim = -1) self.dropout = nn.Dropout(dropout) self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) @@ -88,6 +80,8 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): def forward(self, x): b, n, _, h = *x.shape, self.heads + + x = self.norm(x) qkv = self.to_qkv(x).chunk(3, dim = -1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) @@ -106,8 +100,8 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ - Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))), - ExcludeCLS(Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))) + Residual(Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + ExcludeCLS(Residual(FeedForward(dim, mlp_dim, dropout = dropout))) ])) def forward(self, x): for attn, ff in self.layers: diff --git a/vit_pytorch/max_vit.py b/vit_pytorch/max_vit.py index 8359f0c..1c76d34 100644 --- a/vit_pytorch/max_vit.py +++ b/vit_pytorch/max_vit.py @@ -19,20 +19,20 @@ def cast_tuple(val, length = 1): # helper classes -class PreNormResidual(nn.Module): +class Residual(nn.Module): def __init__(self, dim, fn): super().__init__() - self.norm = nn.LayerNorm(dim) self.fn = fn def forward(self, x): - return self.fn(self.norm(x)) + x + return self.fn(x) + x class FeedForward(nn.Module): def __init__(self, dim, mult = 4, dropout = 0.): super().__init__() inner_dim = int(dim * mult) self.net = nn.Sequential( + nn.LayerNorm(dim), nn.Linear(dim, inner_dim), nn.GELU(), nn.Dropout(dropout), @@ -132,6 +132,7 @@ def __init__( self.heads = dim // dim_head self.scale = dim_head ** -0.5 + self.norm = nn.LayerNorm(dim) self.to_qkv = nn.Linear(dim, dim * 3, bias = False) self.attend = nn.Sequential( @@ -160,6 +161,8 @@ def __init__( def forward(self, x): batch, height, width, window_height, window_width, _, device, h = *x.shape, x.device, self.heads + x = self.norm(x) + # flatten x = rearrange(x, 'b x y w1 w2 d -> (b x y) (w1 w2) d') @@ -259,13 +262,13 @@ def __init__( shrinkage_rate = mbconv_shrinkage_rate ), Rearrange('b d (x w1) (y w2) -> b x y w1 w2 d', w1 = w, w2 = w), # block-like attention - PreNormResidual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)), - PreNormResidual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)), + Residual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)), + Residual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)), Rearrange('b x y w1 w2 d -> b d (x w1) (y w2)'), Rearrange('b d (w1 x) (w2 y) -> b x y w1 w2 d', w1 = w, w2 = w), # grid-like attention - PreNormResidual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)), - PreNormResidual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)), + Residual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)), + Residual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)), Rearrange('b x y w1 w2 d -> b d (w1 x) (w2 y)'), ) diff --git a/vit_pytorch/mobile_vit.py b/vit_pytorch/mobile_vit.py index e0b7b8b..e391742 100644 --- a/vit_pytorch/mobile_vit.py +++ b/vit_pytorch/mobile_vit.py @@ -22,20 +22,11 @@ def conv_nxn_bn(inp, oup, kernel_size=3, stride=1): # classes -class PreNorm(nn.Module): - def __init__(self, dim, fn): - super().__init__() - self.norm = nn.LayerNorm(dim) - self.fn = fn - - def forward(self, x, **kwargs): - return self.fn(self.norm(x), **kwargs) - - class FeedForward(nn.Module): def __init__(self, dim, hidden_dim, dropout=0.): super().__init__() self.net = nn.Sequential( + nn.LayerNorm(dim), nn.Linear(dim, hidden_dim), nn.SiLU(), nn.Dropout(dropout), @@ -53,6 +44,7 @@ def __init__(self, dim, heads=8, dim_head=64, dropout=0.): self.heads = heads self.scale = dim_head ** -0.5 + self.norm = nn.LayerNorm(dim) self.attend = nn.Softmax(dim=-1) self.dropout = nn.Dropout(dropout) @@ -64,9 +56,10 @@ def __init__(self, dim, heads=8, dim_head=64, dropout=0.): ) def forward(self, x): + x = self.norm(x) qkv = self.to_qkv(x).chunk(3, dim=-1) - q, k, v = map(lambda t: rearrange( - t, 'b p n (h d) -> b p h n d', h=self.heads), qkv) + + q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv) dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale @@ -88,8 +81,8 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.): self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ - PreNorm(dim, Attention(dim, heads, dim_head, dropout)), - PreNorm(dim, FeedForward(dim, mlp_dim, dropout)) + Attention(dim, heads, dim_head, dropout), + FeedForward(dim, mlp_dim, dropout) ])) def forward(self, x): @@ -167,11 +160,9 @@ def forward(self, x): # Global representations _, _, h, w = x.shape - x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', - ph=self.ph, pw=self.pw) - x = self.transformer(x) - x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', - h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw) + x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw) + x = self.transformer(x) + x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw) # Fusion x = self.conv3(x) diff --git a/vit_pytorch/nest.py b/vit_pytorch/nest.py index 237e106..68a1fa8 100644 --- a/vit_pytorch/nest.py +++ b/vit_pytorch/nest.py @@ -24,19 +24,11 @@ def forward(self, x): mean = torch.mean(x, dim = 1, keepdim = True) return (x - mean) / (var + self.eps).sqrt() * self.g + self.b -class PreNorm(nn.Module): - def __init__(self, dim, fn): - super().__init__() - self.norm = LayerNorm(dim) - self.fn = fn - - def forward(self, x, **kwargs): - return self.fn(self.norm(x), **kwargs) - class FeedForward(nn.Module): def __init__(self, dim, mlp_mult = 4, dropout = 0.): super().__init__() self.net = nn.Sequential( + LayerNorm(dim), nn.Conv2d(dim, dim * mlp_mult, 1), nn.GELU(), nn.Dropout(dropout), @@ -54,6 +46,7 @@ def __init__(self, dim, heads = 8, dropout = 0.): self.heads = heads self.scale = dim_head ** -0.5 + self.norm = LayerNorm(dim) self.attend = nn.Softmax(dim = -1) self.dropout = nn.Dropout(dropout) self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False) @@ -66,6 +59,8 @@ def __init__(self, dim, heads = 8, dropout = 0.): def forward(self, x): b, c, h, w, heads = *x.shape, self.heads + x = self.norm(x) + qkv = self.to_qkv(x).chunk(3, dim = 1) q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> b h (x y) d', h = heads), qkv) @@ -93,8 +88,8 @@ def __init__(self, dim, seq_len, depth, heads, mlp_mult, dropout = 0.): for _ in range(depth): self.layers.append(nn.ModuleList([ - PreNorm(dim, Attention(dim, heads = heads, dropout = dropout)), - PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout)) + Attention(dim, heads = heads, dropout = dropout), + FeedForward(dim, mlp_mult, dropout = dropout) ])) def forward(self, x): *_, h, w = x.shape diff --git a/vit_pytorch/parallel_vit.py b/vit_pytorch/parallel_vit.py index bd736d2..7b5ca1f 100644 --- a/vit_pytorch/parallel_vit.py +++ b/vit_pytorch/parallel_vit.py @@ -19,18 +19,11 @@ def __init__(self, *fns): def forward(self, x): return sum([fn(x) for fn in self.fns]) -class PreNorm(nn.Module): - def __init__(self, dim, fn): - super().__init__() - self.norm = nn.LayerNorm(dim) - self.fn = fn - def forward(self, x, **kwargs): - return self.fn(self.norm(x), **kwargs) - class FeedForward(nn.Module): def __init__(self, dim, hidden_dim, dropout = 0.): super().__init__() self.net = nn.Sequential( + nn.LayerNorm(dim), nn.Linear(dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), @@ -49,6 +42,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): self.heads = heads self.scale = dim_head ** -0.5 + self.norm = nn.LayerNorm(dim) self.attend = nn.Softmax(dim = -1) self.dropout = nn.Dropout(dropout) @@ -60,6 +54,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): ) if project_out else nn.Identity() def forward(self, x): + x = self.norm(x) qkv = self.to_qkv(x).chunk(3, dim = -1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) @@ -77,8 +72,8 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim, num_parallel_branches = super().__init__() self.layers = nn.ModuleList([]) - attn_block = lambda: PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)) - ff_block = lambda: PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + attn_block = lambda: Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout) + ff_block = lambda: FeedForward(dim, mlp_dim, dropout = dropout) for _ in range(depth): self.layers.append(nn.ModuleList([ diff --git a/vit_pytorch/pit.py b/vit_pytorch/pit.py index 7ed257a..ba7ddef 100644 --- a/vit_pytorch/pit.py +++ b/vit_pytorch/pit.py @@ -17,18 +17,11 @@ def conv_output_size(image_size, kernel_size, stride, padding = 0): # classes -class PreNorm(nn.Module): - def __init__(self, dim, fn): - super().__init__() - self.norm = nn.LayerNorm(dim) - self.fn = fn - def forward(self, x, **kwargs): - return self.fn(self.norm(x), **kwargs) - class FeedForward(nn.Module): def __init__(self, dim, hidden_dim, dropout = 0.): super().__init__() self.net = nn.Sequential( + nn.LayerNorm(dim), nn.Linear(dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), @@ -47,6 +40,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): self.heads = heads self.scale = dim_head ** -0.5 + self.norm = nn.LayerNorm(dim) self.attend = nn.Softmax(dim = -1) self.dropout = nn.Dropout(dropout) self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) @@ -58,6 +52,8 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): def forward(self, x): b, n, _, h = *x.shape, self.heads + + x = self.norm(x) qkv = self.to_qkv(x).chunk(3, dim = -1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) @@ -76,8 +72,8 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ - PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), - PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), + FeedForward(dim, mlp_dim, dropout = dropout) ])) def forward(self, x): for attn, ff in self.layers: diff --git a/vit_pytorch/rvt.py b/vit_pytorch/rvt.py index 5e95442..1ad51dc 100644 --- a/vit_pytorch/rvt.py +++ b/vit_pytorch/rvt.py @@ -55,14 +55,6 @@ def forward(self, x): # helper classes -class PreNorm(nn.Module): - def __init__(self, dim, fn): - super().__init__() - self.norm = nn.LayerNorm(dim) - self.fn = fn - def forward(self, x, **kwargs): - return self.fn(self.norm(x), **kwargs) - class SpatialConv(nn.Module): def __init__(self, dim_in, dim_out, kernel, bias = False): super().__init__() @@ -86,6 +78,7 @@ class FeedForward(nn.Module): def __init__(self, dim, hidden_dim, dropout = 0., use_glu = True): super().__init__() self.net = nn.Sequential( + nn.LayerNorm(dim), nn.Linear(dim, hidden_dim * 2 if use_glu else hidden_dim), GEGLU() if use_glu else nn.GELU(), nn.Dropout(dropout), @@ -103,6 +96,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., use_rotary = Tru self.heads = heads self.scale = dim_head ** -0.5 + self.norm = nn.LayerNorm(dim) self.attend = nn.Softmax(dim = -1) self.dropout = nn.Dropout(dropout) @@ -121,6 +115,9 @@ def forward(self, x, pos_emb, fmap_dims): b, n, _, h = *x.shape, self.heads to_q_kwargs = {'fmap_dims': fmap_dims} if self.use_ds_conv else {} + + x = self.norm(x) + q = self.to_q(x, **to_q_kwargs) qkv = (q, *self.to_kv(x).chunk(2, dim = -1)) @@ -162,8 +159,8 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim, image_size, dropout = 0 self.pos_emb = AxialRotaryEmbedding(dim_head, max_freq = image_size) for _ in range(depth): self.layers.append(nn.ModuleList([ - PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_rotary = use_rotary, use_ds_conv = use_ds_conv)), - PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout, use_glu = use_glu)) + Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_rotary = use_rotary, use_ds_conv = use_ds_conv), + FeedForward(dim, mlp_dim, dropout = dropout, use_glu = use_glu) ])) def forward(self, x, fmap_dims): pos_emb = self.pos_emb(x[:, 1:]) diff --git a/vit_pytorch/scalable_vit.py b/vit_pytorch/scalable_vit.py index b6cf8ed..9dd3cb3 100644 --- a/vit_pytorch/scalable_vit.py +++ b/vit_pytorch/scalable_vit.py @@ -33,15 +33,6 @@ def forward(self, x): mean = torch.mean(x, dim = 1, keepdim = True) return (x - mean) / (var + self.eps).sqrt() * self.g + self.b -class PreNorm(nn.Module): - def __init__(self, dim, fn): - super().__init__() - self.norm = ChanLayerNorm(dim) - self.fn = fn - - def forward(self, x): - return self.fn(self.norm(x)) - class Downsample(nn.Module): def __init__(self, dim_in, dim_out): super().__init__() @@ -65,6 +56,7 @@ def __init__(self, dim, expansion_factor = 4, dropout = 0.): super().__init__() inner_dim = dim * expansion_factor self.net = nn.Sequential( + ChanLayerNorm(dim), nn.Conv2d(dim, inner_dim, 1), nn.GELU(), nn.Dropout(dropout), @@ -92,6 +84,7 @@ def __init__( self.attend = nn.Softmax(dim = -1) self.dropout = nn.Dropout(dropout) + self.norm = ChanLayerNorm(dim) self.to_q = nn.Conv2d(dim, dim_key * heads, 1, bias = False) self.to_k = nn.Conv2d(dim, dim_key * heads, reduction_factor, stride = reduction_factor, bias = False) self.to_v = nn.Conv2d(dim, dim_value * heads, reduction_factor, stride = reduction_factor, bias = False) @@ -104,6 +97,8 @@ def __init__( def forward(self, x): height, width, heads = *x.shape[-2:], self.heads + x = self.norm(x) + q, k, v = self.to_q(x), self.to_k(x), self.to_v(x) # split out heads @@ -145,6 +140,7 @@ def __init__( self.attend = nn.Softmax(dim = -1) self.dropout = nn.Dropout(dropout) + self.norm = ChanLayerNorm(dim) self.local_interactive_module = nn.Conv2d(dim_value * heads, dim_value * heads, 3, padding = 1) self.to_q = nn.Conv2d(dim, dim_key * heads, 1, bias = False) @@ -159,6 +155,8 @@ def __init__( def forward(self, x): height, width, heads, wsz = *x.shape[-2:], self.heads, self.window_size + x = self.norm(x) + wsz_h, wsz_w = default(wsz, height), default(wsz, width) assert (height % wsz_h) == 0 and (width % wsz_w) == 0, f'height ({height}) or width ({width}) of feature map is not divisible by the window size ({wsz_h}, {wsz_w})' @@ -217,11 +215,11 @@ def __init__( is_first = ind == 0 self.layers.append(nn.ModuleList([ - PreNorm(dim, ScalableSelfAttention(dim, heads = heads, dim_key = ssa_dim_key, dim_value = ssa_dim_value, reduction_factor = ssa_reduction_factor, dropout = dropout)), - PreNorm(dim, FeedForward(dim, expansion_factor = ff_expansion_factor, dropout = dropout)), + ScalableSelfAttention(dim, heads = heads, dim_key = ssa_dim_key, dim_value = ssa_dim_value, reduction_factor = ssa_reduction_factor, dropout = dropout), + FeedForward(dim, expansion_factor = ff_expansion_factor, dropout = dropout), PEG(dim) if is_first else None, - PreNorm(dim, FeedForward(dim, expansion_factor = ff_expansion_factor, dropout = dropout)), - PreNorm(dim, InteractiveWindowedSelfAttention(dim, heads = heads, dim_key = iwsa_dim_key, dim_value = iwsa_dim_value, window_size = iwsa_window_size, dropout = dropout)) + FeedForward(dim, expansion_factor = ff_expansion_factor, dropout = dropout), + InteractiveWindowedSelfAttention(dim, heads = heads, dim_key = iwsa_dim_key, dim_value = iwsa_dim_value, window_size = iwsa_window_size, dropout = dropout) ])) self.norm = ChanLayerNorm(dim) if norm_output else nn.Identity() diff --git a/vit_pytorch/sep_vit.py b/vit_pytorch/sep_vit.py index 4b16270..d735722 100644 --- a/vit_pytorch/sep_vit.py +++ b/vit_pytorch/sep_vit.py @@ -25,15 +25,6 @@ def forward(self, x): mean = torch.mean(x, dim = 1, keepdim = True) return (x - mean) / (var + self.eps).sqrt() * self.g + self.b -class PreNorm(nn.Module): - def __init__(self, dim, fn): - super().__init__() - self.norm = ChanLayerNorm(dim) - self.fn = fn - - def forward(self, x): - return self.fn(self.norm(x)) - class OverlappingPatchEmbed(nn.Module): def __init__(self, dim_in, dim_out, stride = 2): super().__init__() @@ -59,6 +50,7 @@ def __init__(self, dim, mult = 4, dropout = 0.): super().__init__() inner_dim = int(dim * mult) self.net = nn.Sequential( + ChanLayerNorm(dim), nn.Conv2d(dim, inner_dim, 1), nn.GELU(), nn.Dropout(dropout), @@ -85,6 +77,8 @@ def __init__( self.window_size = window_size inner_dim = dim_head * heads + self.norm = ChanLayerNorm(dim) + self.attend = nn.Sequential( nn.Softmax(dim = -1), nn.Dropout(dropout) @@ -138,6 +132,8 @@ def forward(self, x): assert (height % wsz) == 0 and (width % wsz) == 0, f'height {height} and width {width} must be divisible by window size {wsz}' num_windows = (height // wsz) * (width // wsz) + x = self.norm(x) + # fold in windows for "depthwise" attention - not sure why it is named depthwise when it is just "windowed" attention x = rearrange(x, 'b c (h w1) (w w2) -> (b h w) c (w1 w2)', w1 = wsz, w2 = wsz) @@ -225,8 +221,8 @@ def __init__( for ind in range(depth): self.layers.append(nn.ModuleList([ - PreNorm(dim, DSSA(dim, heads = heads, dim_head = dim_head, dropout = dropout)), - PreNorm(dim, FeedForward(dim, mult = ff_mult, dropout = dropout)), + DSSA(dim, heads = heads, dim_head = dim_head, dropout = dropout), + FeedForward(dim, mult = ff_mult, dropout = dropout), ])) self.norm = ChanLayerNorm(dim) if norm_output else nn.Identity() diff --git a/vit_pytorch/twins_svt.py b/vit_pytorch/twins_svt.py index ea888b8..703a506 100644 --- a/vit_pytorch/twins_svt.py +++ b/vit_pytorch/twins_svt.py @@ -42,20 +42,11 @@ def forward(self, x): mean = torch.mean(x, dim = 1, keepdim = True) return (x - mean) / (var + self.eps).sqrt() * self.g + self.b -class PreNorm(nn.Module): - def __init__(self, dim, fn): - super().__init__() - self.norm = LayerNorm(dim) - self.fn = fn - - def forward(self, x, **kwargs): - x = self.norm(x) - return self.fn(x, **kwargs) - class FeedForward(nn.Module): def __init__(self, dim, mult = 4, dropout = 0.): super().__init__() self.net = nn.Sequential( + LayerNorm(dim), nn.Conv2d(dim, dim * mult, 1), nn.GELU(), nn.Dropout(dropout), @@ -99,6 +90,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., patch_size = 7): self.heads = heads self.scale = dim_head ** -0.5 + self.norm = LayerNorm(dim) self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False) self.to_kv = nn.Conv2d(dim, inner_dim * 2, 1, bias = False) @@ -108,6 +100,8 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., patch_size = 7): ) def forward(self, fmap): + fmap = self.norm(fmap) + shape, p = fmap.shape, self.patch_size b, n, x, y, h = *shape, self.heads x, y = map(lambda t: t // p, (x, y)) @@ -132,6 +126,8 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., k = 7): self.heads = heads self.scale = dim_head ** -0.5 + self.norm = LayerNorm(dim) + self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False) self.to_kv = nn.Conv2d(dim, inner_dim * 2, k, stride = k, bias = False) @@ -143,6 +139,8 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., k = 7): ) def forward(self, x): + x = self.norm(x) + shape = x.shape b, n, _, y, h = *shape, self.heads q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = 1)) @@ -164,10 +162,10 @@ def __init__(self, dim, depth, heads = 8, dim_head = 64, mlp_mult = 4, local_pat self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ - Residual(PreNorm(dim, LocalAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout, patch_size = local_patch_size))) if has_local else nn.Identity(), - Residual(PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout))) if has_local else nn.Identity(), - Residual(PreNorm(dim, GlobalAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout, k = global_k))), - Residual(PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout))) + Residual(LocalAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout, patch_size = local_patch_size)) if has_local else nn.Identity(), + Residual(FeedForward(dim, mlp_mult, dropout = dropout)) if has_local else nn.Identity(), + Residual(GlobalAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout, k = global_k)), + Residual(FeedForward(dim, mlp_mult, dropout = dropout)) ])) def forward(self, x): for local_attn, ff1, global_attn, ff2 in self.layers: diff --git a/vit_pytorch/vit_1d.py b/vit_pytorch/vit_1d.py index a0e130e..c67e135 100644 --- a/vit_pytorch/vit_1d.py +++ b/vit_pytorch/vit_1d.py @@ -6,18 +6,11 @@ # classes -class PreNorm(nn.Module): - def __init__(self, dim, fn): - super().__init__() - self.norm = nn.LayerNorm(dim) - self.fn = fn - def forward(self, x, **kwargs): - return self.fn(self.norm(x), **kwargs) - class FeedForward(nn.Module): def __init__(self, dim, hidden_dim, dropout = 0.): super().__init__() self.net = nn.Sequential( + nn.Layernorm(dim), nn.Linear(dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), @@ -36,6 +29,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): self.heads = heads self.scale = dim_head ** -0.5 + self.norm = nn.LayerNorm(dim) self.attend = nn.Softmax(dim = -1) self.dropout = nn.Dropout(dropout) @@ -47,6 +41,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): ) if project_out else nn.Identity() def forward(self, x): + x = self.norm(x) qkv = self.to_qkv(x).chunk(3, dim = -1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) @@ -65,8 +60,8 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ - PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), - PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), + FeedForward(dim, mlp_dim, dropout = dropout) ])) def forward(self, x): for attn, ff in self.layers: diff --git a/vit_pytorch/vit_3d.py b/vit_pytorch/vit_3d.py index 26d07c4..a2058fb 100644 --- a/vit_pytorch/vit_3d.py +++ b/vit_pytorch/vit_3d.py @@ -11,18 +11,11 @@ def pair(t): # classes -class PreNorm(nn.Module): - def __init__(self, dim, fn): - super().__init__() - self.norm = nn.LayerNorm(dim) - self.fn = fn - def forward(self, x, **kwargs): - return self.fn(self.norm(x), **kwargs) - class FeedForward(nn.Module): def __init__(self, dim, hidden_dim, dropout = 0.): super().__init__() self.net = nn.Sequential( + nn.LayerNorm(dim), nn.Linear(dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), @@ -41,6 +34,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): self.heads = heads self.scale = dim_head ** -0.5 + self.norm = nn.LayerNorm(dim) self.attend = nn.Softmax(dim = -1) self.dropout = nn.Dropout(dropout) @@ -52,6 +46,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): ) if project_out else nn.Identity() def forward(self, x): + x = self.norm(x) qkv = self.to_qkv(x).chunk(3, dim = -1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) @@ -70,8 +65,8 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ - PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), - PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), + FeedForward(dim, mlp_dim, dropout = dropout) ])) def forward(self, x): for attn, ff in self.layers: diff --git a/vit_pytorch/vit_for_small_dataset.py b/vit_pytorch/vit_for_small_dataset.py index 4884f22..1ec79ad 100644 --- a/vit_pytorch/vit_for_small_dataset.py +++ b/vit_pytorch/vit_for_small_dataset.py @@ -13,18 +13,11 @@ def pair(t): # classes -class PreNorm(nn.Module): - def __init__(self, dim, fn): - super().__init__() - self.norm = nn.LayerNorm(dim) - self.fn = fn - def forward(self, x, **kwargs): - return self.fn(self.norm(x), **kwargs) - class FeedForward(nn.Module): def __init__(self, dim, hidden_dim, dropout = 0.): super().__init__() self.net = nn.Sequential( + nn.LayerNorm(dim) nn.Linear(dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), @@ -41,6 +34,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): self.heads = heads self.temperature = nn.Parameter(torch.log(torch.tensor(dim_head ** -0.5))) + self.norm = nn.LayerNorm(dim) self.attend = nn.Softmax(dim = -1) self.dropout = nn.Dropout(dropout) @@ -52,6 +46,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): ) def forward(self, x): + x = self.norm(x) qkv = self.to_qkv(x).chunk(3, dim = -1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) @@ -74,8 +69,8 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ - PreNorm(dim, LSA(dim, heads = heads, dim_head = dim_head, dropout = dropout)), - PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + LSA(dim, heads = heads, dim_head = dim_head, dropout = dropout), + FeedForward(dim, mlp_dim, dropout = dropout) ])) def forward(self, x): for attn, ff in self.layers: diff --git a/vit_pytorch/vit_with_patch_dropout.py b/vit_pytorch/vit_with_patch_dropout.py index 16278d6..17648e5 100644 --- a/vit_pytorch/vit_with_patch_dropout.py +++ b/vit_pytorch/vit_with_patch_dropout.py @@ -30,18 +30,11 @@ def forward(self, x): return x[batch_indices, patch_indices_keep] -class PreNorm(nn.Module): - def __init__(self, dim, fn): - super().__init__() - self.norm = nn.LayerNorm(dim) - self.fn = fn - def forward(self, x, **kwargs): - return self.fn(self.norm(x), **kwargs) - class FeedForward(nn.Module): def __init__(self, dim, hidden_dim, dropout = 0.): super().__init__() self.net = nn.Sequential( + nn.LayerNorm(dim), nn.Linear(dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), @@ -60,6 +53,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): self.heads = heads self.scale = dim_head ** -0.5 + self.norm = nn.LayerNorm(dim) self.attend = nn.Softmax(dim = -1) self.dropout = nn.Dropout(dropout) @@ -71,6 +65,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): ) if project_out else nn.Identity() def forward(self, x): + x = self.norm(x) qkv = self.to_qkv(x).chunk(3, dim = -1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) @@ -89,8 +84,8 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ - PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), - PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), + FeedForward(dim, mlp_dim, dropout = dropout) ])) def forward(self, x): for attn, ff in self.layers: diff --git a/vit_pytorch/vivit.py b/vit_pytorch/vivit.py index 2df8f01..b95afdc 100644 --- a/vit_pytorch/vivit.py +++ b/vit_pytorch/vivit.py @@ -14,18 +14,11 @@ def pair(t): # classes -class PreNorm(nn.Module): - def __init__(self, dim, fn): - super().__init__() - self.norm = nn.LayerNorm(dim) - self.fn = fn - def forward(self, x, **kwargs): - return self.fn(self.norm(x), **kwargs) - class FeedForward(nn.Module): def __init__(self, dim, hidden_dim, dropout = 0.): super().__init__() self.net = nn.Sequential( + nn.LayerNorm(dim), nn.Linear(dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), @@ -44,6 +37,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): self.heads = heads self.scale = dim_head ** -0.5 + self.norm = nn.LayerNorm(dim) self.attend = nn.Softmax(dim = -1) self.dropout = nn.Dropout(dropout) @@ -55,6 +49,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): ) if project_out else nn.Identity() def forward(self, x): + x = self.norm(x) qkv = self.to_qkv(x).chunk(3, dim = -1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) @@ -74,8 +69,8 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ - PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), - PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), + FeedForward(dim, mlp_dim, dropout = dropout) ])) def forward(self, x): for attn, ff in self.layers: