From b4a1e2fd0285b04eabf5250cbd771c2abb123a37 Mon Sep 17 00:00:00 2001 From: AlpinDale <52078762+AlpinDale@users.noreply.github.com> Date: Mon, 23 Dec 2024 17:54:54 -0800 Subject: [PATCH] vlm: add tensor parallel support for vision transformer models (#971) --- aphrodite/modeling/models/blip.py | 73 ++++++++- aphrodite/modeling/models/blip2.py | 3 +- aphrodite/modeling/models/clip.py | 128 ++++++++++++++- aphrodite/modeling/models/intern_vit.py | 61 ++++--- aphrodite/modeling/models/paligemma.py | 48 +++--- aphrodite/modeling/models/phi3v.py | 46 ++++-- aphrodite/modeling/models/siglip.py | 209 ++++-------------------- 7 files changed, 310 insertions(+), 258 deletions(-) diff --git a/aphrodite/modeling/models/blip.py b/aphrodite/modeling/models/blip.py index 4e352ca59..d3bb01e5c 100644 --- a/aphrodite/modeling/models/blip.py +++ b/aphrodite/modeling/models/blip.py @@ -7,14 +7,16 @@ import torch.nn as nn from PIL import Image from transformers import Blip2VisionConfig, BlipVisionConfig -from transformers.models.blip.modeling_blip import BlipAttention +from xformers import ops as xops from aphrodite.common.config import ModelConfig from aphrodite.common.sequence import SequenceData from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE +from aphrodite.distributed import divide, get_tensor_model_parallel_world_size from aphrodite.inputs import LLMInputs from aphrodite.modeling.layers.activation import get_act_fn from aphrodite.modeling.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, RowParallelLinear) from aphrodite.multimodal.utils import (cached_get_tokenizer, repeat_and_pad_placeholder_tokens) @@ -155,6 +157,71 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: return embeddings +class BlipAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + def __init__( + self, + config: BlipVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + "embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads}).") + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + self.qkv = QKVParallelLinear( + self.embed_dim, + self.head_dim, + self.num_heads, + bias=config.qkv_bias, + quant_config=quant_config, + ) + self.projection = RowParallelLinear( + self.embed_dim, + self.embed_dim, + quant_config=quant_config, + ) + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_partition = divide(self.num_heads, self.tp_size) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, + self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + ): + """Input shape: Batch x Time x Channel""" + bsz, tgt_len, _ = hidden_states.size() + qkv_states, _ = self.qkv(hidden_states) + query_states, key_states, value_states = qkv_states.chunk(3, dim=-1) + query_states = query_states.view(bsz, tgt_len, + self.num_heads_per_partition, + self.head_dim) + key_states = key_states.view(bsz, tgt_len, + self.num_heads_per_partition, + self.head_dim) + value_states = value_states.view(bsz, tgt_len, + self.num_heads_per_partition, + self.head_dim) + out = xops.memory_efficient_attention_forward(query_states, + key_states, + value_states, + p=self.dropout, + scale=self.scale) + out = out.view(bsz, tgt_len, -1) + attn_output, _ = self.projection(out) + return attn_output + + class BlipMLP(nn.Module): def __init__(self, @@ -189,7 +256,7 @@ def __init__(self, quant_config: Optional[QuantizationConfig] = None): super().__init__() - self.self_attn = BlipAttention(config) + self.self_attn = BlipAttention(config, quant_config=quant_config) self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.mlp = BlipMLP(config, quant_config=quant_config) @@ -200,7 +267,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: residual = hidden_states hidden_states = self.layer_norm1(hidden_states) - hidden_states, _ = self.self_attn(hidden_states=hidden_states) + hidden_states = self.self_attn(hidden_states=hidden_states) hidden_states = residual + hidden_states residual = hidden_states diff --git a/aphrodite/modeling/models/blip2.py b/aphrodite/modeling/models/blip2.py index 5ef63eece..4c043faa6 100644 --- a/aphrodite/modeling/models/blip2.py +++ b/aphrodite/modeling/models/blip2.py @@ -700,8 +700,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): use_default_weight_loading = False if "vision" in name: if self.vision_model is not None: - # We only do sharding for language model and - # not vision model for now. + # BlipVisionModel does not need sharding use_default_weight_loading = True else: for (param_name, weight_name, diff --git a/aphrodite/modeling/models/clip.py b/aphrodite/modeling/models/clip.py index 9ba17931d..77b8a50eb 100644 --- a/aphrodite/modeling/models/clip.py +++ b/aphrodite/modeling/models/clip.py @@ -1,21 +1,24 @@ -"""Minimal implementation of CLIPVisionModel intended to be only used +"""Minimal implementation of CLIPVisionModel intended to be only used within a vision language model.""" from array import array -from typing import List, Optional, Union +from typing import Iterable, List, Optional, Tuple, Union import torch import torch.nn as nn from PIL import Image from transformers import CLIPVisionConfig -from transformers.models.clip.modeling_clip import CLIPAttention +from xformers import ops as xops from aphrodite.common.config import ModelConfig -from aphrodite.common.sequence import SequenceData -from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE +from aphrodite.common.sequence import (APHRODITE_TOKEN_ID_ARRAY_TYPE, + SequenceData) +from aphrodite.distributed import divide, get_tensor_model_parallel_world_size from aphrodite.inputs import LLMInputs from aphrodite.modeling.layers.activation import get_act_fn from aphrodite.modeling.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, RowParallelLinear) +from aphrodite.modeling.model_loader.weight_utils import default_weight_loader from aphrodite.multimodal.utils import (cached_get_tokenizer, repeat_and_pad_placeholder_tokens) from aphrodite.quantization import QuantizationConfig @@ -34,7 +37,7 @@ def get_clip_num_patches(*, image_size: int, patch_size: int) -> int: def get_clip_image_feature_size(hf_config: CLIPVisionConfig) -> int: return get_clip_num_patches(image_size=hf_config.image_size, - patch_size=hf_config.patch_size) + patch_size=hf_config.patch_size) + 1 def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int: @@ -160,6 +163,78 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: return embeddings +class CLIPAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: CLIPVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + "embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads}).") + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.qkv_proj = QKVParallelLinear( + hidden_size=self.embed_dim, + head_size=self.head_dim, + total_num_heads=self.num_heads, + quant_config=quant_config, + ) + + self.out_proj = RowParallelLinear( + input_size=self.embed_dim, + output_size=self.embed_dim, + quant_config=quant_config, + ) + + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_partition = divide(self.num_heads, self.tp_size) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, + self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + ): + """Input shape: Batch x Time x Channel""" + bsz, tgt_len, _ = hidden_states.size() + + qkv_states, _ = self.qkv_proj(hidden_states) + query_states, key_states, value_states = qkv_states.chunk(3, dim=-1) + + query_states = query_states.view(bsz, tgt_len, + self.num_heads_per_partition, + self.head_dim) + key_states = key_states.view(bsz, tgt_len, + self.num_heads_per_partition, + self.head_dim) + value_states = value_states.view(bsz, tgt_len, + self.num_heads_per_partition, + self.head_dim) + + out = xops.memory_efficient_attention_forward(query_states, + key_states, + value_states, + p=self.dropout, + scale=self.scale) + out = out.view(bsz, tgt_len, -1) + attn_output, _ = self.out_proj(out) + + return attn_output + + class CLIPMLP(nn.Module): def __init__(self, @@ -192,7 +267,7 @@ def __init__(self, quant_config: Optional[QuantizationConfig] = None): super().__init__() - self.self_attn = CLIPAttention(config) + self.self_attn = CLIPAttention(config, quant_config=quant_config) self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.mlp = CLIPMLP(config, quant_config=quant_config) @@ -204,7 +279,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: residual = hidden_states hidden_states = self.layer_norm1(hidden_states) - hidden_states, _ = self.self_attn(hidden_states=hidden_states) + hidden_states = self.self_attn(hidden_states=hidden_states) hidden_states = residual + hidden_states residual = hidden_states @@ -217,7 +292,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class CLIPEncoder(nn.Module): """ - Transformer encoder consisting of `config.num_hidden_layers` self + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a [`CLIPEncoderLayer`]. Args: @@ -303,3 +378,38 @@ def forward(self, pixel_values: Optional[torch.Tensor] = None): @property def device(self): return next(self.parameters()).device + + # TODO: Add prefix argument for filtering out weights to be loaded + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + layer_count = len(self.vision_model.encoder.layers) + + for name, loaded_weight in weights: + # post_layernorm is not needed in CLIPVisionModel + if "vision_model.post_layernorm" in name: + continue + # omit layers when num_hidden_layers_override is set + if "vision_model.encoder.layers." in name: + layer_idx = int(name.split(".")[3]) + if layer_idx >= layer_count: + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + + param = params_dict[name.replace(weight_name, param_name)] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/aphrodite/modeling/models/intern_vit.py b/aphrodite/modeling/models/intern_vit.py index 4b7b0b85d..032b0bf90 100644 --- a/aphrodite/modeling/models/intern_vit.py +++ b/aphrodite/modeling/models/intern_vit.py @@ -10,10 +10,13 @@ import torch.nn as nn import torch.nn.functional as F from transformers import PretrainedConfig +from xformers import ops as xops +from aphrodite.distributed import divide, get_tensor_model_parallel_world_size from aphrodite.modeling.layers.activation import get_act_fn from aphrodite.modeling.layers.layernorm import RMSNorm from aphrodite.modeling.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, RowParallelLinear) from aphrodite.modeling.model_loader.weight_utils import default_weight_loader from aphrodite.quantization import QuantizationConfig @@ -81,7 +84,11 @@ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: class InternAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: PretrainedConfig): + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ): super().__init__() self.config = config self.embed_dim = config.hidden_size @@ -94,9 +101,13 @@ def __init__(self, config: PretrainedConfig): f' {self.num_heads}).') self.scale = self.head_dim**-0.5 - self.qkv = nn.Linear(self.embed_dim, - 3 * self.embed_dim, - bias=config.qkv_bias) + self.qkv = QKVParallelLinear( + self.embed_dim, + self.head_dim, + self.num_heads, + bias=config.qkv_bias, + quant_config=quant_config, + ) self.qk_normalization = config.qk_normalization @@ -104,25 +115,37 @@ def __init__(self, config: PretrainedConfig): self.q_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps) self.k_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps) - self.proj = nn.Linear(self.embed_dim, self.embed_dim) + self.proj = RowParallelLinear( + self.embed_dim, + self.embed_dim, + quant_config=quant_config, + ) + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_partition = divide(self.num_heads, self.tp_size) def forward(self, x): B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, - C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv.unbind(0) - - if self.qk_normalization: - B_, H_, N_, D_ = q.shape - q = self.q_norm.forward_native(q.transpose(1, 2).flatten( - -2, -1)).view(B_, N_, H_, D_).transpose(1, 2) - k = self.k_norm.forward_native(k.transpose(1, 2).flatten( - -2, -1)).view(B_, N_, H_, D_).transpose(1, 2) + qkv, _ = self.qkv(x) + q, k, v = qkv.chunk(3, dim=-1) - x = F.scaled_dot_product_attention(q, k, v, scale=self.scale) - x = x.transpose(1, 2).reshape(B, N, C) + q = q.view(B, N, self.num_heads_per_partition, self.head_dim) + k = k.view(B, N, self.num_heads_per_partition, self.head_dim) + v = v.view(B, N, self.num_heads_per_partition, self.head_dim) - x = self.proj(x) + if self.qk_normalization: + B_, N_, H_, D_ = q.shape + q = self.q_norm.forward_native(q.flatten(-2, + -1)).view(B_, N_, H_, D_) + k = self.k_norm.forward_native(k.flatten(-2, + -1)).view(B_, N_, H_, D_) + x = xops.memory_efficient_attention_forward( + q, + k, + v, + scale=self.scale, + ) + x = x.view(B, N, -1) + x, _ = self.proj(x) return x @@ -161,7 +184,7 @@ def __init__(self, self.intermediate_size = config.intermediate_size self.norm_type = config.norm_type - self.attn = InternAttention(config) + self.attn = InternAttention(config, quant_config=quant_config) self.mlp = InternMLP(config, quant_config=quant_config) self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps) diff --git a/aphrodite/modeling/models/paligemma.py b/aphrodite/modeling/models/paligemma.py index eea4dde05..48b0e926a 100644 --- a/aphrodite/modeling/models/paligemma.py +++ b/aphrodite/modeling/models/paligemma.py @@ -145,7 +145,6 @@ def __init__(self, self.config = config self.multimodal_config = multimodal_config - # TODO: Port over SiglipVisionModel & TP self.vision_tower = SiglipVisionModel(config.vision_config) self.multi_modal_projector = PaliGemmaMultiModalProjector( vision_hidden_size=config.vision_config.hidden_size, @@ -309,34 +308,27 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if key_to_modify in name: name = name.replace(key_to_modify, new_key) use_default_weight_loading = False - if "vision" in name: - if self.vision_tower is not None: - # We only do sharding for language model and - # not vision model for now. - use_default_weight_loading = True + for (param_name, shard_name, shard_id) in stacked_params_mapping: + if shard_name not in name: + continue + name = name.replace(shard_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break else: - for (param_name, shard_name, - shard_id) in stacked_params_mapping: - if shard_name not in name: - continue - name = name.replace(shard_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # lm_head is not used in vllm as it is tied with - # embed_token. To prevent errors, skip loading - # lm_head.weight. - if "lm_head.weight" in name: - continue - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - use_default_weight_loading = True + # lm_head is not used in vllm as it is tied with + # embed_token. To prevent errors, skip loading + # lm_head.weight. + if "lm_head.weight" in name: + continue + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + use_default_weight_loading = True if use_default_weight_loading: param = params_dict[name] diff --git a/aphrodite/modeling/models/phi3v.py b/aphrodite/modeling/models/phi3v.py index 50d4b90d5..8de32a0cf 100644 --- a/aphrodite/modeling/models/phi3v.py +++ b/aphrodite/modeling/models/phi3v.py @@ -71,6 +71,20 @@ projection_dim=768) +def _init_img_processor(hf_config: PretrainedConfig): + clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG + layer_idx = hf_config.img_processor.get('layer_idx', -2) + # Initialize the CLIP only up to the required feature layer + if layer_idx < 0: + num_hidden_layers = clip_config.num_hidden_layers + \ + layer_idx + 1 + else: + num_hidden_layers = layer_idx + 1 + img_processor = CLIPVisionModel( + clip_config, num_hidden_layers_override=num_hidden_layers) + return img_processor + + class Phi3VImagePixelInputs(TypedDict): type: Literal["pixel_values"] data: Union[torch.Tensor, List[torch.Tensor]] @@ -136,18 +150,8 @@ def __init__(self, config: PretrainedConfig) -> None: hidden_size = config.n_embd if hasattr( config, 'n_embd') else config.hidden_size - clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG - self.layer_idx = config.img_processor.get('layer_idx', -2) - # Initialize the CLIP only up to the required feature layer - if self.layer_idx < 0: - num_hidden_layers = clip_config.num_hidden_layers + \ - self.layer_idx + 1 - else: - num_hidden_layers = self.layer_idx + 1 - - self.img_processor = CLIPVisionModel( - clip_config, num_hidden_layers_override=num_hidden_layers) + self.img_processor = _init_img_processor(config) image_dim_out = config.img_processor['image_dim_out'] self.num_img_tokens = config.img_processor['num_img_tokens'] @@ -649,21 +653,22 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): (".gate_up_proj", ".gate_proj", 0), (".gate_up_proj", ".up_proj", 1), ] + # TODO: This is a temporary fix to load + # the vision weights with CLIPVisionModel.load_weights() + vision_weights = [] params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - # post_layernorm is not needed in CLIPVisionModel - if "vision_model.post_layernorm" in name: + # Skip loading the img_processor weights since they are + # loaded separately. + if "vision_embed_tokens.img_processor" in name: + vision_weights.append((name, loaded_weight)) continue for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): if key_to_modify in name: name = name.replace(key_to_modify, new_key) for (param_name, weight_name, shard_id) in stacked_params_mapping: - # We only do sharding for language model - # and not vision model for now. - if "vision_embed_tokens" in name and self.vision_embed_tokens: - continue if weight_name not in name: continue param = params_dict[name.replace(weight_name, param_name)] @@ -679,3 +684,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + # We use regex to extract the sub-module name + # from "model.vision_embed_tokens.img_processor.*" + vision_weights = [ + (re.search(r"vision_embed_tokens\.img_processor\.(.*)", + n).group(1), w) for n, w in vision_weights + ] + self.vision_embed_tokens.img_processor.load_weights(vision_weights) diff --git a/aphrodite/modeling/models/siglip.py b/aphrodite/modeling/models/siglip.py index dca408374..430f7acd3 100644 --- a/aphrodite/modeling/models/siglip.py +++ b/aphrodite/modeling/models/siglip.py @@ -6,17 +6,15 @@ from typing import Iterable, List, Optional, Tuple, Union import torch -from aphrodite_flash_attn import flash_attn_func from PIL import Image from torch import nn from transformers import SiglipVisionConfig -from transformers.models.siglip.modeling_siglip import SiglipAttention -from xformers.ops import memory_efficient_attention +from xformers import ops as xops from aphrodite.common.config import ModelConfig from aphrodite.common.sequence import SequenceData from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE -from aphrodite.distributed import get_tensor_model_parallel_world_size +from aphrodite.distributed import divide, get_tensor_model_parallel_world_size from aphrodite.inputs import LLMInputs from aphrodite.modeling.layers.activation import get_act_fn from aphrodite.modeling.layers.linear import (ColumnParallelLinear, @@ -222,9 +220,7 @@ def forward(self, return embeddings -# NOTE: Not used - kept for later when we TP the ViT -# TODO(ChristopherCho): Implement TP version of Attention -class SiglipTPAttention(nn.Module): +class SiglipAttention(nn.Module): def __init__( self, @@ -234,29 +230,18 @@ def __init__( super().__init__() self.config = config self.embed_dim = config.hidden_size - - tp_size = get_tensor_model_parallel_world_size() - self.total_num_heads = config.num_attention_heads - if self.total_num_heads % tp_size != 0: - raise ValueError( - f"Number of attention heads ({self.total_num_heads}) " - "must be divisible by the tensor model parallel size" - f" ({tp_size}).") - - self.num_heads = self.total_num_heads // tp_size - self.head_dim = self.embed_dim // self.total_num_heads - if self.head_dim * self.total_num_heads != self.embed_dim: + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: raise ValueError(f"embed_dim must be divisible by num_heads (got " - "`embed_dim`: {self.embed_dim} and `num_heads`:" + f"`embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads}).") - self.qkv_size = self.num_heads * self.head_dim self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout - self.qkv_proj = QKVParallelLinear( hidden_size=self.embed_dim, head_size=self.head_dim, - total_num_heads=self.total_num_heads, + total_num_heads=self.num_heads, quant_config=quant_config, ) self.out_proj = RowParallelLinear( @@ -265,7 +250,8 @@ def __init__( quant_config=quant_config, ) - self.attn_fn = self._basic_attention_forward + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_partition = divide(self.num_heads, self.tp_size) def forward( self, @@ -275,163 +261,27 @@ def forward( batch_size, q_len, _ = hidden_states.size() qkv_states, _ = self.qkv_proj(hidden_states) - query_states, key_states, value_states = qkv_states.split( - [self.qkv_size] * 3, dim=-1) - - attn_output = self.attn_fn( - q=query_states, - k=key_states, - v=value_states, - batch_size=batch_size, - q_len=q_len, - ) - - attn_output, _ = self.out_proj(attn_output) - return attn_output - - def _basic_attention_forward(self, q, k, v, batch_size, q_len): - q = q.view(batch_size, q_len, self.num_heads, - self.head_dim).transpose(1, 2) - k = k.view(batch_size, q_len, self.num_heads, - self.head_dim).transpose(1, 2) - v = v.view(batch_size, q_len, self.num_heads, - self.head_dim).transpose(1, 2) - - k_v_seq_len = k.shape[-2] - attn_weights = torch.matmul(q, k.transpose(2, 3)) * self.scale - - if attn_weights.size() != ( - batch_size, - self.num_heads, - q_len, - k_v_seq_len, - ): - raise ValueError( - "Attention weights should be of size " - f"{(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" - f" {attn_weights.size()}") - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, - dim=-1, - dtype=torch.float32).to(q.dtype) - attn_weights = nn.functional.dropout(attn_weights, - p=self.dropout, - training=self.training) - attn_output = torch.matmul(attn_weights, v) - - if attn_output.size() != ( - batch_size, - self.num_heads, - q_len, - self.head_dim, - ): - raise ValueError( - "`attn_output` should be of size " - f"{(batch_size, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}") - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) - - return attn_output - - -# NOTE: Not used - kept for later when we TP the ViT -# TODO(ChristopherCho): flash_attn_func is not working properly. -# It constantly throws a CUDA error. -class SiglipFlashAttention2(SiglipTPAttention): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.attn_fn = self._flash_attention_forward - - # Ported from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L449 - # and https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/modeling_flash_attention_utils.py#L133 - def _flash_attention_forward(self, q, k, v, batch_size, q_len, *args, - **kwargs): - """Implements the multihead softmax attention. - Arguments - --------- - q, k, v: The tensor containing the - query, key, and value. (B, S, H, D) - """ - - q = q.view(batch_size, q_len, self.num_heads, self.head_dim) - k = k.view(batch_size, q_len, self.num_heads, self.head_dim) - v = v.view(batch_size, q_len, self.num_heads, self.head_dim) - - attn_output = flash_attn_func( - q, - k, - v, - dropout_p=self.dropout, - causal=False, - ) - - attn_output = attn_output.reshape(batch_size, q_len, - self.embed_dim).contiguous() - - return attn_output - - -# NOTE: Not used - kept for later when we TP the ViT -class SiglipSdpaAttention(SiglipTPAttention): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.is_causal = False - self.attn_fn = self._sdpa_attention_forward - - def _sdpa_attention_forward(self, q, k, v, batch_size, q_len): - q = q.view(batch_size, q_len, self.num_heads, - self.head_dim).transpose(1, 2) - k = k.view(batch_size, q_len, self.num_heads, - self.head_dim).transpose(1, 2) - v = v.view(batch_size, q_len, self.num_heads, - self.head_dim).transpose(1, 2) - - attn_output = torch.nn.functional.scaled_dot_product_attention( - q, k, v, dropout_p=self.dropout, is_causal=False, scale=self.scale) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(batch_size, q_len, self.embed_dim) - - return attn_output - - -# NOTE: Not used - kept for later when we TP the ViT -class SiglipxFormersAttention(SiglipTPAttention): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.attn_fn = self._xformers_attention_forward - - def _xformers_attention_forward(self, q, k, v, batch_size, q_len): - q = q.view(batch_size, q_len, self.num_heads, self.head_dim) - k = k.view(batch_size, q_len, self.num_heads, self.head_dim) - v = v.view(batch_size, q_len, self.num_heads, self.head_dim) - - attn_output = memory_efficient_attention(q, - k, - v, - p=0.0, - scale=self.scale) - attn_output = attn_output.reshape(batch_size, q_len, - self.embed_dim).contiguous() + query_states, key_states, value_states = qkv_states.chunk(3, dim=-1) + query_states = query_states.view(batch_size, q_len, + self.num_heads_per_partition, + self.head_dim) + key_states = key_states.view(batch_size, q_len, + self.num_heads_per_partition, + self.head_dim) + value_states = value_states.view(batch_size, q_len, + self.num_heads_per_partition, + self.head_dim) + out = xops.memory_efficient_attention_forward(query_states, + key_states, + value_states, + p=self.dropout, + scale=self.scale) + out = out.view(batch_size, q_len, -1) + attn_output, _ = self.out_proj(out) return attn_output -# NOTE: Not used - kept for later when we TP the ViT -SIGLIP_ATTENTION_CLASSES = { - "eager": SiglipTPAttention, - "flash_attention_2": SiglipFlashAttention2, - "sdpa": SiglipSdpaAttention, - "xformers": SiglipxFormersAttention, -} - - class SiglipMLP(nn.Module): def __init__( @@ -474,8 +324,7 @@ def __init__( super().__init__() self.embed_dim = config.hidden_size - # TODO(ChristopherCho): use TP'ed Attention block - self.self_attn = SiglipAttention(config) + self.self_attn = SiglipAttention(config, quant_config=quant_config) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = SiglipMLP( @@ -492,7 +341,7 @@ def forward( residual = hidden_states hidden_states = self.layer_norm1(hidden_states) - hidden_states, _ = self.self_attn(hidden_states=hidden_states) + hidden_states = self.self_attn(hidden_states=hidden_states) hidden_states = residual + hidden_states residual = hidden_states