From baad2d7c8df599d8d9b081ba2e946626eaa2dc34 Mon Sep 17 00:00:00 2001 From: Dongxu Date: Mon, 6 Mar 2023 20:14:55 +0800 Subject: [PATCH] Fix BLIP2 mixed precision on CPU (#179) * allow both cpu string and torch.device to be identified for model loading. * blip2 amp cpu compatibility. * use dtype=float16 by default. --- lavis/models/__init__.py | 3 +- lavis/models/blip2_models/blip2.py | 30 +++++++++++------ .../blip2_models/blip2_image_text_matching.py | 4 +-- lavis/models/blip2_models/blip2_opt.py | 32 +++++++++++-------- lavis/models/blip2_models/blip2_qformer.py | 20 ++++++------ lavis/models/blip2_models/blip2_t5.py | 16 +++++----- 6 files changed, 61 insertions(+), 44 deletions(-) diff --git a/lavis/models/__init__.py b/lavis/models/__init__.py index dcb7d55d8..c658eb8f6 100644 --- a/lavis/models/__init__.py +++ b/lavis/models/__init__.py @@ -6,6 +6,7 @@ """ import logging +import torch from omegaconf import OmegaConf from lavis.common.registry import registry @@ -211,7 +212,7 @@ def load_model_and_preprocess(name, model_type, is_eval=False, device="cpu"): """ ) - if device == "cpu": + if device == "cpu" or device == torch.device("cpu"): model = model.float() return model.to(device), vis_processors, txt_processors diff --git a/lavis/models/blip2_models/blip2.py b/lavis/models/blip2_models/blip2.py index 03a98e0d1..2259e1ac8 100644 --- a/lavis/models/blip2_models/blip2.py +++ b/lavis/models/blip2_models/blip2.py @@ -4,6 +4,7 @@ SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ +import contextlib import logging import os import time @@ -32,6 +33,16 @@ def init_tokenizer(cls): tokenizer.add_special_tokens({"bos_token": "[DEC]"}) return tokenizer + def maybe_autocast(self, dtype=torch.float16): + # if on cpu, don't use autocast + # if on gpu, use autocast with dtype if provided, otherwise use torch.float16 + enable_autocast = self.device != torch.device("cpu") + + if enable_autocast: + return torch.cuda.amp.autocast(dtype=dtype) + else: + return contextlib.nullcontext() + @classmethod def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2): encoder_config = BertConfig.from_pretrained("bert-base-uncased") @@ -42,7 +53,7 @@ def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2): encoder_config.query_length = num_query_token Qformer = BertLMHeadModel.from_pretrained( "bert-base-uncased", config=encoder_config - ) + ) query_tokens = nn.Parameter( torch.zeros(1, num_query_token, encoder_config.hidden_size) ) @@ -52,16 +63,17 @@ def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2): @classmethod def init_vision_encoder( cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision - ): - assert model_name in ["eva_clip_g","clip_L"], "vit model must be eva_clip_g or clip_L" - if model_name=="eva_clip_g": + ): + assert model_name in [ + "eva_clip_g", + "clip_L", + ], "vit model must be eva_clip_g or clip_L" + if model_name == "eva_clip_g": visual_encoder = create_eva_vit_g( img_size, drop_path_rate, use_grad_checkpoint, precision ) - elif model_name=="clip_L": - visual_encoder = create_clip_vit_L( - img_size, use_grad_checkpoint, precision - ) + elif model_name == "clip_L": + visual_encoder = create_clip_vit_L(img_size, use_grad_checkpoint, precision) ln_vision = LayerNorm(visual_encoder.num_features) return visual_encoder, ln_vision @@ -80,7 +92,7 @@ def load_from_pretrained(self, url_or_filename): msg = self.load_state_dict(state_dict, strict=False) - logging.info("Missing keys {}".format(msg.missing_keys)) + # logging.info("Missing keys {}".format(msg.missing_keys)) logging.info("load checkpoint from %s" % url_or_filename) return msg diff --git a/lavis/models/blip2_models/blip2_image_text_matching.py b/lavis/models/blip2_models/blip2_image_text_matching.py index 24eb83b2c..f32db24d0 100644 --- a/lavis/models/blip2_models/blip2_image_text_matching.py +++ b/lavis/models/blip2_models/blip2_image_text_matching.py @@ -54,9 +54,9 @@ def forward(self, samples, match_head="itm"): image = samples["image"] caption = samples["text_input"] - with torch.cuda.amp.autocast(enabled=(self.device != torch.device("cpu"))): + with self.maybe_autocast(): image_embeds = self.ln_vision(self.visual_encoder(image)) - image_embeds = image_embeds.float() + image_embeds = image_embeds.float() image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( image.device ) diff --git a/lavis/models/blip2_models/blip2_opt.py b/lavis/models/blip2_models/blip2_opt.py index 6e38e08a4..14cb4ea5d 100644 --- a/lavis/models/blip2_models/blip2_opt.py +++ b/lavis/models/blip2_models/blip2_opt.py @@ -59,7 +59,7 @@ def __init__( ) if freeze_vit: for name, param in self.visual_encoder.named_parameters(): - param.requires_grad = False + param.requires_grad = False self.visual_encoder = self.visual_encoder.eval() self.visual_encoder.train = disabled_train logging.info("freeze vision encoder") @@ -95,7 +95,8 @@ def __init__( def forward(self, samples): image = samples["image"] - image_embeds = self.ln_vision(self.visual_encoder(image)) + with self.maybe_autocast(): + image_embeds = self.ln_vision(self.visual_encoder(image)) image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( image.device ) @@ -138,12 +139,13 @@ def forward(self, samples): inputs_embeds = torch.cat([inputs_opt, inputs_embeds], dim=1) attention_mask = torch.cat([atts_opt, opt_tokens.attention_mask], dim=1) - outputs = self.opt_model( - inputs_embeds=inputs_embeds, - attention_mask=attention_mask, - return_dict=True, - labels=targets, - ) + with self.maybe_autocast(): + outputs = self.opt_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + return_dict=True, + labels=targets, + ) loss = outputs.loss return {"loss": loss} @@ -177,9 +179,7 @@ def generate( captions (list): A list of strings of length batch_size * num_captions. """ image = samples["image"] - with torch.cuda.amp.autocast( - enabled=(self.device != torch.device("cpu")) - ): + with self.maybe_autocast(): image_embeds = self.ln_vision(self.visual_encoder(image)) image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( image.device @@ -194,7 +194,9 @@ def generate( ) inputs_opt = self.opt_proj(query_output.last_hidden_state) - atts_opt = torch.ones(inputs_opt.size()[:-1], dtype=torch.long).to(image.device) + atts_opt = torch.ones(inputs_opt.size()[:-1], dtype=torch.long).to( + image.device + ) if "prompt" in samples.keys(): prompt = samples["prompt"] @@ -203,7 +205,9 @@ def generate( prompt = [prompt] * image.size(0) - opt_tokens = self.opt_tokenizer(prompt, return_tensors="pt").to(image.device) + opt_tokens = self.opt_tokenizer(prompt, return_tensors="pt").to( + image.device + ) input_ids = opt_tokens.input_ids attention_mask = torch.cat([atts_opt, opt_tokens.attention_mask], dim=1) @@ -238,7 +242,7 @@ def generate( @classmethod def from_config(cls, cfg): - vit_model = cfg.get("vit_model","eva_clip_g") + vit_model = cfg.get("vit_model", "eva_clip_g") img_size = cfg.get("image_size") num_query_token = cfg.get("num_query_token") opt_model = cfg.get("opt_model") diff --git a/lavis/models/blip2_models/blip2_qformer.py b/lavis/models/blip2_models/blip2_qformer.py index e817c7181..3fb078042 100644 --- a/lavis/models/blip2_models/blip2_qformer.py +++ b/lavis/models/blip2_models/blip2_qformer.py @@ -64,9 +64,9 @@ def __init__( ) if freeze_vit: for name, param in self.visual_encoder.named_parameters(): - param.requires_grad = False + param.requires_grad = False self.visual_encoder = self.visual_encoder.eval() - self.visual_encoder.train = disabled_train + self.visual_encoder.train = disabled_train logging.info("freeze vision encoder") self.Qformer, self.query_tokens = self.init_Qformer( num_query_token, self.visual_encoder.num_features, cross_attention_freq @@ -90,7 +90,7 @@ def __init__( def forward(self, samples): image = samples["image"] text = samples["text_input"] - + image_embeds = self.ln_vision(self.visual_encoder(image)) image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( image.device @@ -247,7 +247,7 @@ def forward(self, samples): return_dict=True, labels=labels, ) - + loss_lm = lm_output.loss return BlipOutput( @@ -403,9 +403,9 @@ def extract_features(self, samples, mode="multimodal"): image is not None ), "Image is not provided for mode 'image' or 'multimodal'" # return query features - with torch.cuda.amp.autocast(enabled=(self.device != torch.device("cpu"))): + with self.maybe_autocast(): image_embeds_frozen = self.ln_vision(self.visual_encoder(image)) - image_embeds_frozen = image_embeds_frozen.float() + image_embeds_frozen = image_embeds_frozen.float() image_atts = torch.ones( image_embeds_frozen.size()[:-1], dtype=torch.long ).to(self.device) @@ -443,9 +443,9 @@ def extract_features(self, samples, mode="multimodal"): elif mode == "multimodal": # return multimodel query features - with torch.cuda.amp.autocast(enabled=(self.device != torch.device("cpu"))): + with self.maybe_autocast(): image_embeds_frozen = self.ln_vision(self.visual_encoder(image)) - image_embeds_frozen = image_embeds_frozen.float() + image_embeds_frozen = image_embeds_frozen.float() image_atts = torch.ones( image_embeds_frozen.size()[:-1], dtype=torch.long ).to(self.device) @@ -482,10 +482,10 @@ def extract_features(self, samples, mode="multimodal"): @classmethod def from_config(cls, cfg): - vit_model = cfg.get("vit_model","eva_clip_g") + vit_model = cfg.get("vit_model", "eva_clip_g") img_size = cfg.get("image_size") num_query_token = cfg.get("num_query_token") - cross_attention_freq = cfg.get("cross_attention_freq",2) + cross_attention_freq = cfg.get("cross_attention_freq", 2) drop_path_rate = cfg.get("drop_path_rate", 0) use_grad_checkpoint = cfg.get("use_grad_checkpoint", False) diff --git a/lavis/models/blip2_models/blip2_t5.py b/lavis/models/blip2_models/blip2_t5.py index 4df710385..ba98e4318 100644 --- a/lavis/models/blip2_models/blip2_t5.py +++ b/lavis/models/blip2_models/blip2_t5.py @@ -101,7 +101,9 @@ def __init__( def forward(self, samples): image = samples["image"] - image_embeds = self.ln_vision(self.visual_encoder(image)) + + with self.maybe_autocast(): + image_embeds = self.ln_vision(self.visual_encoder(image)) image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( image.device ) @@ -117,7 +119,7 @@ def forward(self, samples): inputs_t5 = self.t5_proj(query_output.last_hidden_state) atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device) - with torch.cuda.amp.autocast(dtype=torch.bfloat16): + with self.maybe_autocast(dtype=torch.bfloat16): input_tokens = self.t5_tokenizer( samples["text_input"], padding="longest", @@ -182,9 +184,8 @@ def generate( captions (list): A list of strings of length batch_size * num_captions. """ image = samples["image"] - enable_autocast = self.device != torch.device("cpu") - with torch.cuda.amp.autocast(enabled=enable_autocast): + with self.maybe_autocast(): image_embeds = self.ln_vision(self.visual_encoder(image)) image_embeds = image_embeds.float() image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( @@ -220,7 +221,7 @@ def generate( encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1) - with torch.cuda.amp.autocast(enabled=enable_autocast, dtype=torch.bfloat16): + with self.maybe_autocast(dtype=torch.bfloat16): inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids) inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1) @@ -257,7 +258,7 @@ def predict_answers( **kwargs ): image = samples["image"] - with torch.cuda.amp.autocast(enabled=(self.device != torch.device("cpu"))): + with self.maybe_autocast(): image_embeds = self.ln_vision(self.visual_encoder(image)) image_embeds = image_embeds.float() image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( @@ -288,8 +289,7 @@ def predict_answers( encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1) - device_type = "cuda" if "cuda" in str(self.device) else "cpu" - with torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16): + with self.maybe_autocast(dtype=torch.bfloat16): inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids) inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1)