Skip to content

Commit

Permalink
distributed: support pipeline parallelism for internvl and internlm2 (#…
Browse files Browse the repository at this point in the history
…965)

* distributed: support pipeline parallelism for internvl and internlm2

* register the new models
  • Loading branch information
AlpinDale authored Dec 23, 2024
1 parent cbd51a2 commit a8bdd48
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 33 deletions.
2 changes: 2 additions & 0 deletions aphrodite/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
"NemotronForCausalLM",
"Qwen2ForCausalLM",
"Qwen2MoeForCausalLM",
"InternLM2ForCausalLM",
"InternVLChatModel",
]

_OPTIMIZED_QUANTS = [
Expand Down
57 changes: 42 additions & 15 deletions aphrodite/modeling/models/internlm2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
from typing import Any, Dict, Iterable, List, Optional, Tuple
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

import torch
from torch import nn
Expand All @@ -8,7 +8,8 @@
from aphrodite.attention import Attention, AttentionMetadata
from aphrodite.common.config import CacheConfig
from aphrodite.common.sequence import IntermediateTensors
from aphrodite.distributed import get_tensor_model_parallel_world_size
from aphrodite.distributed import (get_pp_group,
get_tensor_model_parallel_world_size)
from aphrodite.modeling.layers.activation import SiluAndMul
from aphrodite.modeling.layers.layernorm import RMSNorm
from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
Expand All @@ -23,6 +24,9 @@
from aphrodite.modeling.sampling_metadata import SamplingMetadata
from aphrodite.quantization.base_config import QuantizationConfig

from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)


class InternLM2MLP(nn.Module):

Expand Down Expand Up @@ -212,6 +216,7 @@ def __init__(
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
Expand All @@ -221,11 +226,15 @@ def __init__(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
InternLMDecoderLayer(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: InternLMDecoderLayer(config, cache_config,
quant_config),
prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.tok_embeddings(input_ids)
Expand All @@ -236,23 +245,33 @@ def forward(
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
intermediate_tensors: IntermediateTensors = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if inputs_embeds is not None:
hidden_states = inputs_embeds
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.tok_embeddings(input_ids)
residual = None
else:
hidden_states = self.tok_embeddings(input_ids)
residual = None
for i in range(len(self.layers)):
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i],
kv_caches[i - self.start_layer],
attn_metadata,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states

Expand All @@ -272,8 +291,12 @@ def __init__(
self.output = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
if self.config.tie_word_embeddings:
self.output.weight = self.model.tok_embeddings.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

def forward(
self,
Expand All @@ -284,7 +307,7 @@ def forward(
intermediate_tensors: IntermediateTensors,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
attn_metadata, intermediate_tensors)
return hidden_states

def compute_logits(
Expand Down Expand Up @@ -321,6 +344,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
Expand All @@ -329,6 +354,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
Expand Down
4 changes: 3 additions & 1 deletion aphrodite/modeling/models/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,8 @@ def __init__(self,
nn.Linear(llm_hidden_size, llm_hidden_size))

self.img_context_token_id = None
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)

def pixel_shuffle(self, x, scale_factor=0.5):
n, w, h, c = x.size()
Expand Down Expand Up @@ -459,7 +461,7 @@ def forward(
positions,
kv_caches,
attn_metadata,
None,
intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states

Expand Down
16 changes: 16 additions & 0 deletions aphrodite/modeling/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from aphrodite.common.config import (CacheConfig, LoRAConfig, MultiModalConfig,
SchedulerConfig)
from aphrodite.common.sequence import IntermediateTensors
from aphrodite.common.utils import is_pin_memory_available, progress_bar
from aphrodite.modeling.model_loader.loader import build_model
from aphrodite.modeling.models import ModelRegistry
Expand Down Expand Up @@ -271,3 +272,18 @@ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
if name.startswith(missing_layer_name):
return True
return False


def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int):

def make_empty_intermediate_tensors(
batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
key: torch.zeros((batch_size, hidden_size),
dtype=dtype,
device=device)
for key in keys
})

return make_empty_intermediate_tensors
38 changes: 22 additions & 16 deletions tests/distributed/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,26 @@
APHRODITE_MULTI_NODE = os.getenv("APHRODITE_MULTI_NODE", "0") == "1"


@pytest.mark.parametrize(("TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, "
"MODEL_NAME, DIST_BACKEND"),
[
(2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"),
(2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"),
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"),
])
@pytest.mark.parametrize(
("TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, TRUST_REMOTE_CODE, "
"MODEL_NAME, DIST_BACKEND"),
[
(2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 3, 0, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 4, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 4, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 3, 0, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(1, 4, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(1, 4, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(2, 2, 1, 1, 1, "internlm/internlm2_5-7b-chat", "ray"),
],
)
@fork_new_process_for_each_test
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
DIST_BACKEND):
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL,
TRUST_REMOTE_CODE, MODEL_NAME, DIST_BACKEND):
if APHRODITE_MULTI_NODE and DIST_BACKEND == "mp":
pytest.skip("Skipping multi-node pipeline parallel test for "
"multiprocessing distributed backend")
Expand Down Expand Up @@ -68,6 +71,9 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
if EAGER_MODE:
pp_args.append("--enforce-eager")
tp_args.append("--enforce-eager")
if TRUST_REMOTE_CODE:
pp_args.append("--trust-remote-code")
tp_args.append("--trust-remote-code")
pp_env = None
if (DIST_BACKEND == "ray" and TP_SIZE == 2 and PP_SIZE == 2
and CHUNKED_PREFILL):
Expand Down
7 changes: 6 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,12 @@ def compare_two_settings(model: str,
env2: The second set of environment variables to pass to the API server.
"""

tokenizer = AutoTokenizer.from_pretrained(model)
trust_remote_code = "--trust-remote-code"
if trust_remote_code in arg1 or trust_remote_code in arg2:
tokenizer = AutoTokenizer.from_pretrained(model,
trust_remote_code=True)
else:
tokenizer = AutoTokenizer.from_pretrained(model)

prompt = "Hello, my name is"
token_ids = tokenizer(prompt)["input_ids"]
Expand Down

0 comments on commit a8bdd48

Please sign in to comment.