From 1259caf0feead63d9c3558634cf2156b0e35cdab Mon Sep 17 00:00:00 2001 From: Tomasz Zielinski Date: Thu, 31 Oct 2024 17:37:12 +0200 Subject: [PATCH 1/2] Implementation of TP > 1 for multi-step scheduling --- vllm/worker/hpu_model_runner.py | 38 +++++++++++++++++++++++++++++---- 1 file changed, 34 insertions(+), 4 deletions(-) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index c50e4e244dffe..aaaec3770a48d 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -29,6 +29,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) +from vllm.distributed import broadcast_tensor_dict from vllm.distributed.parallel_state import get_world_group from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping @@ -97,7 +98,10 @@ def subtuple(obj: object, if to_override is None: to_override = {} fields = set(to_copy) | set(to_override.keys()) - values = {f: to_override.get(f, getattr(obj, f)) for f in fields} + if type(obj) is dict: + values = {key: obj[key] for key in fields if key in obj} + else: + values = {f: to_override.get(f, getattr(obj, f)) for f in fields} if typename not in _TYPE_CACHE: _TYPE_CACHE[typename] = collections.namedtuple(typename, ' '.join(fields)) @@ -2049,7 +2053,9 @@ def execute_model( # not first or last multi-step return [] # last multi-step - output = self._decode_sampler_outputs(model_input) + output = self._decode_sampler_outputs( + model_input) if self.is_driver_worker else [] + torch.hpu.synchronize() if model_input.is_first_multi_step: # first multi-step if self.lora_config: @@ -2110,6 +2116,21 @@ def execute_model( sampling_metadata.skip_sampler_cpu_output = True self.model.model.sampler.include_gpu_probs_tensor = True for i in range(num_steps): + torch.hpu.synchronize() + if i != 0 and not self.is_driver_worker: + broadcast_data = broadcast_tensor_dict(src=0) + if 'early_exit' in broadcast_data and broadcast_data[ + 'early_exit']: + return [output] if num_steps == 1 else [] + execute_model_kwargs.update({ + "input_ids": + broadcast_data["input_ids"], + "positions": + broadcast_data["positions"], + "attn_metadata": + self.trim_attn_metadata( + broadcast_data["attn_metadata"]) + }) with self.profiler.record_event('internal', model_event_name): hidden_states = self.model.forward( **execute_model_kwargs, @@ -2133,9 +2154,10 @@ def execute_model( logits = self.model.compute_logits(hidden_states, sampling_metadata) htorch.core.mark_step() + torch.hpu.synchronize() # Only perform sampling in the driver worker. if not self.is_driver_worker: - return [] + continue if model_input.async_callback is not None: model_input.async_callback() @@ -2170,6 +2192,8 @@ def execute_model( dummy_token = (540, ) data.output_token_ids += (dummy_token) else: + broadcast_tensor_dict({'early_exit': True}, + src=0) if num_steps == 1: return [output] else: @@ -2185,6 +2209,12 @@ def execute_model( "attn_metadata": self.trim_attn_metadata(result.attn_metadata) }) + execute_model_kwargs_update = { + "input_ids": result.input_tokens, + "positions": result.input_positions, + "attn_metadata": vars(result.attn_metadata) + } + broadcast_tensor_dict(execute_model_kwargs_update, src=0) if self.is_driver_worker and self.profiler.enabled: # Stop recording 'execute_model' event @@ -2199,7 +2229,7 @@ def execute_model( is_prompt=is_prompt) self.profiler.record_counter(self.event_start, counters) if num_steps == 1: - return [output] + return [output] if self.is_driver_worker else [] else: return [] return output if type(output) is list else [output] From e40bd10039d8a2638412d75f80963e95117dc1ae Mon Sep 17 00:00:00 2001 From: Tomasz Zielinski Date: Mon, 4 Nov 2024 14:05:02 +0200 Subject: [PATCH 2/2] Removed redundant sync points --- vllm/worker/hpu_model_runner.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index aaaec3770a48d..fec5f3d01cff8 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -2116,7 +2116,6 @@ def execute_model( sampling_metadata.skip_sampler_cpu_output = True self.model.model.sampler.include_gpu_probs_tensor = True for i in range(num_steps): - torch.hpu.synchronize() if i != 0 and not self.is_driver_worker: broadcast_data = broadcast_tensor_dict(src=0) if 'early_exit' in broadcast_data and broadcast_data[ @@ -2154,7 +2153,6 @@ def execute_model( logits = self.model.compute_logits(hidden_states, sampling_metadata) htorch.core.mark_step() - torch.hpu.synchronize() # Only perform sampling in the driver worker. if not self.is_driver_worker: continue @@ -2209,12 +2207,12 @@ def execute_model( "attn_metadata": self.trim_attn_metadata(result.attn_metadata) }) - execute_model_kwargs_update = { + model_kwargs_broadcast_data = { "input_ids": result.input_tokens, "positions": result.input_positions, "attn_metadata": vars(result.attn_metadata) } - broadcast_tensor_dict(execute_model_kwargs_update, src=0) + broadcast_tensor_dict(model_kwargs_broadcast_data, src=0) if self.is_driver_worker and self.profiler.enabled: # Stop recording 'execute_model' event