diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index d3d2bdf7..bf091981 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -57,7 +57,7 @@ class Segment: compression_ratio: float no_speech_prob: float words: Optional[List[Word]] - temperature: Optional[float] = 1.0 + temperature: Optional[float] def _asdict(self): warn( @@ -68,7 +68,6 @@ def _asdict(self): return asdict(self) -# Added additional parameters for multilingual videos and fixes below @dataclass class TranscriptionOptions: beam_size: int @@ -112,34 +111,17 @@ class TranscriptionInfo: vad_options: VadOptions -# The code below is originally from HF pipeline and is used in whisper-x -# (https://github.com/m-bain/whisperX) and adapted for faster_whisper - - class BatchedInferencePipeline: - """ - Huggingface Pipeline wrapper for WhisperModel. - Copyright (c) 2022, Max Bain - All rights reserved. - Modified by Mobius Labs GmbH - """ - def __init__( self, model, - options: Optional[TranscriptionOptions] = None, - tokenizer=None, - language: Optional[str] = None, ): self.model: WhisperModel = model - self.tokenizer = tokenizer - self.options = options - self.preset_language = language self.last_speech_timestamp = 0.0 - def forward(self, features, chunks_metadata, **forward_params): - encoder_output, outputs = self.model.generate_segment_batched( - features, self.tokenizer, forward_params + def forward(self, features, tokenizer, chunks_metadata, options): + encoder_output, outputs = self.generate_segment_batched( + features, tokenizer, options ) segmented_outputs = [] @@ -153,7 +135,7 @@ def forward(self, features, chunks_metadata, **forward_params): seek, single_timestamp_ending, ) = self.model._split_segments_by_timestamps( - tokenizer=self.tokenizer, + tokenizer=tokenizer, tokens=output["tokens"], time_offset=chunk_metadata["start_time"], segment_size=segment_size, @@ -163,14 +145,14 @@ def forward(self, features, chunks_metadata, **forward_params): segmented_outputs.append( [ dict( - text=self.tokenizer.decode(subsegment["tokens"]), + text=tokenizer.decode(subsegment["tokens"]), avg_logprob=output["avg_logprob"], no_speech_prob=output["no_speech_prob"], tokens=subsegment["tokens"], start=subsegment["start"], end=subsegment["end"], compression_ratio=get_compression_ratio( - self.tokenizer.decode(subsegment["tokens"]) + tokenizer.decode(subsegment["tokens"]) ), seek=int( chunk_metadata["start_time"] * self.model.frames_per_second @@ -179,19 +161,88 @@ def forward(self, features, chunks_metadata, **forward_params): for subsegment in subsegments ] ) - if forward_params["word_timestamps"]: + if options.word_timestamps: self.last_speech_timestamp = self.model.add_word_timestamps( segmented_outputs, - self.tokenizer, + tokenizer, encoder_output, segment_sizes, - forward_params["prepend_punctuations"], - forward_params["append_punctuations"], + options.prepend_punctuations, + options.append_punctuations, self.last_speech_timestamp, ) return segmented_outputs + def generate_segment_batched( + self, + features: np.ndarray, + tokenizer: Tokenizer, + options: TranscriptionOptions, + ): + batch_size = features.shape[0] + + prompt = self.model.get_prompt( + tokenizer, + previous_tokens=( + tokenizer.encode(options.initial_prompt) + if options.initial_prompt is not None + else [] + ), + without_timestamps=options.without_timestamps, + hotwords=options.hotwords, + ) + + if options.max_new_tokens is not None: + max_length = len(prompt) + options.max_new_tokens + else: + max_length = self.model.max_length + + if max_length > self.model.max_length: + raise ValueError( + f"The length of the prompt is {len(prompt)}, and the `max_new_tokens` " + f"{max_length - len(prompt)}. Thus, the combined length of the prompt " + f"and `max_new_tokens` is: {max_length}. This exceeds the " + f"`max_length` of the Whisper model: {self.model.max_length}. " + "You should either reduce the length of your prompt, or " + "reduce the value of `max_new_tokens`, " + f"so that their combined length is less that {self.model.max_length}." + ) + + encoder_output = self.model.encode(features) + + results = self.model.model.generate( + encoder_output, + [prompt] * batch_size, + beam_size=options.beam_size, + patience=options.patience, + length_penalty=options.length_penalty, + max_length=max_length, + suppress_blank=options.suppress_blank, + suppress_tokens=options.suppress_tokens, + return_scores=True, + return_no_speech_prob=True, + sampling_temperature=options.temperatures[0], + repetition_penalty=options.repetition_penalty, + no_repeat_ngram_size=options.no_repeat_ngram_size, + ) + + output = [] + for result in results: + # return scores + seq_len = len(result.sequences_ids[0]) + cum_logprob = result.scores[0] * (seq_len**options.length_penalty) + + output.append( + dict( + avg_logprob=cum_logprob / (seq_len + 1), + no_speech_prob=result.no_speech_prob, + tokens=result.sequences_ids[0], + ) + ) + + return encoder_output, output + def transcribe( self, audio: Union[str, BinaryIO, np.ndarray], @@ -216,20 +267,26 @@ def transcribe( log_prob_threshold: Optional[float] = -1.0, log_prob_low_threshold: Optional[float] = None, no_speech_threshold: Optional[float] = 0.6, + condition_on_previous_text: bool = True, + prompt_reset_on_temperature: float = 0.5, initial_prompt: Optional[Union[str, Iterable[int]]] = None, prefix: Optional[str] = None, suppress_blank: bool = True, suppress_tokens: Optional[List[int]] = [-1], without_timestamps: bool = True, + max_initial_timestamp: float = 1.0, word_timestamps: bool = False, prepend_punctuations: str = "\"'“¿([{-", append_punctuations: str = "\"'.。,,!!??::”)]}、", + multilingual: bool = False, + output_language: Optional[str] = None, vad_filter: bool = True, vad_parameters: Optional[Union[dict, VadOptions]] = None, max_new_tokens: Optional[int] = None, chunk_length: Optional[int] = None, clip_timestamps: Optional[List[dict]] = None, - batch_size: int = 16, + hallucination_silence_threshold: Optional[float] = None, + batch_size: int = 8, hotwords: Optional[str] = None, language_detection_threshold: Optional[float] = 0.5, language_detection_segments: int = 1, @@ -250,22 +307,10 @@ def transcribe( repetition_penalty: Penalty applied to the score of previously generated tokens (set > 1 to penalize). no_repeat_ngram_size: Prevent repetitions of ngrams with this size (set 0 to disable). - temperature: Temperature for sampling. It can be a tuple of temperatures, - which will be successively used upon failures according to either - `compression_ratio_threshold` or `log_prob_threshold`. - compression_ratio_threshold: If the gzip compression ratio is above this value, - treat as failed. - log_prob_threshold: If the average log probability over sampled tokens is - below this value, treat as failed. - log_prob_low_threshold: This parameter alone is sufficient to skip an output text, - whereas log_prob_threshold also looks for appropriate no_speech_threshold value. - This value should be less than log_prob_threshold. - no_speech_threshold: If the no_speech probability is higher than this value AND - the average log probability over sampled tokens is below `log_prob_threshold`, - consider the segment as silent. + temperature: Temperature for sampling. If a list or tuple is passed, + only the first value is used. initial_prompt: Optional text string or iterable of token ids to provide as a - prompt for the first window. - prefix: Optional text to provide as a prefix for the first window. + prompt for the each window. suppress_blank: Suppress blank outputs at the beginning of the sampling. suppress_tokens: List of token IDs to suppress. -1 will suppress a default set of symbols as defined in `tokenizer.non_speech_tokens()`. @@ -296,29 +341,32 @@ def transcribe( higher than this value, the language is detected. language_detection_segments: Number of segments to consider for the language detection. - Static params: (Fixed for batched version) - max_initial_timestamp: The initial timestamp cannot be later than this, set at 0.0. - multilingual: If True, perform transcription on multilingual videos. Set as False. - output_language: Valid only if multilingual is set to True. - Specifies the string representing the output language. One of - 'en' (English) or 'hybrid' (code-switched transcription). set as None. + Unused Arguments + compression_ratio_threshold: If the gzip compression ratio is above this value, + treat as failed. + log_prob_threshold: If the average log probability over sampled tokens is + below this value, treat as failed. + log_prob_low_threshold: This parameter alone is sufficient to skip an output text, + whereas log_prob_threshold also looks for appropriate no_speech_threshold value. + This value should be less than log_prob_threshold. + no_speech_threshold: If the no_speech probability is higher than this value AND + the average log probability over sampled tokens is below `log_prob_threshold`, + consider the segment as silent. condition_on_previous_text: If True, the previous output of the model is provided as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop, such as repetition looping or timestamps going out of sync. Set as False prompt_reset_on_temperature: Resets prompt if temperature is above this value. Arg has effect only if condition_on_previous_text is True. Set at 0.5 - #TODO: support "hallucination_silence_threshold" when "word_timestamps=True" + prefix: Optional text to provide as a prefix at the beginning of each window. + max_initial_timestamp: The initial timestamp cannot be later than this, set at 0.0. + multilingual: If True, perform transcription on multilingual videos. Set as False. + output_language: Valid only if multilingual is set to True. + Specifies the string representing the output language. One of + 'en' (English) or 'hybrid' (code-switched transcription). set as None. hallucination_silence_threshold: Optional[float] When word_timestamps is True, skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected. set as None. - - unused: - language_detection_threshold: If the maximum probability of the language tokens is - higher than this value, the language is detected. - language_detection_segments: Number of segments to consider for the language detection. - - Returns: A tuple with: @@ -410,7 +458,7 @@ def transcribe( language_probability = 1 - self.tokenizer = Tokenizer( + tokenizer = Tokenizer( self.model.hf_tokenizer, self.model.model.is_multilingual, task=task, @@ -421,8 +469,7 @@ def transcribe( np.stack([pad_or_trim(feature) for feature in features]) if features else [] ) - # batched options: see the difference with default options in WhisperModel - batched_options = TranscriptionOptions( + options = TranscriptionOptions( beam_size=beam_size, best_of=best_of, patience=patience, @@ -434,12 +481,14 @@ def transcribe( no_speech_threshold=no_speech_threshold, compression_ratio_threshold=compression_ratio_threshold, temperatures=( - temperature if isinstance(temperature, (list, tuple)) else [temperature] + temperature[:1] + if isinstance(temperature, (list, tuple)) + else [temperature] ), initial_prompt=initial_prompt, prefix=prefix, suppress_blank=suppress_blank, - suppress_tokens=get_suppressed_tokens(self.tokenizer, suppress_tokens), + suppress_tokens=get_suppressed_tokens(tokenizer, suppress_tokens), prepend_punctuations=prepend_punctuations, append_punctuations=append_punctuations, max_new_tokens=max_new_tokens, @@ -447,7 +496,7 @@ def transcribe( word_timestamps=word_timestamps, hallucination_silence_threshold=None, condition_on_previous_text=False, - clip_timestamps="0", + clip_timestamps=clip_timestamps, prompt_reset_on_temperature=0.5, multilingual=False, output_language=None, @@ -460,31 +509,33 @@ def transcribe( language_probability=language_probability, duration=duration, duration_after_vad=duration_after_vad, - transcription_options=batched_options, - vad_options=None, + transcription_options=options, + vad_options=vad_parameters, all_language_probs=all_language_probs, ) segments = self._batched_segments_generator( features, + tokenizer, chunks_metadata, batch_size, - batched_options, + options, log_progress, ) return segments, info def _batched_segments_generator( - self, features, chunks_metadata, batch_size, options, log_progress + self, features, tokenizer, chunks_metadata, batch_size, options, log_progress ): pbar = tqdm(total=len(features), disable=not log_progress, position=0) seg_idx = 0 for i in range(0, len(features), batch_size): results = self.forward( features[i : i + batch_size], + tokenizer, chunks_metadata[i : i + batch_size], - **asdict(options), + options, ) for result in results: @@ -505,6 +556,7 @@ def _batched_segments_generator( avg_logprob=segment["avg_logprob"], no_speech_prob=segment["no_speech_prob"], compression_ratio=segment["compression_ratio"], + temperature=options.temperatures[0], ) pbar.update(1) @@ -1689,57 +1741,6 @@ def find_alignment( ) return return_list - def generate_segment_batched( - self, - features: np.ndarray, - tokenizer: Tokenizer, - options: dict, - ): - batch_size = features.shape[0] - all_tokens = [] - prompt_reset_since = 0 - - if options["initial_prompt"] is not None: - initial_prompt = " " + options["initial_prompt"].strip() - initial_prompt_tokens = tokenizer.encode(initial_prompt) - all_tokens.extend(initial_prompt_tokens) - previous_tokens = all_tokens[prompt_reset_since:] - prompt = self.get_prompt( - tokenizer, - previous_tokens, - without_timestamps=options["without_timestamps"], - prefix=options["prefix"], - ) - - encoder_output = self.encode(features) - - result = self.model.generate( - encoder_output, - [prompt] * batch_size, - beam_size=options["beam_size"], - patience=options["patience"], - length_penalty=options["length_penalty"], - max_length=self.max_length, - suppress_blank=options["suppress_blank"], - suppress_tokens=options["suppress_tokens"], - return_scores=True, - return_no_speech_prob=True, - ) - - output = [] - for res in result: - output.append({}) - # return scores - seq_len = len(res.sequences_ids[0]) - cum_logprob = res.scores[0] * (seq_len ** options["length_penalty"]) - output[-1]["avg_logprob"] = cum_logprob / (seq_len + 1) - - # return no speech prob - output[-1]["no_speech_prob"] = res.no_speech_prob - output[-1]["tokens"] = res.sequences_ids[0] - - return encoder_output, output - def detect_language( self, audio: Optional[np.ndarray] = None,