From 0c56d23ecea55ab328ce41cd70d9d50960562436 Mon Sep 17 00:00:00 2001 From: AlpinDale <52078762+AlpinDale@users.noreply.github.com> Date: Tue, 24 Dec 2024 09:34:06 -0800 Subject: [PATCH] tpu: fix outputs by correcting the next_token_ids shape (#986) --- aphrodite/worker/tpu_model_runner.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/aphrodite/worker/tpu_model_runner.py b/aphrodite/worker/tpu_model_runner.py index fdc706d18..d9d988269 100644 --- a/aphrodite/worker/tpu_model_runner.py +++ b/aphrodite/worker/tpu_model_runner.py @@ -602,7 +602,7 @@ def _execute_model(*args): batch_idx += 1 else: for seq_id in seq_ids: - next_token_id = next_token_ids[batch_idx][0] + next_token_id = next_token_ids[batch_idx] seq_outputs.append( SequenceOutput(seq_id, next_token_id, {next_token_id: zero_logprob})) @@ -723,6 +723,9 @@ def forward( sampled_token_ids = torch.multinomial(probs, num_samples, replacement=True) + if num_samples == 1: + argmax_token_ids = argmax_token_ids.squeeze(dim=-1) + sampled_token_ids = sampled_token_ids.squeeze(dim=-1) next_token_ids = torch.where(t != 0, sampled_token_ids, argmax_token_ids) return next_token_ids