Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: correct type hints #1150

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 37 additions & 29 deletions faster_whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from dataclasses import asdict, dataclass
from inspect import signature
from math import ceil
from typing import BinaryIO, Iterable, List, Optional, Tuple, Union
from typing import Any, BinaryIO, Iterable, List, Optional, Tuple, Union
from warnings import warn

import ctranslate2
Expand Down Expand Up @@ -82,11 +82,11 @@ class TranscriptionOptions:
compression_ratio_threshold: Optional[float]
condition_on_previous_text: bool
prompt_reset_on_temperature: float
temperatures: List[float]
temperatures: Union[List[float], Tuple[float, ...]]
initial_prompt: Optional[Union[str, Iterable[int]]]
prefix: Optional[str]
suppress_blank: bool
suppress_tokens: Optional[List[int]]
suppress_tokens: Union[List[int], Tuple[int, ...]]
without_timestamps: bool
max_initial_timestamp: float
word_timestamps: bool
Expand All @@ -108,7 +108,7 @@ class TranscriptionInfo:
duration_after_vad: float
all_language_probs: Optional[List[Tuple[str, float]]]
transcription_options: TranscriptionOptions
vad_options: VadOptions
vad_options: Optional[VadOptions]


class BatchedInferencePipeline:
Expand All @@ -123,7 +123,6 @@ def forward(self, features, tokenizer, chunks_metadata, options):
encoder_output, outputs = self.generate_segment_batched(
features, tokenizer, options
)

