Skip to content

Commit

Permalink
Use top_p as a scalar
Browse files Browse the repository at this point in the history
  • Loading branch information
kdamaszk committed Nov 6, 2024
1 parent 44c82dc commit 6981936
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
13 changes: 7 additions & 6 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,14 +194,15 @@ def _init_sampling_tensors(

# Initialize new sampling tensors
(sampling_tensors, do_penalties, do_top_p_top_k, do_min_p,
top_k_scalar) = SamplingTensors.from_sampling_metadata(
top_k_scalar, top_p_scalar) = SamplingTensors.from_sampling_metadata(
sampling_metadata, vocab_size, logits.device, logits.dtype)

self._sampling_tensors = sampling_tensors
self._do_penalties = do_penalties
self._do_top_p_top_k = do_top_p_top_k
self._do_min_p = do_min_p
self._top_k_scalar = top_k_scalar
self._top_p_scalar = top_p_scalar

self._apply_top_k_top_p_opt = ApplyToppTopkScalar(5)

Expand Down Expand Up @@ -263,10 +264,10 @@ def forward(
logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))

if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None:
# If we have a scalar k, we can use the optimized version.
if self._top_k_scalar is not None:
# If we have a scalar p and k, we can use the optimized version.
if self._top_k_scalar and self._top_p_scalar:
logits = self._apply_top_k_top_p_opt(logits,
sampling_tensors.top_ps,
self._top_p_scalar,
self._top_k_scalar)
else:
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
Expand Down Expand Up @@ -391,7 +392,7 @@ class ApplyToppTopkScalar():
def __init__(self, increment: int):
self._increment = increment

def __call__(self, logits: torch.Tensor, p: torch.Tensor, k: int):
def __call__(self, logits: torch.Tensor, p: float, k: int):
if k > ApplyToppTopkScalar._padded_k:
ApplyToppTopkScalar._padded_k = min(k + self._increment,
logits.shape[1])
Expand Down Expand Up @@ -444,7 +445,7 @@ def __call__(self, logits: torch.Tensor, p: torch.Tensor, k: int):

probs_sort = vals.softmax(dim=-1)
probs_sum = probs_sort.cumsum(dim=-1)
top_p_mask = probs_sum <= (1 - p.unsqueeze(dim=1))
top_p_mask = probs_sum <= (1 - p)
top_p_mask[:, -1] = False
vals.masked_fill_(top_p_mask, -float("inf"))

Expand Down
7 changes: 5 additions & 2 deletions vllm/model_executor/sampling_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,8 @@ def from_sampling_metadata(
vocab_size: int,
device: torch.device,
dtype: torch.dtype,
) -> Tuple["SamplingTensors", bool, bool, bool, Optional[int]]:
) -> Tuple["SamplingTensors", bool, bool, bool, Optional[int],
Optional[float]]:
prompt_tokens: List[array] = []
output_tokens: List[array] = []
top_ks: List[int] = []
Expand Down Expand Up @@ -478,6 +479,8 @@ def from_sampling_metadata(

top_k_scalar = top_ks[0] if do_top_p_top_k and all(
k == top_ks[0] for k in top_ks) else None
top_p_scalar = top_ps[0] if do_top_p_top_k and all(
p == top_ps[0] for p in top_ps) else None

sampling_tensors = SamplingTensors.from_lists(
temperatures,
Expand All @@ -494,7 +497,7 @@ def from_sampling_metadata(
dtype,
)
return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p,
top_k_scalar)
top_k_scalar, top_p_scalar)

@classmethod
def from_lists(
Expand Down

0 comments on commit 6981936

Please sign in to comment.