diff --git a/aphrodite/common/config.py b/aphrodite/common/config.py index e0e0ff210..c50d65f0e 100644 --- a/aphrodite/common/config.py +++ b/aphrodite/common/config.py @@ -35,6 +35,7 @@ APHRODITE_USE_MODELSCOPE = envs.APHRODITE_USE_MODELSCOPE _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 +_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 4096 _PP_SUPPORTED_MODELS = [ "AquilaModel", @@ -722,6 +723,10 @@ def is_embedding_model(self) -> bool: """Extract the embedding model flag.""" return self.embedding_mode + @property + def is_multimodal_model(self) -> bool: + return self.multimodal_config is not None + class CacheConfig: """Configuration for the KV cache. @@ -1118,28 +1123,35 @@ def __init__(self, num_lookahead_slots: int = 0, delay_factor: float = 0.0, enable_chunked_prefill: bool = False, - embedding_mode: Optional[bool] = False, + embedding_mode: bool = False, + is_multimodal_model: bool = False, preemption_mode: Optional[str] = None, num_scheduler_steps: int = 1, send_delta_data: bool = False, single_user_mode: bool = False) -> None: - if max_num_batched_tokens is not None: - self.max_num_batched_tokens = max_num_batched_tokens - else: + if max_num_batched_tokens is None: if enable_chunked_prefill: - if not HAS_TRITON: - raise ValueError("Triton is not installed, " - "chunked prefill will not work.") - # For chunked prefill, choose the well-tuned batch size. - self.max_num_batched_tokens = 768 - elif embedding_mode: - # For embedding, choose specific value for higher throughput - self.max_num_batched_tokens = max( - max_model_len, _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS) + # It is the values that have the best balance between ITL + # and TTFT on A100. Note it is not optimized for throughput. + max_num_batched_tokens = 512 else: # If max_model_len is too short, use 2048 as the default value # for higher throughput. - self.max_num_batched_tokens = max(max_model_len, 2048) + max_num_batched_tokens = max(max_model_len, 2048) + if embedding_mode: + # For embedding, choose specific value for higher throughput + max_num_batched_tokens = max( + max_num_batched_tokens, + _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS, + ) + if is_multimodal_model: + # The value needs to be at least the number of multimodal tokens + max_num_batched_tokens = max( + max_num_batched_tokens, + _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, + ) + self.max_num_batched_tokens = max_num_batched_tokens + if enable_chunked_prefill: logger.info( "Chunked prefill is enabled with " diff --git a/aphrodite/engine/aphrodite_engine.py b/aphrodite/engine/aphrodite_engine.py index 6d17e46aa..541da79a7 100644 --- a/aphrodite/engine/aphrodite_engine.py +++ b/aphrodite/engine/aphrodite_engine.py @@ -1849,7 +1849,7 @@ def _validate_model_inputs(self, inputs: Union[LLMInputs, prompt_ids = inputs.get("prompt_token_ids") if prompt_ids is None or len(prompt_ids) == 0: raise ValueError("Prompt cannot be empty") - if self.model_config.multimodal_config is not None: + if self.model_config.is_multimodal_model: max_prompt_len = self.model_config.max_model_len if len(prompt_ids) > max_prompt_len: raise ValueError( @@ -1859,6 +1859,9 @@ def _validate_model_inputs(self, inputs: Union[LLMInputs, "number of text tokens plus multimodal tokens. For image " "inputs, the number of image tokens depends on the number " "of images, and possibly their aspect ratios as well.") + # TODO: Find out how many placeholder tokens are there so we can + # check that chunked prefill does not truncate them + # max_batch_len = self.scheduler_config.max_num_batched_tokens setup_logger() diff --git a/aphrodite/engine/args_tools.py b/aphrodite/engine/args_tools.py index 93fd7b58f..ecbf9b61e 100644 --- a/aphrodite/engine/args_tools.py +++ b/aphrodite/engine/args_tools.py @@ -1061,6 +1061,7 @@ def create_engine_config(self, ) -> EngineConfig: delay_factor=self.scheduler_delay_factor, enable_chunked_prefill=self.enable_chunked_prefill, embedding_mode=model_config.embedding_mode, + is_multimodal_model=model_config.is_multimodal_model, preemption_mode=self.preemption_mode, num_scheduler_steps=self.num_scheduler_steps, send_delta_data=(APHRODITE_USE_RAY_SPMD_WORKER and diff --git a/aphrodite/task_handler/utils.py b/aphrodite/task_handler/utils.py index e953fd287..ce514734f 100644 --- a/aphrodite/task_handler/utils.py +++ b/aphrodite/task_handler/utils.py @@ -39,7 +39,7 @@ def assert_enc_dec_mr_supported_scenario( raise NotImplementedError( STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PP']) - if enc_dec_mr.model_config.multimodal_config is not None: + if enc_dec_mr.model_config.is_multimodal_model: raise NotImplementedError( STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_MM'])