segmented_outputs = []
segment_sizes = []
for chunk_metadata, output in zip(chunks_metadata, outputs):
Expand All @@ -132,8 +131,8 @@ def forward(self, features, tokenizer, chunks_metadata, options):
segment_sizes.append(segment_size)
(
subsegments,
seek,
single_timestamp_ending,
_,
_,
) = self.model._split_segments_by_timestamps(
tokenizer=tokenizer,
tokens=output["tokens"],
Expand Down Expand Up @@ -288,7 +287,7 @@ def transcribe(
hallucination_silence_threshold: Optional[float] = None,
batch_size: int = 8,
hotwords: Optional[str] = None,
language_detection_threshold: Optional[float] = 0.5,
language_detection_threshold: float = 0.5,
language_detection_segments: int = 1,
) -> Tuple[Iterable[Segment], TranscriptionInfo]:
"""transcribe audio in chunks in batched fashion and return with language info.
Expand Down Expand Up @@ -576,7 +575,7 @@ def __init__(
num_workers: int = 1,
download_root: Optional[str] = None,
local_files_only: bool = False,
files: dict = None,
files: Optional[dict] = None,
**model_kwargs,
):
"""Initializes the Whisper model.
Expand Down Expand Up @@ -729,7 +728,7 @@ def transcribe(
clip_timestamps: Union[str, List[float]] = "0",
hallucination_silence_threshold: Optional[float] = None,
hotwords: Optional[str] = None,
language_detection_threshold: Optional[float] = 0.5,
language_detection_threshold: float = 0.5,
language_detection_segments: int = 1,
) -> Tuple[Iterable[Segment], TranscriptionInfo]:
"""Transcribes an input file.
Expand Down Expand Up @@ -833,7 +832,7 @@ def transcribe(
elif isinstance(vad_parameters, dict):
vad_parameters = VadOptions(**vad_parameters)
speech_chunks = get_speech_timestamps(audio, vad_parameters)
audio_chunks, chunks_metadata = collect_chunks(audio, speech_chunks)
audio_chunks, _ = collect_chunks(audio, speech_chunks)
audio = np.concatenate(audio_chunks, axis=0)
duration_after_vad = audio.shape[0] / sampling_rate

Expand Down Expand Up @@ -933,7 +932,7 @@ def transcribe(
condition_on_previous_text=condition_on_previous_text,
prompt_reset_on_temperature=prompt_reset_on_temperature,
temperatures=(
temperature if isinstance(temperature, (list, tuple)) else [temperature]
temperature if isinstance(temperature, (List, Tuple)) else [temperature]
),
initial_prompt=initial_prompt,
prefix=prefix,
Expand Down Expand Up @@ -962,7 +961,8 @@ def transcribe(

if speech_chunks:
segments = restore_speech_timestamps(segments, speech_chunks, sampling_rate)

if isinstance(vad_parameters, dict):
vad_parameters = VadOptions(**vad_parameters)
info = TranscriptionInfo(
language=language,
language_probability=language_probability,
Expand All @@ -982,7 +982,7 @@ def _split_segments_by_timestamps(
segment_size: int,
segment_duration: float,
seek: int,
) -> List[List[int]]:
) -> Tuple[List[Any], int, bool]:
current_segments = []
single_timestamp_ending = (
len(tokens) >= 2 and tokens[-2] < tokenizer.timestamp_begin <= tokens[-1]
Expand Down Expand Up @@ -1550,8 +1550,8 @@ def add_word_timestamps(
num_frames: int,
prepend_punctuations: str,
append_punctuations: str,
last_speech_timestamp: float,
) -> float:
last_speech_timestamp: Union[float, None],
) -> Optional[float]:
if len(segments) == 0:
return

Expand Down Expand Up @@ -1698,9 +1698,11 @@ def find_alignment(
text_indices = np.array([pair[0] for pair in alignments])
time_indices = np.array([pair[1] for pair in alignments])

words, word_tokens = tokenizer.split_to_word_tokens(
text_token + [tokenizer.eot]
)
if isinstance(text_token, int):
tokens = [text_token] + [tokenizer.eot]
else:
tokens = text_token + [tokenizer.eot]
words, word_tokens = tokenizer.split_to_word_tokens(tokens)
if len(word_tokens) <= 1:
# return on eot only
# >>> np.pad([], (1, 0))
Expand Down Expand Up @@ -1746,7 +1748,7 @@ def detect_language(
audio: Optional[np.ndarray] = None,
features: Optional[np.ndarray] = None,
vad_filter: bool = False,
vad_parameters: Union[dict, VadOptions] = None,
vad_parameters: Optional[Union[dict, VadOptions]] = None,
language_detection_segments: int = 1,
language_detection_threshold: float = 0.5,
) -> Tuple[str, float, List[Tuple[str, float]]]:
Expand Down Expand Up @@ -1778,18 +1780,24 @@ def detect_language(
if audio is not None:
if vad_filter:
speech_chunks = get_speech_timestamps(audio, vad_parameters)
audio_chunks, chunks_metadata = collect_chunks(audio, speech_chunks)
audio_chunks, _ = collect_chunks(audio, speech_chunks)
audio = np.concatenate(audio_chunks, axis=0)

assert (
audio is not None
), "Audio have a problem while concatanating the audio_chunks; return None"
audio = audio[
: language_detection_segments * self.feature_extractor.n_samples
]
features = self.feature_extractor(audio)

assert (
features is not None
), "No features extracted from audio file; return None"
features = features[
..., : language_detection_segments * self.feature_extractor.nb_max_frames
]

assert (
features is not None
), "No features extracted when detectting language in audio segments; return None"
detected_language_info = {}
for i in range(0, features.shape[-1], self.feature_extractor.nb_max_frames):
encoder_output = self.encode(
Expand Down Expand Up @@ -1859,13 +1867,13 @@ def get_compression_ratio(text: str) -> float:

def get_suppressed_tokens(
tokenizer: Tokenizer,
suppress_tokens: Tuple[int],
) -> Optional[List[int]]:
if -1 in suppress_tokens:
suppress_tokens: Optional[List[int]],
) -> Tuple[int, ...]:
if suppress_tokens is None or len(suppress_tokens) == 0:
suppress_tokens = [] # interpret empty string as an empty list
elif -1 in suppress_tokens:
suppress_tokens = [t for t in suppress_tokens if t >= 0]
suppress_tokens.extend(tokenizer.non_speech_tokens)
elif suppress_tokens is None or len(suppress_tokens) == 0:
suppress_tokens = [] # interpret empty string as an empty list
else:
assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"

Expand Down
7 changes: 4 additions & 3 deletions faster_whisper/vad.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os

from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union

import numpy as np

Expand Down Expand Up @@ -44,7 +44,7 @@ class VadOptions:

def get_speech_timestamps(
audio: np.ndarray,
vad_options: Optional[VadOptions] = None,
vad_options: Optional[Union[dict, VadOptions]] = None,
sampling_rate: int = 16000,
**kwargs,
) -> List[dict]:
Expand All @@ -61,7 +61,8 @@ def get_speech_timestamps(
"""
if vad_options is None:
vad_options = VadOptions(**kwargs)

if isinstance(vad_options, dict):
vad_options = VadOptions(**vad_options)
onset = vad_options.onset
min_speech_duration_ms = vad_options.min_speech_duration_ms
max_speech_duration_s = vad_options.max_speech_duration_s
Expand Down
Loading