Skip to content

Commit

Permalink
Add progress bar to WhisperModel.transcribe (#1138)
Browse files Browse the repository at this point in the history
  • Loading branch information
MahmoudAshraf97 authored Nov 14, 2024
1 parent 3e0ba86 commit 85e61ea
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion faster_whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,7 @@ def transcribe(
audio: Union[str, BinaryIO, np.ndarray],
language: Optional[str] = None,
task: str = "transcribe",
log_progress: bool = False,
beam_size: int = 5,
best_of: int = 5,
patience: float = 1,
Expand Down Expand Up @@ -695,6 +696,7 @@ def transcribe(
as "en" or "fr". If not set, the language will be detected in the first 30 seconds
of audio.
task: Task to execute (transcribe or translate).
log_progress: whether to show progress bar or not.
beam_size: Beam size to use for decoding.
best_of: Number of candidates when sampling with non-zero temperature.
patience: Beam search patience factor.
Expand Down Expand Up @@ -941,7 +943,9 @@ def transcribe(
hotwords=hotwords,
)

segments = self.generate_segments(features, tokenizer, options, encoder_output)
segments = self.generate_segments(
features, tokenizer, options, log_progress, encoder_output
)

if speech_chunks:
segments = restore_speech_timestamps(segments, speech_chunks, sampling_rate)
Expand Down Expand Up @@ -1041,6 +1045,7 @@ def generate_segments(
features: np.ndarray,
tokenizer: Tokenizer,
options: TranscriptionOptions,
log_progress,
encoder_output: Optional[ctranslate2.StorageView] = None,
) -> Iterable[Segment]:
content_frames = features.shape[-1] - 1
Expand Down Expand Up @@ -1083,6 +1088,7 @@ def generate_segments(
else:
all_tokens.extend(options.initial_prompt)

pbar = tqdm(total=content_duration, unit="seconds", disable=not log_progress)
last_speech_timestamp = 0.0
# NOTE: This loop is obscurely flattened to make the diff readable.
# A later commit should turn this into a simpler nested loop.
Expand Down Expand Up @@ -1341,6 +1347,12 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]:

prompt_reset_since = len(all_tokens)

pbar.update(
(min(content_frames, seek) - previous_seek)
* self.feature_extractor.time_per_frame,
)
pbar.close()

def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
# When the model is running on multiple GPUs, the encoder output should be moved
# to the CPU since we don't know which GPU will handle the next job.
Expand Down

0 comments on commit 85e61ea

Please sign in to comment.