Skip to content

Commit

Permalink
vlm: add tensor parallel support for vision transformer models (#971)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlpinDale authored Dec 24, 2024
1 parent 61103b9 commit b4a1e2f
Show file tree
Hide file tree
Showing 7 changed files with 310 additions and 258 deletions.
73 changes: 70 additions & 3 deletions aphrodite/modeling/models/blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@
import torch.nn as nn
from PIL import Image
from transformers import Blip2VisionConfig, BlipVisionConfig
from transformers.models.blip.modeling_blip import BlipAttention
from xformers import ops as xops

from aphrodite.common.config import ModelConfig
from aphrodite.common.sequence import SequenceData
from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
from aphrodite.distributed import divide, get_tensor_model_parallel_world_size
from aphrodite.inputs import LLMInputs
from aphrodite.modeling.layers.activation import get_act_fn
from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from aphrodite.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens)
Expand Down Expand Up @@ -155,6 +157,71 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
return embeddings


class BlipAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
config: BlipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
"embed_dim must be divisible by num_heads "
f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads}).")
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.qkv = QKVParallelLinear(
self.embed_dim,
self.head_dim,
self.num_heads,
bias=config.qkv_bias,
quant_config=quant_config,
)
self.projection = RowParallelLinear(
self.embed_dim,
self.embed_dim,
quant_config=quant_config,
)
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)

def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads,
self.head_dim).transpose(1, 2).contiguous()

def forward(
self,
hidden_states: torch.Tensor,
):
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, _ = hidden_states.size()
qkv_states, _ = self.qkv(hidden_states)
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
query_states = query_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
key_states = key_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
value_states = value_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
out = xops.memory_efficient_attention_forward(query_states,
key_states,
value_states,
p=self.dropout,
scale=self.scale)
out = out.view(bsz, tgt_len, -1)
attn_output, _ = self.projection(out)
return attn_output


class BlipMLP(nn.Module):

def __init__(self,
Expand Down Expand Up @@ -189,7 +256,7 @@ def __init__(self,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()

self.self_attn = BlipAttention(config)
self.self_attn = BlipAttention(config, quant_config=quant_config)
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.mlp = BlipMLP(config, quant_config=quant_config)
Expand All @@ -200,7 +267,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual = hidden_states

hidden_states = self.layer_norm1(hidden_states)
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states

residual = hidden_states
Expand Down
3 changes: 1 addition & 2 deletions aphrodite/modeling/models/blip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,8 +700,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
use_default_weight_loading = False
if "vision" in name:
if self.vision_model is not None:
# We only do sharding for language model and
# not vision model for now.
# BlipVisionModel does not need sharding
use_default_weight_loading = True
else:
for (param_name, weight_name,
Expand Down
128 changes: 119 additions & 9 deletions aphrodite/modeling/models/clip.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
"""Minimal implementation of CLIPVisionModel intended to be only used
"""Minimal implementation of CLIPVisionModel intended to be only used
within a vision language model."""
from array import array
from typing import List, Optional, Union
from typing import Iterable, List, Optional, Tuple, Union

import torch
import torch.nn as nn
from PIL import Image
from transformers import CLIPVisionConfig
from transformers.models.clip.modeling_clip import CLIPAttention
from xformers import ops as xops

from aphrodite.common.config import ModelConfig
from aphrodite.common.sequence import SequenceData
from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
from aphrodite.common.sequence import (APHRODITE_TOKEN_ID_ARRAY_TYPE,
SequenceData)
from aphrodite.distributed import divide, get_tensor_model_parallel_world_size
from aphrodite.inputs import LLMInputs
from aphrodite.modeling.layers.activation import get_act_fn
from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
from aphrodite.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens)
from aphrodite.quantization import QuantizationConfig
Expand All @@ -34,7 +37,7 @@ def get_clip_num_patches(*, image_size: int, patch_size: int) -> int:

def get_clip_image_feature_size(hf_config: CLIPVisionConfig) -> int:
return get_clip_num_patches(image_size=hf_config.image_size,
patch_size=hf_config.patch_size)
patch_size=hf_config.patch_size) + 1


def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int:
Expand Down Expand Up @@ -160,6 +163,78 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
return embeddings


class CLIPAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

def __init__(
self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
"embed_dim must be divisible by num_heads "
f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads}).")
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout

self.qkv_proj = QKVParallelLinear(
hidden_size=self.embed_dim,
head_size=self.head_dim,
total_num_heads=self.num_heads,
quant_config=quant_config,
)

self.out_proj = RowParallelLinear(
input_size=self.embed_dim,
output_size=self.embed_dim,
quant_config=quant_config,
)

self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)

def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads,
self.head_dim).transpose(1, 2).contiguous()

def forward(
self,
hidden_states: torch.Tensor,
):
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, _ = hidden_states.size()

qkv_states, _ = self.qkv_proj(hidden_states)
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)

query_states = query_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
key_states = key_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
value_states = value_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)

out = xops.memory_efficient_attention_forward(query_states,
key_states,
value_states,
p=self.dropout,
scale=self.scale)
out = out.view(bsz, tgt_len, -1)
attn_output, _ = self.out_proj(out)

return attn_output


class CLIPMLP(nn.Module):

def __init__(self,
Expand Down Expand Up @@ -192,7 +267,7 @@ def __init__(self,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()

self.self_attn = CLIPAttention(config)
self.self_attn = CLIPAttention(config, quant_config=quant_config)
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.mlp = CLIPMLP(config, quant_config=quant_config)
Expand All @@ -204,7 +279,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual = hidden_states

hidden_states = self.layer_norm1(hidden_states)
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states

residual = hidden_states
Expand All @@ -217,7 +292,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

class CLIPEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self
Transformer encoder consisting of `config.num_hidden_layers` self
attention layers. Each layer is a [`CLIPEncoderLayer`].
Args:
Expand Down Expand Up @@ -303,3 +378,38 @@ def forward(self, pixel_values: Optional[torch.Tensor] = None):
@property
def device(self):
return next(self.parameters()).device

# TODO: Add prefix argument for filtering out weights to be loaded
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters())
layer_count = len(self.vision_model.encoder.layers)

for name, loaded_weight in weights:
# post_layernorm is not needed in CLIPVisionModel
if "vision_model.post_layernorm" in name:
continue
# omit layers when num_hidden_layers_override is set
if "vision_model.encoder.layers." in name:
layer_idx = int(name.split(".")[3])
if layer_idx >= layer_count:
continue

for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue

param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
Loading

0 comments on commit b4a1e2f

Please sign in to comment.