diff --git a/faster_whisper/audio.py b/faster_whisper/audio.py index 1f1970aa..1cebb39f 100644 --- a/faster_whisper/audio.py +++ b/faster_whisper/audio.py @@ -107,26 +107,3 @@ def _resample_frames(frames, resampler): # Add None to flush the resampler. for frame in itertools.chain(frames, [None]): yield from resampler.resample(frame) - - -def pad_or_trim(array, length: int, *, axis: int = -1): - """ - Pad or trim the audio array to N_SAMPLES, as expected by the encoder. - """ - axis = axis % array.ndim - if array.shape[axis] > length: - idx = [Ellipsis] * axis + [slice(length)] + [Ellipsis] * (array.ndim - axis - 1) - return array[idx] - - if array.shape[axis] < length: - pad_widths = ( - [ - 0, - ] - * array.ndim - * 2 - ) - pad_widths[2 * axis] = length - array.shape[axis] - array = torch.nn.functional.pad(array, tuple(pad_widths[::-1])) - - return array diff --git a/faster_whisper/feature_extractor.py b/faster_whisper/feature_extractor.py index 6371d5ef..7cc0e200 100644 --- a/faster_whisper/feature_extractor.py +++ b/faster_whisper/feature_extractor.py @@ -20,7 +20,7 @@ def __init__( self.hop_length = hop_length self.chunk_length = chunk_length self.n_samples = chunk_length * sampling_rate - self.nb_max_frames = self.n_samples // hop_length + self.nb_max_frames = (30 * sampling_rate) // hop_length self.time_per_frame = hop_length / sampling_rate self.sampling_rate = sampling_rate self.mel_filters = self.get_mel_filters( @@ -82,7 +82,6 @@ def __call__(self, waveform, padding=True, chunk_length=None, to_cpu=False): if chunk_length is not None: self.n_samples = chunk_length * self.sampling_rate - self.nb_max_frames = self.n_samples // self.hop_length if waveform.dtype is not torch.float32: waveform = waveform.to(torch.float32) diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index 7611237a..4efb3312 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -17,7 +17,7 @@ from tqdm import tqdm -from faster_whisper.audio import decode_audio, pad_or_trim +from faster_whisper.audio import decode_audio from faster_whisper.feature_extractor import FeatureExtractor from faster_whisper.tokenizer import _LANGUAGE_CODES, Tokenizer from faster_whisper.utils import download_model, format_timestamp, get_end, get_logger @@ -239,7 +239,7 @@ def transcribe( vad_filter: bool = True, vad_parameters: Optional[Union[dict, VadOptions]] = None, max_new_tokens: Optional[int] = None, - chunk_length: Optional[int] = None, + chunk_length: Optional[int] = 30, clip_timestamps: Optional[List[dict]] = None, batch_size: int = 16, hotwords: Optional[str] = None, @@ -663,7 +663,7 @@ def transcribe( vad_filter: bool = False, vad_parameters: Optional[Union[dict, VadOptions]] = None, max_new_tokens: Optional[int] = None, - chunk_length: Optional[int] = None, + chunk_length: Optional[int] = 30, clip_timestamps: Union[str, List[float]] = "0", hallucination_silence_threshold: Optional[float] = None, hotwords: Optional[str] = None, @@ -753,6 +753,7 @@ def transcribe( """ sampling_rate = self.feature_extractor.sampling_rate + self.feature_extractor.n_samples = chunk_length * sampling_rate if isinstance(audio, np.ndarray): audio = torch.from_numpy(audio) @@ -797,11 +798,6 @@ def transcribe( else: speech_chunks = None - to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1 - features = self.feature_extractor( - audio, chunk_length=chunk_length, to_cpu=to_cpu - ) - encoder_output = None all_language_probs = None @@ -828,12 +824,10 @@ def transcribe( if isinstance(clip_timestamps, str) else clip_timestamps[0] ) - content_frames = ( - features.shape[-1] - self.feature_extractor.nb_max_frames - ) + content_frames = audio.shape[0] seek = ( - int(start_timestamp * self.frames_per_second) - if start_timestamp * self.frames_per_second < content_frames + int(start_timestamp * sampling_rate) + if start_timestamp * sampling_rate < content_frames else 0 ) end_frames = min( @@ -844,9 +838,9 @@ def transcribe( ) detected_language_info = {} while seek <= end_frames: - segment = features[ - :, seek : seek + self.feature_extractor.nb_max_frames - ] + segment = self.feature_extractor( + audio[seek : seek + self.feature_extractor.n_samples] + )[:, : self.feature_extractor.nb_max_frames] encoder_output = self.encode(segment) # results is a list of tuple[str, float] with language names and # probabilities. @@ -865,7 +859,7 @@ def transcribe( detected_language_info.setdefault(language, []).append( language_probability ) - seek += segment.shape[-1] + seek += self.feature_extractor.n_samples else: # If no language detected for all segments, the majority vote of the highest # projected languages for all segments is used to determine the language. @@ -934,7 +928,9 @@ def transcribe( hotwords=hotwords, ) - segments = self.generate_segments(features, tokenizer, options, encoder_output) + segments = self.generate_segments( + audio, chunk_length, tokenizer, options, encoder_output + ) if speech_chunks: segments = restore_speech_timestamps(segments, speech_chunks, sampling_rate) @@ -1005,7 +1001,11 @@ def _split_segments_by_timestamps( last_timestamp_position = ( tokens[last_slice - 1] - tokenizer.timestamp_begin ) - seek += last_timestamp_position * self.input_stride + seek += ( + last_timestamp_position + * self.input_stride + * self.feature_extractor.hop_length + ) else: duration = segment_duration @@ -1031,13 +1031,14 @@ def _split_segments_by_timestamps( def generate_segments( self, - features: torch.Tensor, + audio: torch.Tensor, + chunk_length: float, tokenizer: Tokenizer, options: TranscriptionOptions, encoder_output: Optional[ctranslate2.StorageView] = None, ) -> Iterable[Segment]: - content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames - content_duration = float(content_frames * self.feature_extractor.time_per_frame) + content_frames = audio.shape[0] + content_duration = audio.shape[0] / self.feature_extractor.sampling_rate if isinstance(options.clip_timestamps, str): options = options._replace( @@ -1051,7 +1052,8 @@ def generate_segments( ] ) seek_points: List[int] = [ - round(ts * self.frames_per_second) for ts in options.clip_timestamps + round(ts * self.feature_extractor.sampling_rate) + for ts in options.clip_timestamps ] if len(seek_points) == 0: seek_points.append(0) @@ -1093,19 +1095,20 @@ def generate_segments( if clip_idx < len(seek_clips): seek = seek_clips[clip_idx][0] continue - time_offset = seek * self.feature_extractor.time_per_frame - window_end_time = float( - (seek + self.feature_extractor.nb_max_frames) - * self.feature_extractor.time_per_frame - ) + time_offset = seek / self.feature_extractor.sampling_rate + window_end_time = ( + seek + self.feature_extractor.n_samples + ) / self.feature_extractor.sampling_rate + segment_size = min( - self.feature_extractor.nb_max_frames, + self.feature_extractor.n_samples, content_frames - seek, seek_clip_end - seek, ) - segment = features[:, seek : seek + segment_size] - segment_duration = segment_size * self.feature_extractor.time_per_frame - segment = pad_or_trim(segment, self.feature_extractor.nb_max_frames) + segment = self.feature_extractor( + audio[seek : seek + segment_size], chunk_length=chunk_length + )[:, : self.feature_extractor.nb_max_frames] + segment_duration = segment_size / self.feature_extractor.sampling_rate if self.logger.isEnabledFor(logging.DEBUG): self.logger.debug( @@ -1233,7 +1236,7 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]: [current_segments], tokenizer, encoder_output, - segment_size, + segment_size // self.feature_extractor.hop_length, options.prepend_punctuations, options.append_punctuations, last_speech_timestamp=last_speech_timestamp, @@ -1241,7 +1244,9 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]: if not single_timestamp_ending: last_word_end = get_end(current_segments) if last_word_end is not None and last_word_end > time_offset: - seek = round(last_word_end * self.frames_per_second) + seek = round( + last_word_end * self.feature_extractor.sampling_rate + ) # skip silence before possible hallucinations if options.hallucination_silence_threshold is not None: @@ -1252,7 +1257,9 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]: if first_segment is not None and is_segment_anomaly(first_segment): gap = first_segment["start"] - time_offset if gap > threshold: - seek = previous_seek + round(gap * self.frames_per_second) + seek = previous_seek + round( + gap * self.feature_extractor.sampling_rate + ) continue # skip silence before any possible hallucination that is surrounded @@ -1283,7 +1290,7 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]: if silence_before and silence_after: seek = round( max(time_offset + 1, segment["start"]) - * self.frames_per_second + * self.feature_extractor.sampling_rate ) if content_duration - segment["end"] < threshold: seek = content_frames @@ -1849,15 +1856,8 @@ def detect_language_multi_segment( if duration < 1.0: return {"language_code": None, "language_confidence": 1.0} - # number of feature frames in 30 seconds of audio is 3000 - nb_max_frames = self.feature_extractor.nb_max_frames - - # extract features from audio with padding (default) - to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1 - features = self.feature_extractor(audio, to_cpu=to_cpu) - # number of segments in the audio - num_segments = features.shape[-1] // nb_max_frames + num_segments = ceil(audio.shape[0] / self.feature_extractor.n_samples) # more number of segments than possible with the duration of file if num_detection_segments > num_segments: logging.warning( @@ -1893,7 +1893,13 @@ def detect_language_multi_segment( # We need to get sufficient number of confident predictions per language, not in total. for i in indices: - segment_features = features[:, i * nb_max_frames : (i + 1) * nb_max_frames] + segment_features = self.feature_extractor( + audio[ + i + * self.feature_extractor.n_samples : (i + 1) + * self.feature_extractor.n_samples + ] + )[:, : self.feature_extractor.nb_max_frames] try: encoder_output = self.encode(segment_features) results = self.model.detect_language(encoder_output)[0]