diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 1b6bc2b1848c1..074d1d8307046 100755 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -1,6 +1,7 @@ """A layer that samples the next tokens from the model's outputs.""" import itertools import math +import os import warnings from dataclasses import dataclass from importlib.util import find_spec @@ -192,19 +193,15 @@ def _init_sampling_tensors( self._sampling_tensors = None # Initialize new sampling tensors - (sampling_tensors, do_penalties, do_top_p_top_k, - do_min_p) = SamplingTensors.from_sampling_metadata( + (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p, + top_k_scalar) = SamplingTensors.from_sampling_metadata( sampling_metadata, vocab_size, logits.device, logits.dtype) self._sampling_tensors = sampling_tensors self._do_penalties = do_penalties self._do_top_p_top_k = do_top_p_top_k self._do_min_p = do_min_p - self._top_p_scalar = sampling_tensors.top_ps[0] - self._top_k_scalar = sampling_tensors.top_ks[0] - scalar_p = torch.all(sampling_tensors.top_ps == self._top_p_scalar) - scalar_k = torch.all(sampling_tensors.top_ks == self._top_k_scalar) - self._scalar_p_and_k = torch.logical_and(scalar_p, scalar_k) + self._top_k_scalar = top_k_scalar self._apply_top_k_top_p_opt = ApplyToppTopkScalar(5) @@ -266,11 +263,11 @@ def forward( logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1)) if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None: - # If we have a scalar p and k, we can use the optimized version. - if self._scalar_p_and_k.any(): + # If we have a scalar k, we can use the optimized version. + if self._top_k_scalar is not None: logits = self._apply_top_k_top_p_opt(logits, - self._top_p_scalar.item(), - self._top_k_scalar.item()) + sampling_tensors.top_ps, + self._top_k_scalar) else: logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, sampling_tensors.top_ks) @@ -383,23 +380,31 @@ class ApplyToppTopkScalar(): The main logic of this is in __call__ This is a class instead of a function, just to keep track of the monotonic non-decreasing state _padded_k + + To disable the duplicates that are outside of kth border, + set VLLM_HANDLE_TOPK_DUPLICATES to false. """ _padded_k = 0 + _handle_duplicates = os.getenv('VLLM_HANDLE_TOPK_DUPLICATES', + '1').lower() in ['1', 'true'] def __init__(self, increment: int): self._increment = increment - def __call__(self, logits: torch.Tensor, p: float, k: int): + def __call__(self, logits: torch.Tensor, p: torch.Tensor, k: int): if k > ApplyToppTopkScalar._padded_k: ApplyToppTopkScalar._padded_k = min(k + self._increment, logits.shape[1]) - vals, idx = torch.topk(logits, k=ApplyToppTopkScalar._padded_k, \ - dim=1, sorted=True) + vals, idx = torch.topk(logits, + k=ApplyToppTopkScalar._padded_k, + dim=1, + sorted=True) # this "if" checks if we have bucketed so much that # we have padded k upto shape of logits - if ApplyToppTopkScalar._padded_k != logits.shape[1]: + if self._handle_duplicates and \ + ApplyToppTopkScalar._padded_k != logits.shape[1]: smallest_of_top_k = vals[:, k - 1] num_duplicates_of_smallest_of_topk = torch.sum( logits == smallest_of_top_k.unsqueeze(1), 1) @@ -424,9 +429,10 @@ def __call__(self, logits: torch.Tensor, p: float, k: int): ApplyToppTopkScalar._padded_k + incr, logits.shape[1]) # recompute topk with expanded padded_k - vals, idx = torch.topk(logits, \ - k=ApplyToppTopkScalar._padded_k, \ - dim=1, sorted=True) + vals, idx = torch.topk(logits, + k=ApplyToppTopkScalar._padded_k, + dim=1, + sorted=True) idx = torch.fliplr(idx) vals = torch.fliplr(vals) @@ -438,7 +444,7 @@ def __call__(self, logits: torch.Tensor, p: float, k: int): probs_sort = vals.softmax(dim=-1) probs_sum = probs_sort.cumsum(dim=-1) - top_p_mask = probs_sum <= (1 - p) + top_p_mask = probs_sum <= (1 - p.unsqueeze(dim=1)) top_p_mask[:, -1] = False vals.masked_fill_(top_p_mask, -float("inf")) diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index d4a8024095286..418b266bb67ae 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -389,7 +389,7 @@ def from_sampling_metadata( vocab_size: int, device: torch.device, dtype: torch.dtype, - ) -> Tuple["SamplingTensors", bool, bool, bool]: + ) -> Tuple["SamplingTensors", bool, bool, bool, Optional[int]]: prompt_tokens: List[array] = [] output_tokens: List[array] = [] top_ks: List[int] = [] @@ -476,6 +476,9 @@ def from_sampling_metadata( prompt_tokens.append(seq_data.prompt_token_ids_array) output_tokens.append(seq_data.output_token_ids_array) + top_k_scalar = top_ks[0] if do_top_p_top_k and all( + k == top_ks[0] for k in top_ks) else None + sampling_tensors = SamplingTensors.from_lists( temperatures, top_ps, @@ -490,7 +493,8 @@ def from_sampling_metadata( device, dtype, ) - return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p) + return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p, + top_k_scalar) @classmethod def from_lists(