From 9992a615d1f3e0117a62ced39a88d48b6be68858 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Fri, 19 Jul 2024 19:23:38 -0700 Subject: [PATCH] attention re-use in lookup vit should use pre-softmax attention matrix --- setup.py | 2 +- vit_pytorch/look_vit.py | 23 ++++++++++++----------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/setup.py b/setup.py index eea2d68..0b3450f 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '1.7.3', + version = '1.7.4', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', long_description=long_description, diff --git a/vit_pytorch/look_vit.py b/vit_pytorch/look_vit.py index 1651796..c8b7ffe 100644 --- a/vit_pytorch/look_vit.py +++ b/vit_pytorch/look_vit.py @@ -99,8 +99,8 @@ def forward( self, x, context = None, - return_attn = False, - attn = None + return_qk_sim = False, + qk_sim = None ): x = self.norm(x) @@ -119,20 +119,21 @@ def forward( q, k = tuple(self.split_heads(t) for t in qk) q = q * self.scale - sim = einsum(q, k, 'b h i d, b h j d -> b h i j') + qk_sim = einsum(q, k, 'b h i d, b h j d -> b h i j') - attn = self.attend(sim) - attn = self.dropout(attn) else: - assert exists(attn), 'attention matrix must be passed in for reusing previous attention' + assert exists(qk_sim), 'qk sim matrix must be passed in for reusing previous attention' + + attn = self.attend(qk_sim) + attn = self.dropout(attn) out = einsum(attn, v, 'b h i j, b h j d -> b h i d') out = self.to_out(out) - if not return_attn: + if not return_qk_sim: return out - return out, attn + return out, qk_sim # LookViT @@ -228,7 +229,7 @@ def forward(self, img): # main tokens cross attends (lookup) on the high res tokens - lookup_out, lookup_attn = lookup_cross_attn(tokens, highres_tokens, return_attn = True) # return attention as they reuse the attention matrix + lookup_out, qk_sim = lookup_cross_attn(tokens, highres_tokens, return_qk_sim = True) # return attention as they reuse the attention matrix tokens = lookup_out + tokens tokens = attn(tokens) + tokens @@ -236,9 +237,9 @@ def forward(self, img): # attention-reuse - lookup_attn = rearrange(lookup_attn, 'b h i j -> b h j i') # transpose for reverse cross attention + qk_sim = rearrange(qk_sim, 'b h i j -> b h j i') # transpose for reverse cross attention - highres_tokens = highres_attn(highres_tokens, tokens, attn = lookup_attn) + highres_tokens + highres_tokens = highres_attn(highres_tokens, tokens, qk_sim = qk_sim) + highres_tokens highres_tokens = highres_norm(highres_tokens) highres_tokens = highres_mlp(highres_tokens) + highres_tokens