Skip to content

Commit

Permalink
Add option to disable duplicates in topk
Browse files Browse the repository at this point in the history
  • Loading branch information
kdamaszk committed Nov 6, 2024
1 parent c3c0e90 commit 44c82dc
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 21 deletions.
44 changes: 25 additions & 19 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""A layer that samples the next tokens from the model's outputs."""
import itertools
import math
import os
import warnings
from dataclasses import dataclass
from importlib.util import find_spec
Expand Down Expand Up @@ -192,19 +193,15 @@ def _init_sampling_tensors(
self._sampling_tensors = None

# Initialize new sampling tensors
(sampling_tensors, do_penalties, do_top_p_top_k,
do_min_p) = SamplingTensors.from_sampling_metadata(
(sampling_tensors, do_penalties, do_top_p_top_k, do_min_p,
top_k_scalar) = SamplingTensors.from_sampling_metadata(
sampling_metadata, vocab_size, logits.device, logits.dtype)

self._sampling_tensors = sampling_tensors
self._do_penalties = do_penalties
self._do_top_p_top_k = do_top_p_top_k
self._do_min_p = do_min_p
self._top_p_scalar = sampling_tensors.top_ps[0]
self._top_k_scalar = sampling_tensors.top_ks[0]
scalar_p = torch.all(sampling_tensors.top_ps == self._top_p_scalar)
scalar_k = torch.all(sampling_tensors.top_ks == self._top_k_scalar)
self._scalar_p_and_k = torch.logical_and(scalar_p, scalar_k)
self._top_k_scalar = top_k_scalar

self._apply_top_k_top_p_opt = ApplyToppTopkScalar(5)

Expand Down Expand Up @@ -266,11 +263,11 @@ def forward(
logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))

if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None:
# If we have a scalar p and k, we can use the optimized version.
if self._scalar_p_and_k.any():
# If we have a scalar k, we can use the optimized version.
if self._top_k_scalar is not None:
logits = self._apply_top_k_top_p_opt(logits,
self._top_p_scalar.item(),
self._top_k_scalar.item())
sampling_tensors.top_ps,
self._top_k_scalar)
else:
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
sampling_tensors.top_ks)
Expand Down Expand Up @@ -383,23 +380,31 @@ class ApplyToppTopkScalar():
The main logic of this is in __call__
This is a class instead of a function, just to keep track of
the monotonic non-decreasing state _padded_k
To disable the duplicates that are outside of kth border,
set VLLM_HANDLE_TOPK_DUPLICATES to false.
"""
_padded_k = 0
_handle_duplicates = os.getenv('VLLM_HANDLE_TOPK_DUPLICATES',
'1').lower() in ['1', 'true']

def __init__(self, increment: int):
self._increment = increment

def __call__(self, logits: torch.Tensor, p: float, k: int):
def __call__(self, logits: torch.Tensor, p: torch.Tensor, k: int):
if k > ApplyToppTopkScalar._padded_k:
ApplyToppTopkScalar._padded_k = min(k + self._increment,
logits.shape[1])

vals, idx = torch.topk(logits, k=ApplyToppTopkScalar._padded_k, \
dim=1, sorted=True)
vals, idx = torch.topk(logits,
k=ApplyToppTopkScalar._padded_k,
dim=1,
sorted=True)

# this "if" checks if we have bucketed so much that
# we have padded k upto shape of logits
if ApplyToppTopkScalar._padded_k != logits.shape[1]:
if self._handle_duplicates and \
ApplyToppTopkScalar._padded_k != logits.shape[1]:
smallest_of_top_k = vals[:, k - 1]
num_duplicates_of_smallest_of_topk = torch.sum(
logits == smallest_of_top_k.unsqueeze(1), 1)
Expand All @@ -424,9 +429,10 @@ def __call__(self, logits: torch.Tensor, p: float, k: int):
ApplyToppTopkScalar._padded_k + incr, logits.shape[1])

# recompute topk with expanded padded_k
vals, idx = torch.topk(logits, \
k=ApplyToppTopkScalar._padded_k, \
dim=1, sorted=True)
vals, idx = torch.topk(logits,
k=ApplyToppTopkScalar._padded_k,
dim=1,
sorted=True)

idx = torch.fliplr(idx)
vals = torch.fliplr(vals)
Expand All @@ -438,7 +444,7 @@ def __call__(self, logits: torch.Tensor, p: float, k: int):

probs_sort = vals.softmax(dim=-1)
probs_sum = probs_sort.cumsum(dim=-1)
top_p_mask = probs_sum <= (1 - p)
top_p_mask = probs_sum <= (1 - p.unsqueeze(dim=1))
top_p_mask[:, -1] = False
vals.masked_fill_(top_p_mask, -float("inf"))

Expand Down
8 changes: 6 additions & 2 deletions vllm/model_executor/sampling_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def from_sampling_metadata(
vocab_size: int,
device: torch.device,
dtype: torch.dtype,
) -> Tuple["SamplingTensors", bool, bool, bool]:
) -> Tuple["SamplingTensors", bool, bool, bool, Optional[int]]:
prompt_tokens: List[array] = []
output_tokens: List[array] = []
top_ks: List[int] = []
Expand Down Expand Up @@ -476,6 +476,9 @@ def from_sampling_metadata(
prompt_tokens.append(seq_data.prompt_token_ids_array)
output_tokens.append(seq_data.output_token_ids_array)

top_k_scalar = top_ks[0] if do_top_p_top_k and all(
k == top_ks[0] for k in top_ks) else None

sampling_tensors = SamplingTensors.from_lists(
temperatures,
top_ps,
Expand All @@ -490,7 +493,8 @@ def from_sampling_metadata(
device,
dtype,
)
return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p)
return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p,
top_k_scalar)

@classmethod
def from_lists(
Expand Down

0 comments on commit 44c82dc

Please sign in to comment.