diff --git a/examples/offline_inference_spec_decode.py b/examples/offline_inference_spec_decode.py index 03543ff47de69..22daecfcca070 100644 --- a/examples/offline_inference_spec_decode.py +++ b/examples/offline_inference_spec_decode.py @@ -15,7 +15,8 @@ def time_generation(llm: LLM, prompts: List[str], start = time.time() outputs = llm.generate(prompts, sampling_params) end = time.time() - latency_per_token = (end - start) / sum([len(o.outputs[0].token_ids) for o in outputs]) + latency_per_token = (end - start) / sum( + [len(o.outputs[0].token_ids) for o in outputs]) # Print the outputs. ret = [] for output in outputs: @@ -36,7 +37,8 @@ def time_generation(llm: LLM, prompts: List[str], print("==============Without speculation==================") llm = LLM(model="facebook/opt-6.7b") - ret_non_spec,latency_per_token_non_spec = time_generation(llm, prompts, sampling_params) + ret_non_spec, latency_per_token_non_spec = time_generation( + llm, prompts, sampling_params) del llm gc.collect() @@ -46,19 +48,20 @@ def time_generation(llm: LLM, prompts: List[str], llm = LLM( model="facebook/opt-6.7b", speculative_model="facebook/opt-125m", - num_speculative_tokens = 5, + num_speculative_tokens=5, # These are currently required for MLPSpeculator decoding use_v2_block_manager=True, ) - ret_spec,latency_per_token_spec = time_generation(llm, prompts, sampling_params) + ret_spec, latency_per_token_spec = time_generation(llm, prompts, + sampling_params) del llm gc.collect() print("================= Summary =====================") print("input is ", prompts, "\n") - print("Non Spec Decode - latency_per_token is ", latency_per_token_non_spec) + print("Non Spec Decode - latency_per_token is ", + latency_per_token_non_spec) print("Generated Text is :", ret_non_spec, "\n") print("Spec Decode - latency_per_token is ", latency_per_token_spec) print("Generated Text is :", ret_spec) - \ No newline at end of file diff --git a/tests/samplers/test_rejection_sampler.py b/tests/samplers/test_rejection_sampler.py index fac59894d2c2d..4c18521e4455a 100644 --- a/tests/samplers/test_rejection_sampler.py +++ b/tests/samplers/test_rejection_sampler.py @@ -23,16 +23,17 @@ def mock_causal_accepted_tensor( """ batch_size = last_accepted_indices.shape[0] - accepted = (torch.arange(k).expand(batch_size, k) <= - last_accepted_indices.unsqueeze(-1).broadcast_to( + accepted = (torch.arange(k).expand(batch_size, k) + <= last_accepted_indices.unsqueeze(-1).broadcast_to( batch_size, k)) # Sprinkle accepted values after the contiguous initial accepted values. # This replicates the behavior of rejection sampling, which may "accept" # a token that cannot be accepted because of causality. - sprinkle_candidates = ( - torch.arange(k).expand(batch_size, k) > - last_accepted_indices.unsqueeze(-1).broadcast_to(batch_size, k) + 1) + sprinkle_candidates = (torch.arange(k).expand( + batch_size, + k) > last_accepted_indices.unsqueeze(-1).broadcast_to(batch_size, k) + + 1) sprinkle = torch.rand(batch_size, k) > 0.5 accepted[sprinkle_candidates] = sprinkle[sprinkle_candidates] return accepted @@ -382,8 +383,8 @@ def test_rejection_sampling_approximates_target_distribution( distance_wrt_reference) expected_improvement_multiplier = 20 - assert (relative_change_in_distance_wrt_target > - relative_change_in_distance_wrt_reference * + assert (relative_change_in_distance_wrt_target + > relative_change_in_distance_wrt_reference * expected_improvement_multiplier) diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index a0ca01ef960de..1197afd15d016 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -25,7 +25,7 @@ class HPUAttentionBackend(AttentionBackend): @staticmethod def get_name() -> str: return "hpu-attn" - + @staticmethod def get_impl_cls() -> Type["HPUAttentionImpl"]: return HPUAttentionImpl diff --git a/vllm/executor/hpu_executor.py b/vllm/executor/hpu_executor.py index 37ddd8c381c5e..a5dffd9b8c6b5 100644 --- a/vllm/executor/hpu_executor.py +++ b/vllm/executor/hpu_executor.py @@ -144,13 +144,13 @@ def execute_model( with gc_ctx as gc_local_metric, \ cpu_fallback_ctx as cpu_fallback_local_metric: output = self.driver_worker.execute_model(execute_model_req) - if (log_graph_compilation and gc_local_metric.stats()[0][1] > 0 - ) or log_graph_compilation_all: + if (log_graph_compilation and gc_local_metric.stats()[0][1] + > 0) or log_graph_compilation_all: msg = ("VLLM_HPU_STEP_GRAPH_COMPILATION: " f"{gc_local_metric.stats()}, {input_stats}") logger.warning(msg) - if (log_cpu_fallbacks and cpu_fallback_local_metric.stats()[0][1] > - 0) or log_cpu_fallbacks_all: + if (log_cpu_fallbacks and cpu_fallback_local_metric.stats()[0][1] + > 0) or log_cpu_fallbacks_all: msg = ("VLLM_HPU_STEP_CPU_FALLBACK: " f"{cpu_fallback_local_metric.stats()}, {input_stats}") logger.warning(msg) diff --git a/vllm/model_executor/layers/spec_decode_base_sampler.py b/vllm/model_executor/layers/spec_decode_base_sampler.py index f65639413b4f4..e3b76538f3941 100644 --- a/vllm/model_executor/layers/spec_decode_base_sampler.py +++ b/vllm/model_executor/layers/spec_decode_base_sampler.py @@ -30,7 +30,9 @@ def __init__(self, strict_mode: bool = False): self.num_emitted_tokens: Optional[torch.Tensor] = None self.num_draft_tokens: int = 0 - def init_tensors(self, device: Union[int, str], device_type: str = 'cuda') -> None: + def init_tensors(self, + device: Union[int, str], + device_type: str = 'cuda') -> None: assert self.num_accepted_tokens is None if isinstance(device, int): device = f"{device_type}:{device}" diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index 852053ab21290..8892893dc4a62 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -231,12 +231,17 @@ def _contract_batch_all_spec( # of shape [batch_size * k + 1] back to [batch_size, k + 1]. contracted_bs, k = proposals.proposal_token_ids.shape - (target_sampler_output.sampled_token_ids, - target_sampler_output.sampled_token_probs, - target_sampler_output.logprobs, - target_sampler_output.hidden_states, - _, _, _, _,) = self._split_scoring_output( - target_sampler_output, num_scoring_tokens) + ( + target_sampler_output.sampled_token_ids, + target_sampler_output.sampled_token_probs, + target_sampler_output.logprobs, + target_sampler_output.hidden_states, + _, + _, + _, + _, + ) = self._split_scoring_output(target_sampler_output, + num_scoring_tokens) # Reshape tensors to original batch size target_token_ids = target_sampler_output.sampled_token_ids.reshape( diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index b3ead209b9d66..aae0d7decdbab 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -2,13 +2,12 @@ import torch -from vllm.forward_context import set_forward_context -from vllm.model_executor.layers.sampler import SamplerOutput - from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) +from vllm.forward_context import set_forward_context from vllm.logger import init_logger +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.multimodal import MultiModalInputs from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata, @@ -23,10 +22,9 @@ # vllm_flash_attn is not installed, try the ROCm FA metadata from vllm.attention.backends.rocm_flash_attn import ( ROCmFlashAttentionMetadata as FlashAttentionMetadata) -except: - logger.warning( - "Draft model speculative decoding currently only supports" - "CUDA and ROCm flash attention backend.") +except Exception as e: + logger.warning("Draft model speculative decoding currently only supports" + "CUDA and ROCm flash attention backend.", e) # A flag to enable debug prints for the updated input tensors # before each step. diff --git a/vllm/spec_decode/metrics.py b/vllm/spec_decode/metrics.py index 1e227973e84b1..583f830b6be89 100644 --- a/vllm/spec_decode/metrics.py +++ b/vllm/spec_decode/metrics.py @@ -6,8 +6,8 @@ from vllm.model_executor.layers.spec_decode_base_sampler import ( SpecDecodeBaseSampler) -from vllm.utils import is_pin_memory_available from vllm.platforms import current_platform +from vllm.utils import is_pin_memory_available class SpecDecodeWorkerMetrics( @@ -78,9 +78,9 @@ def __init__(self, self._rejsample_metrics_collect_interval_s = collect_interval_s self._last_metrics_collect_time = self._timer() - def init_tensors(self, rank: int, device: str) -> None: + def init_tensors(self, rank: int, device: torch.device) -> None: self._rank = rank - if 'hpu' == device.type: + if device.type == 'hpu': import habana_frameworks.torch as htorch self._copy_stream = htorch.hpu.Stream() else: diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 38fe170775575..780bcda5fbd7c 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -5,6 +5,7 @@ import torch from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.platforms import current_platform from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData, SequenceGroupMetadata) from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner @@ -12,8 +13,8 @@ SpeculativeProposer) from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase from vllm.spec_decode.top1_proposer import Top1Proposer -from vllm.platforms import current_platform from vllm.utils import is_neuron, is_openvino, is_xpu + if is_neuron(): from vllm.worker.neuron_worker import NeuronWorker as WorkerBaseCls elif current_platform.is_hpu(): diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index ce853bb4ba95b..1dd07fc074ca0 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -35,8 +35,8 @@ get_all_num_logprobs, get_sampled_token_logprobs, nvtx_range, split_batch_by_proposal_len) -from vllm.worker.worker import Worker from vllm.worker.selector import init_worker +from vllm.worker.worker import Worker from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase logger = init_logger(__name__) @@ -461,8 +461,8 @@ def _should_disable_all_speculation( self, execute_model_req: ExecuteModelRequest) -> bool: # When the batch size is too large, disable speculative decoding # to stop trading off throughput for latency. - return (execute_model_req.running_queue_size >= - self.disable_by_batch_size) + return (execute_model_req.running_queue_size + >= self.disable_by_batch_size) def _maybe_disable_speculative_tokens( self, disable_all_speculation: bool, diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index 782fbb48f1118..c88820ab27b69 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -5,10 +5,10 @@ import torch from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.platforms import current_platform from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, PromptLogprobs, SequenceGroupMetadata, SequenceOutput) -from vllm.platforms import current_platform SeqId = int @@ -40,13 +40,15 @@ def get_sampled_token_logprobs( """ num_steps, batch_size, vocab_size = logprob_tensor.shape - selected_logprobs = logprob_tensor[torch.arange(num_steps).unsqueeze(1), - torch.arange(batch_size), - sampled_token_ids, ] + selected_logprobs = logprob_tensor[ + torch.arange(num_steps).unsqueeze(1), + torch.arange(batch_size), + sampled_token_ids, + ] expanded_selected_logprobs = selected_logprobs.unsqueeze(-1).expand( -1, -1, vocab_size) - sampled_token_ids_ranks = (logprob_tensor > - expanded_selected_logprobs).sum(-1).add_(1) + sampled_token_ids_ranks = (logprob_tensor + > expanded_selected_logprobs).sum(-1).add_(1) return sampled_token_ids_ranks, selected_logprobs diff --git a/vllm/worker/selector.py b/vllm/worker/selector.py index b06122f9139c2..eed4d73999d00 100644 --- a/vllm/worker/selector.py +++ b/vllm/worker/selector.py @@ -1,5 +1,6 @@ from vllm.config import DeviceConfig + def init_worker(*args, **kwargs): device_config: DeviceConfig = kwargs.get("device_config") if device_config.device_type == 'neuron': @@ -22,4 +23,4 @@ def init_worker(*args, **kwargs): return XPUWorker(*args, **kwargs) else: from vllm.worker.worker import Worker - return Worker(*args, **kwargs) \ No newline at end of file + return Worker(*args, **kwargs)