diff --git a/aphrodite/common/config.py b/aphrodite/common/config.py index 15066de88..4925f3ce7 100644 --- a/aphrodite/common/config.py +++ b/aphrodite/common/config.py @@ -48,6 +48,8 @@ "NemotronForCausalLM", "Qwen2ForCausalLM", "Qwen2MoeForCausalLM", + "InternLM2ForCausalLM", + "InternVLChatModel", ] _OPTIMIZED_QUANTS = [ diff --git a/aphrodite/modeling/models/internlm2.py b/aphrodite/modeling/models/internlm2.py index 0d6c7be9b..c9f8f695a 100644 --- a/aphrodite/modeling/models/internlm2.py +++ b/aphrodite/modeling/models/internlm2.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -8,7 +8,8 @@ from aphrodite.attention import Attention, AttentionMetadata from aphrodite.common.config import CacheConfig from aphrodite.common.sequence import IntermediateTensors -from aphrodite.distributed import get_tensor_model_parallel_world_size +from aphrodite.distributed import (get_pp_group, + get_tensor_model_parallel_world_size) from aphrodite.modeling.layers.activation import SiluAndMul from aphrodite.modeling.layers.layernorm import RMSNorm from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear, @@ -23,6 +24,9 @@ from aphrodite.modeling.sampling_metadata import SamplingMetadata from aphrodite.quantization.base_config import QuantizationConfig +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + class InternLM2MLP(nn.Module): @@ -212,6 +216,7 @@ def __init__( config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -221,11 +226,15 @@ def __init__( config.vocab_size, config.hidden_size, ) - self.layers = nn.ModuleList([ - InternLMDecoderLayer(config, cache_config, quant_config) - for _ in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: InternLMDecoderLayer(config, cache_config, + quant_config), + prefix=f"{prefix}.layers") self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.tok_embeddings(input_ids) @@ -236,23 +245,33 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, + intermediate_tensors: IntermediateTensors = None, inputs_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if inputs_embeds is not None: - hidden_states = inputs_embeds + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.tok_embeddings(input_ids) + residual = None else: - hidden_states = self.tok_embeddings(input_ids) - residual = None - for i in range(len(self.layers)): + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, - kv_caches[i], + kv_caches[i - self.start_layer], attn_metadata, residual, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -272,8 +291,12 @@ def __init__( self.output = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) + if self.config.tie_word_embeddings: + self.output.weight = self.model.tok_embeddings.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -284,7 +307,7 @@ def forward( intermediate_tensors: IntermediateTensors, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -321,6 +344,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -329,6 +354,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/aphrodite/modeling/models/internvl.py b/aphrodite/modeling/models/internvl.py index db013c0a4..7e39ad3f2 100644 --- a/aphrodite/modeling/models/internvl.py +++ b/aphrodite/modeling/models/internvl.py @@ -339,6 +339,8 @@ def __init__(self, nn.Linear(llm_hidden_size, llm_hidden_size)) self.img_context_token_id = None + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) def pixel_shuffle(self, x, scale_factor=0.5): n, w, h, c = x.size() @@ -459,7 +461,7 @@ def forward( positions, kv_caches, attn_metadata, - None, + intermediate_tensors, inputs_embeds=inputs_embeds) return hidden_states diff --git a/aphrodite/modeling/models/utils.py b/aphrodite/modeling/models/utils.py index b51496208..b5c35fdcf 100644 --- a/aphrodite/modeling/models/utils.py +++ b/aphrodite/modeling/models/utils.py @@ -8,6 +8,7 @@ from aphrodite.common.config import (CacheConfig, LoRAConfig, MultiModalConfig, SchedulerConfig) +from aphrodite.common.sequence import IntermediateTensors from aphrodite.common.utils import is_pin_memory_available, progress_bar from aphrodite.modeling.model_loader.loader import build_model from aphrodite.modeling.models import ModelRegistry @@ -271,3 +272,18 @@ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool: if name.startswith(missing_layer_name): return True return False + + +def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int): + + def make_empty_intermediate_tensors( + batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + key: torch.zeros((batch_size, hidden_size), + dtype=dtype, + device=device) + for key in keys + }) + + return make_empty_intermediate_tensors diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index f69308d6f..2b4e24b15 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -15,23 +15,26 @@ APHRODITE_MULTI_NODE = os.getenv("APHRODITE_MULTI_NODE", "0") == "1" -@pytest.mark.parametrize(("TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, " - "MODEL_NAME, DIST_BACKEND"), - [ - (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"), - (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), - (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"), - (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"), - (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), - (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), - (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"), - (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), - (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), - (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"), - ]) +@pytest.mark.parametrize( + ("TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, TRUST_REMOTE_CODE, " + "MODEL_NAME, DIST_BACKEND"), + [ + (2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), + (2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"), + (1, 3, 0, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"), + (1, 4, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), + (1, 4, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"), + (1, 3, 0, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), + (1, 4, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), + (1, 4, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), + (2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), + (2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), + (2, 2, 1, 1, 1, "internlm/internlm2_5-7b-chat", "ray"), + ], +) @fork_new_process_for_each_test -def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME, - DIST_BACKEND): +def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, + TRUST_REMOTE_CODE, MODEL_NAME, DIST_BACKEND): if APHRODITE_MULTI_NODE and DIST_BACKEND == "mp": pytest.skip("Skipping multi-node pipeline parallel test for " "multiprocessing distributed backend") @@ -68,6 +71,9 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME, if EAGER_MODE: pp_args.append("--enforce-eager") tp_args.append("--enforce-eager") + if TRUST_REMOTE_CODE: + pp_args.append("--trust-remote-code") + tp_args.append("--trust-remote-code") pp_env = None if (DIST_BACKEND == "ray" and TP_SIZE == 2 and PP_SIZE == 2 and CHUNKED_PREFILL): diff --git a/tests/utils.py b/tests/utils.py index 393b513d3..ba10e8fa0 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -179,7 +179,12 @@ def compare_two_settings(model: str, env2: The second set of environment variables to pass to the API server. """ - tokenizer = AutoTokenizer.from_pretrained(model) + trust_remote_code = "--trust-remote-code" + if trust_remote_code in arg1 or trust_remote_code in arg2: + tokenizer = AutoTokenizer.from_pretrained(model, + trust_remote_code=True) + else: + tokenizer = AutoTokenizer.from_pretrained(model) prompt = "Hello, my name is" token_ids = tokenizer(prompt)["input_ids"]