Skip to content

Commit

Permalink
pad audio instead of spectrogram features
Browse files Browse the repository at this point in the history
  • Loading branch information
MahmoudAshraf97 committed Oct 24, 2024
1 parent 2dbca5e commit a483549
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 70 deletions.
23 changes: 0 additions & 23 deletions faster_whisper/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,26 +107,3 @@ def _resample_frames(frames, resampler):
# Add None to flush the resampler.
for frame in itertools.chain(frames, [None]):
yield from resampler.resample(frame)


def pad_or_trim(array, length: int, *, axis: int = -1):
"""
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
"""
axis = axis % array.ndim
if array.shape[axis] > length:
idx = [Ellipsis] * axis + [slice(length)] + [Ellipsis] * (array.ndim - axis - 1)
return array[idx]

if array.shape[axis] < length:
pad_widths = (
[
0,
]
* array.ndim
* 2
)
pad_widths[2 * axis] = length - array.shape[axis]
array = torch.nn.functional.pad(array, tuple(pad_widths[::-1]))

return array
3 changes: 1 addition & 2 deletions faster_whisper/feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(
self.hop_length = hop_length
self.chunk_length = chunk_length
self.n_samples = chunk_length * sampling_rate
self.nb_max_frames = self.n_samples // hop_length
self.nb_max_frames = (30 * sampling_rate) // hop_length
self.time_per_frame = hop_length / sampling_rate
self.sampling_rate = sampling_rate
self.mel_filters = self.get_mel_filters(
Expand Down Expand Up @@ -82,7 +82,6 @@ def __call__(self, waveform, padding=True, chunk_length=None, to_cpu=False):

if chunk_length is not None:
self.n_samples = chunk_length * self.sampling_rate
self.nb_max_frames = self.n_samples // self.hop_length

if waveform.dtype is not torch.float32:
waveform = waveform.to(torch.float32)
Expand Down
96 changes: 51 additions & 45 deletions faster_whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from tqdm import tqdm

from faster_whisper.audio import decode_audio, pad_or_trim
from faster_whisper.audio import decode_audio
from faster_whisper.feature_extractor import FeatureExtractor
from faster_whisper.tokenizer import _LANGUAGE_CODES, Tokenizer
from faster_whisper.utils import download_model, format_timestamp, get_end, get_logger
Expand Down Expand Up @@ -239,7 +239,7 @@ def transcribe(
vad_filter: bool = True,
vad_parameters: Optional[Union[dict, VadOptions]] = None,
max_new_tokens: Optional[int] = None,
chunk_length: Optional[int] = None,
chunk_length: Optional[int] = 30,
clip_timestamps: Optional[List[dict]] = None,
batch_size: int = 16,
hotwords: Optional[str] = None,
Expand Down Expand Up @@ -663,7 +663,7 @@ def transcribe(
vad_filter: bool = False,
vad_parameters: Optional[Union[dict, VadOptions]] = None,
max_new_tokens: Optional[int] = None,
chunk_length: Optional[int] = None,
chunk_length: Optional[int] = 30,
clip_timestamps: Union[str, List[float]] = "0",
hallucination_silence_threshold: Optional[float] = None,
hotwords: Optional[str] = None,
Expand Down Expand Up @@ -753,6 +753,7 @@ def transcribe(
"""

sampling_rate = self.feature_extractor.sampling_rate
self.feature_extractor.n_samples = chunk_length * sampling_rate

if isinstance(audio, np.ndarray):
audio = torch.from_numpy(audio)
Expand Down Expand Up @@ -797,11 +798,6 @@ def transcribe(
else:
speech_chunks = None

to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1
features = self.feature_extractor(
audio, chunk_length=chunk_length, to_cpu=to_cpu
)

encoder_output = None
all_language_probs = None

Expand All @@ -828,12 +824,10 @@ def transcribe(
if isinstance(clip_timestamps, str)
else clip_timestamps[0]
)
content_frames = (
features.shape[-1] - self.feature_extractor.nb_max_frames
)
content_frames = audio.shape[0]
seek = (
int(start_timestamp * self.frames_per_second)
if start_timestamp * self.frames_per_second < content_frames
int(start_timestamp * sampling_rate)
if start_timestamp * sampling_rate < content_frames
else 0
)
end_frames = min(
Expand All @@ -844,9 +838,9 @@ def transcribe(
)
detected_language_info = {}
while seek <= end_frames:
segment = features[
:, seek : seek + self.feature_extractor.nb_max_frames
]
segment = self.feature_extractor(
audio[seek : seek + self.feature_extractor.n_samples]
)[:, : self.feature_extractor.nb_max_frames]
encoder_output = self.encode(segment)
# results is a list of tuple[str, float] with language names and
# probabilities.
Expand All @@ -865,7 +859,7 @@ def transcribe(
detected_language_info.setdefault(language, []).append(
language_probability
)
seek += segment.shape[-1]
seek += self.feature_extractor.n_samples
else:
# If no language detected for all segments, the majority vote of the highest
# projected languages for all segments is used to determine the language.
Expand Down Expand Up @@ -934,7 +928,9 @@ def transcribe(
hotwords=hotwords,
)

segments = self.generate_segments(features, tokenizer, options, encoder_output)
segments = self.generate_segments(
audio, chunk_length, tokenizer, options, encoder_output
)

if speech_chunks:
segments = restore_speech_timestamps(segments, speech_chunks, sampling_rate)
Expand Down Expand Up @@ -1005,7 +1001,11 @@ def _split_segments_by_timestamps(
last_timestamp_position = (
tokens[last_slice - 1] - tokenizer.timestamp_begin
)
seek += last_timestamp_position * self.input_stride
seek += (
last_timestamp_position
* self.input_stride
* self.feature_extractor.hop_length
)

else:
duration = segment_duration
Expand All @@ -1031,13 +1031,14 @@ def _split_segments_by_timestamps(

def generate_segments(
self,
features: torch.Tensor,
audio: torch.Tensor,
chunk_length: float,
tokenizer: Tokenizer,
options: TranscriptionOptions,
encoder_output: Optional[ctranslate2.StorageView] = None,
) -> Iterable[Segment]:
content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames
content_duration = float(content_frames * self.feature_extractor.time_per_frame)
content_frames = audio.shape[0]
content_duration = audio.shape[0] / self.feature_extractor.sampling_rate

if isinstance(options.clip_timestamps, str):
options = options._replace(
Expand All @@ -1051,7 +1052,8 @@ def generate_segments(
]
)
seek_points: List[int] = [
round(ts * self.frames_per_second) for ts in options.clip_timestamps
round(ts * self.feature_extractor.sampling_rate)
for ts in options.clip_timestamps
]
if len(seek_points) == 0:
seek_points.append(0)
Expand Down Expand Up @@ -1093,19 +1095,20 @@ def generate_segments(
if clip_idx < len(seek_clips):
seek = seek_clips[clip_idx][0]
continue
time_offset = seek * self.feature_extractor.time_per_frame
window_end_time = float(
(seek + self.feature_extractor.nb_max_frames)
* self.feature_extractor.time_per_frame
)
time_offset = seek / self.feature_extractor.sampling_rate
window_end_time = (
seek + self.feature_extractor.n_samples
) / self.feature_extractor.sampling_rate

segment_size = min(
self.feature_extractor.nb_max_frames,
self.feature_extractor.n_samples,
content_frames - seek,
seek_clip_end - seek,
)
segment = features[:, seek : seek + segment_size]
segment_duration = segment_size * self.feature_extractor.time_per_frame
segment = pad_or_trim(segment, self.feature_extractor.nb_max_frames)
segment = self.feature_extractor(
audio[seek : seek + segment_size], chunk_length=chunk_length
)[:, : self.feature_extractor.nb_max_frames]
segment_duration = segment_size / self.feature_extractor.sampling_rate

if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug(
Expand Down Expand Up @@ -1233,15 +1236,17 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]:
[current_segments],
tokenizer,
encoder_output,
segment_size,
segment_size // self.feature_extractor.hop_length,
options.prepend_punctuations,
options.append_punctuations,
last_speech_timestamp=last_speech_timestamp,
)
if not single_timestamp_ending:
last_word_end = get_end(current_segments)
if last_word_end is not None and last_word_end > time_offset:
seek = round(last_word_end * self.frames_per_second)
seek = round(
last_word_end * self.feature_extractor.sampling_rate
)

# skip silence before possible hallucinations
if options.hallucination_silence_threshold is not None:
Expand All @@ -1252,7 +1257,9 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]:
if first_segment is not None and is_segment_anomaly(first_segment):
gap = first_segment["start"] - time_offset
if gap > threshold:
seek = previous_seek + round(gap * self.frames_per_second)
seek = previous_seek + round(
gap * self.feature_extractor.sampling_rate
)
continue

# skip silence before any possible hallucination that is surrounded
Expand Down Expand Up @@ -1283,7 +1290,7 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]:
if silence_before and silence_after:
seek = round(
max(time_offset + 1, segment["start"])
* self.frames_per_second
* self.feature_extractor.sampling_rate
)
if content_duration - segment["end"] < threshold:
seek = content_frames
Expand Down Expand Up @@ -1849,15 +1856,8 @@ def detect_language_multi_segment(
if duration < 1.0:
return {"language_code": None, "language_confidence": 1.0}

# number of feature frames in 30 seconds of audio is 3000
nb_max_frames = self.feature_extractor.nb_max_frames

# extract features from audio with padding (default)
to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1
features = self.feature_extractor(audio, to_cpu=to_cpu)

# number of segments in the audio
num_segments = features.shape[-1] // nb_max_frames
num_segments = ceil(audio.shape[0] / self.feature_extractor.n_samples)
# more number of segments than possible with the duration of file
if num_detection_segments > num_segments:
logging.warning(
Expand Down Expand Up @@ -1893,7 +1893,13 @@ def detect_language_multi_segment(
# We need to get sufficient number of confident predictions per language, not in total.

for i in indices:
segment_features = features[:, i * nb_max_frames : (i + 1) * nb_max_frames]
segment_features = self.feature_extractor(
audio[
i
* self.feature_extractor.n_samples : (i + 1)
* self.feature_extractor.n_samples
]
)[:, : self.feature_extractor.nb_max_frames]
try:
encoder_output = self.encode(segment_features)
results = self.model.detect_language(encoder_output)[0]
Expand Down

0 comments on commit a483549

Please sign in to comment.