Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor batch processing in calculate_coco_features.py #741

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 19 additions & 15 deletions app/calculate_coco_features.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
"""
# Copyright (c) 2022, salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
# Copyright (c) 2022, salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause

from PIL import Image
import requests
Expand All @@ -18,7 +16,6 @@

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def load_demo_image():
img_url = (
"https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg"
Expand All @@ -34,6 +31,17 @@ def read_img(filepath):
return raw_image


def process_batch(images_in_batch, filepaths_in_batch, feature_extractor, caption, path2feat):
images_in_batch = torch.cat(images_in_batch, dim=0).to(device)
with torch.no_grad():
image_features = feature_extractor(
images_in_batch, caption, mode="image", normalized=True
)[:, 0]

for filepath, image_feat in zip(filepaths_in_batch, image_features):
path2feat[os.path.basename(filepath)] = image_feat.detach().cpu()
return path2feat

# model
model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth"
feature_extractor = BlipFeatureExtractor(pretrained=model_url)
Expand Down Expand Up @@ -62,14 +70,7 @@ def read_img(filepath):

for i, filename in enumerate(filepaths):
if i % bsz == 0 and i > 0:
images_in_batch = torch.cat(images_in_batch, dim=0).to(device)
with torch.no_grad():
image_features = feature_extractor(
images_in_batch, caption, mode="image", normalized=True
)[:, 0]

for filepath, image_feat in zip(filepaths_in_batch, image_features):
path2feat[os.path.basename(filepath)] = image_feat.detach().cpu()
path2feat = process_batch(images_in_batch, filepaths_in_batch, feature_extractor, caption, path2feat)

images_in_batch = []
filepaths_in_batch = []
Expand All @@ -84,4 +85,7 @@ def read_img(filepath):
images_in_batch.append(image)
filepaths_in_batch.append(filepath)

path2feat = process_batch(images_in_batch, filepaths_in_batch, feature_extractor, caption, path2feat) # process remaining images

torch.save(path2feat, "path2feat_coco_train2014.pth")