Skip to content

Commit

Permalink
Fix data type conversion issue in MLP and
Browse files Browse the repository at this point in the history
Attention modules
  • Loading branch information
guobentian committed Nov 27, 2023
1 parent ebe3bf2 commit 5cb2cc9
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ git clone https://github.com/salesforce/LAVIS.git
cd LAVIS
pip install -e .
```
If you are using arm cpu(for example Mac with Apple Silicon), please use `requirements_arm.txt` instead of `requirements.txt` to install the dependencies.

## Getting Started
### Model Zoo
Expand Down
16 changes: 16 additions & 0 deletions lavis/models/eva_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,18 @@ def __init__(self, in_features, hidden_features=None, out_features=None, act_lay
self.drop = nn.Dropout(drop)

def forward(self, x):
if self.fc1.weight.dtype == torch.float16:
x = x.half()
elif self.fc1.weight.dtype == torch.float32:
x = x.float()
x = self.fc1(x)
x = self.act(x)
# x = self.drop(x)
# commit this for the orignal BERT implement
if self.fc2.weight.dtype == torch.float16:
x = x.half()
elif self.fc2.weight.dtype == torch.float32:
x = x.float()
x = self.fc2(x)
x = self.drop(x)
return x
Expand Down Expand Up @@ -143,6 +151,10 @@ def forward(self, x, rel_pos_bias=None):
attn = self.attn_drop(attn)

x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
if self.proj.weight.dtype == torch.float16:
x = x.half()
elif self.proj.weight.dtype == torch.float32:
x = x.float()
x = self.proj(x)
x = self.proj_drop(x)
return x
Expand Down Expand Up @@ -200,6 +212,10 @@ def forward(self, x, **kwargs):
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
if self.proj.weight.dtype == torch.float16:
x = x.half()
elif self.proj.weight.dtype == torch.float32:
x = x.float()
x = self.proj(x).flatten(2).transpose(1, 2)
return x

Expand Down

0 comments on commit 5cb2cc9

Please sign in to comment.