Skip to content

Commit

Permalink
Try to fix pylint issues, add more typings
Browse files Browse the repository at this point in the history
  • Loading branch information
Sirozha1337 committed Feb 2, 2024
1 parent 7f5fdba commit e44fb1c
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 59 deletions.
51 changes: 30 additions & 21 deletions auto_subtitle/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import warnings
import logging
from typing import Optional
from .models.Subtitles import Subtitles
from .utils.files import filename, write_srt
from .utils.ffmpeg import get_audio, add_subtitles
Expand All @@ -12,15 +13,11 @@

def process(args: dict):
model_name: str = args.pop("model")
output_dir: str = args.pop("output_dir")
output_type: str = args.pop("output_type")
subtitle_type: str = args.pop("subtitle_type")
language: str = args.pop("language")
sample_interval: list = args.pop("sample_interval")
target_language: str = args.pop("target_language")

logging.basicConfig(encoding='utf-8', level=logging.INFO)
os.makedirs(output_dir, exist_ok=True)

if model_name.endswith(".en"):
warnings.warn(
Expand All @@ -32,11 +29,18 @@ def process(args: dict):
args["language"] = language

if target_language != 'en':
logger.info(f"{target_language} is not English, Opus-MT will be used to perform translation.")
logger.info("%s is not English, Opus-MT will be used to perform translation.",
target_language)
# warnings.warn(
# f"{target_language} is not English, Opus-MT will be used to perform translation.")
args['task'] = 'transcribe'

output_args = {
"output_dir": args.pop("output_dir"),
"output_type": args.pop("output_type"),
"subtitle_type": args.pop("subtitle_type")
}

videos = args.pop('video')
audio_channel = args.pop('audio_channel')
model_args = {
Expand All @@ -45,32 +49,37 @@ def process(args: dict):
"compute_type": args.pop("compute_type")
}
transcribe_model = WhisperAI(model_args, args)
translate_model = EasyNMTWrapper(device=model_args['device']) if target_language != 'en' else None
translate_model = EasyNMTWrapper(
device=model_args['device']) if target_language != 'en' else None

os.makedirs(output_args["output_dir"], exist_ok=True)
for video in videos:
audio = get_audio(video, audio_channel, sample_interval)

transcribed, translated = perform_task(video, audio, language, target_language, transcribe_model,
translate_model)
transcribed, translated = perform_task(video, audio, language, target_language,
transcribe_model, translate_model)

save_result(video, transcribed, translated, output_dir, output_type, sample_interval, subtitle_type)
save_result(video, transcribed, translated, sample_interval, output_args)


def save_result(video, transcribed, translated, output_dir, output_type, sample_interval, subtitle_type):
if output_type == 'all' or output_type == 'srt':
def save_result(video: str, transcribed: Subtitles, translated: Subtitles, sample_interval: list,
output_args: dict[str, str]) -> None:
if output_args["output_type"] == 'all' or output_args["output_type"] == 'srt':
logger.info('Saving subtitle files...')
save_subtitles(video, transcribed, output_dir, translated is not None)
save_subtitles(video, transcribed, output_args["output_dir"], translated is not None)

if translated is not None:
save_subtitles(video, translated, output_dir, translated is not None)
save_subtitles(video, translated, output_args["output_dir"], translated is not None)

if output_type == 'srt':
if output_args["output_type"] == 'srt':
return

add_subtitles(video, transcribed, translated, output_dir, sample_interval, subtitle_type)
add_subtitles(video, transcribed, translated, sample_interval, output_args)


def perform_task(video, audio, language, target_language, transcribe_model, translate_model):
def perform_task(video: str, audio: str, language: str, target_language: str,
transcribe_model: WhisperAI,
translate_model: EasyNMTWrapper) -> tuple[Subtitles, Optional[Subtitles]]:
transcribed = get_subtitles(video, audio, transcribe_model)
translated = None

Expand All @@ -95,23 +104,23 @@ def translate_subtitles(subtitles: Subtitles, source_lang: str, target_lang: str
return Subtitles(translated_segments, target_lang)


def save_subtitles(path: str, subtitles: Subtitles, output_dir: str, use_language_in_output: bool):
def save_subtitles(path: str, subtitles: Subtitles, output_dir: str,
use_language_in_output: bool) -> None:
if use_language_in_output:
subtitles.output_path = os.path.join(
output_dir, f"{filename(path)}.{subtitles.language}.srt")
else:
subtitles.output_path = os.path.join(
output_dir, f"{filename(path)}.srt")

logger.info(f'Saving to path {subtitles.output_path}')
logger.info('Saving to path %s', subtitles.output_path)
with open(subtitles.output_path, "w", encoding="utf-8") as srt:
write_srt(subtitles.segments, file=srt)


def get_subtitles(source_path: str, audio_path: str, model: WhisperAI) -> Subtitles:
logger.info(
f"Generating subtitles for {filename(source_path)}... This might take a while."
)
logger.info(f"Generating subtitles for %s... This might take a while.",
filename(source_path))

segments, info = model.transcribe(audio_path)

Expand Down
5 changes: 3 additions & 2 deletions auto_subtitle/translation/easynmt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@ def __init__(self, device: str):
translator=self.translator,
device=device if device != 'auto' else None)

def translate(self, segments: list[Segment], source_lang: str, target_lang: str) -> list[Segment]:
def translate(self, segments: list[Segment], source_lang: str,
target_lang: str) -> list[Segment]:
source_text = [segment.text for segment in segments]
self.translator.load_available_models()

translated_text = self.model.translate(source_text, target_lang,
source_lang, show_progress_bar=True)
translated_segments = list()
translated_segments = []
for segment, translation in zip(segments, translated_text):
translated_segments.append(segment._replace(text=translation))

Expand Down
10 changes: 5 additions & 5 deletions auto_subtitle/translation/opusmt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def load_model(self, model_name):
self.models[model_name]['last_loaded'] = time.time()
return self.models[model_name]['tokenizer'], self.models[model_name]['model']

logger.info("Load model: %s" % model_name)
logger.info("Load model: %s", model_name)
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(model_name)
model.eval()
Expand Down Expand Up @@ -74,20 +74,20 @@ def determine_required_translations(self, source_lang, target_lang):
direct_key = f'{source_lang}-{target_lang}'
if direct_key in self.available_models:
logger.info(
f'Found direct translation from {source_lang} to {target_lang}.')
'Found direct translation from %s to %s.', source_lang, target_lang)
return [(source_lang, target_lang, direct_key)]

logger.info(
f'No direct translation from {source_lang} to {target_lang}. Trying to translate through en.')
f'No direct translation from %s to %s. Trying to translate through en.', source_lang, target_lang)

to_en_key = f'{source_lang}-en'
if to_en_key not in self.available_models:
logger.info(f'No translation from {source_lang} to en.')
logger.info(f'No translation from %s to en.', source_lang)
return []

from_en_key = f'en-{target_lang}'
if from_en_key not in self.available_models:
logger.info(f'No translation from en to {target_lang}.')
logger.info(f'No translation from en to %s.', target_lang)
return []

return [(source_lang, 'en', to_en_key), ('en', target_lang, from_en_key)]
Expand Down
13 changes: 7 additions & 6 deletions auto_subtitle/utils/convert.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from datetime import datetime, timedelta
from typing import Optional


def str2bool(string: str):
def str2bool(string: str) -> bool:
string = string.lower()
str2val = {"true": True, "false": False}

Expand All @@ -12,7 +13,7 @@ def str2bool(string: str):
f"Expected one of {set(str2val.keys())}, got {string}")


def str2timeinterval(string: str):
def str2timeinterval(string: str) -> Optional[list[int]]:
if string is None:
return None

Expand All @@ -34,7 +35,7 @@ def str2timeinterval(string: str):
return [start, end]


def time_to_timestamp(string: str):
def time_to_timestamp(string: str) -> int:
split_time = string.split(':')
if len(split_time) == 0 or len(split_time) > 3 or not all(x.isdigit() for x in split_time):
raise ValueError(
Expand All @@ -49,7 +50,7 @@ def time_to_timestamp(string: str):
return int(split_time[0]) * 60 * 60 + int(split_time[1]) * 60 + int(split_time[2])


def try_parse_timestamp(string: str):
def try_parse_timestamp(string: str) -> int:
timestamp = parse_timestamp(string, '%H:%M:%S')
if timestamp is not None:
return timestamp
Expand All @@ -61,7 +62,7 @@ def try_parse_timestamp(string: str):
return parse_timestamp(string, '%S')


def parse_timestamp(string: str, pattern: str):
def parse_timestamp(string: str, pattern: str) -> Optional[int]:
try:
date = datetime.strptime(string, pattern)
delta = timedelta(
Expand All @@ -71,7 +72,7 @@ def parse_timestamp(string: str, pattern: str):
return None


def format_timestamp(seconds: float, always_include_hours: bool = False):
def format_timestamp(seconds: float, always_include_hours: bool = False) -> str:
assert seconds >= 0, "non-negative timestamp expected"
milliseconds = round(seconds * 1000.0)

Expand Down
37 changes: 21 additions & 16 deletions auto_subtitle/utils/ffmpeg.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
import os
import tempfile
import ffmpeg
import logging
from typing import Optional
import ffmpeg
from .tempfile import SubtitlesTempFile
from .files import filename
from ..models.Subtitles import Subtitles
from typing import Optional

logger = logging.getLogger(__name__)


def get_audio(path: str, audio_channel_index: int, sample_interval: list):
def get_audio(path: str, audio_channel_index: int, sample_interval: list) -> str:
temp_dir = tempfile.gettempdir()

logger.info(f"Extracting audio from {filename(path)}...")
output_path = os.path.join(temp_dir, f"{filename(path)}.wav")
file_name = filename(path)
logger.info("Extracting audio from %s...", file_name)
output_path = os.path.join(temp_dir, f"{file_name}.wav")

ffmpeg_input_args = {}
if sample_interval is not None:
Expand All @@ -39,10 +40,11 @@ def get_audio(path: str, audio_channel_index: int, sample_interval: list):


def add_subtitles(path: str, transcribed: Subtitles, translated: Optional[Subtitles],
output_dir: str, sample_interval: list, subtitle_type: str):
out_path = os.path.join(output_dir, f"{filename(path)}.mp4")
sample_interval: list, output_args: dict[str, str]) -> None:
file_name = filename(path)
out_path = os.path.join(output_args["output_dir"], f"{file_name}.mp4")

logger.info(f"Adding subtitles to {filename(path)}...")
logger.info("Adding subtitles to %s...", file_name)

ffmpeg_input_args = {}
if sample_interval is not None:
Expand All @@ -56,19 +58,22 @@ def add_subtitles(path: str, transcribed: Subtitles, translated: Optional[Subtit
# HACK: On Windows it's impossible to use absolute subtitle file path with ffmpeg
# so we use temp copy instead
# see: https://github.com/kkroening/ffmpeg-python/issues/745
with SubtitlesTempFile(transcribed) as transcribed_tmp, SubtitlesTempFile(translated) as translated_tmp:
with SubtitlesTempFile(transcribed) as transcribed_tmp, SubtitlesTempFile(
translated) as translated_tmp:

if subtitle_type == 'hard':
hard_subtitles(path, out_path, transcribed_tmp, translated_tmp, ffmpeg_input_args, ffmpeg_output_args)
elif subtitle_type == 'soft':
soft_subtitles(path, out_path, transcribed_tmp, translated_tmp, ffmpeg_input_args, ffmpeg_output_args)
if output_args["subtitle_type"] == 'hard':
hard_subtitles(path, out_path, transcribed_tmp, translated_tmp, ffmpeg_input_args,
ffmpeg_output_args)
elif output_args["subtitle_type"] == 'soft':
soft_subtitles(path, out_path, transcribed_tmp, translated_tmp, ffmpeg_input_args,
ffmpeg_output_args)

logger.info(f"Saved subtitled video to {os.path.abspath(out_path)}.")
logger.info("Saved subtitled video to %s.", os.path.abspath(out_path))


def hard_subtitles(input_path: str, output_path: str,
transcribed: SubtitlesTempFile, translated: SubtitlesTempFile,
input_args: dict, output_args: dict):
input_args: dict, output_args: dict) -> None:
video = ffmpeg.input(input_path, **input_args)
audio = video.audio

Expand All @@ -88,7 +93,7 @@ def hard_subtitles(input_path: str, output_path: str,

def soft_subtitles(input_path: str, output_path: str,
transcribed: SubtitlesTempFile, translated: SubtitlesTempFile,
input_args: dict, output_args: dict):
input_args: dict, output_args: dict) -> None:
output_args['c'] = 'copy'
output_args['c:s'] = 'mov_text'
output_args['metadata:s:s:0'] = f'language={transcribed.subtitles.language}'
Expand Down
4 changes: 2 additions & 2 deletions auto_subtitle/utils/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .convert import format_timestamp


def write_srt(transcript: list[Segment], file: TextIO):
def write_srt(transcript: list[Segment], file: TextIO) -> None:
for i, segment in enumerate(transcript, start=1):
print(
f"{i}\n"
Expand All @@ -16,5 +16,5 @@ def write_srt(transcript: list[Segment], file: TextIO):
)


def filename(path: str):
def filename(path: str) -> str:
return os.path.splitext(os.path.basename(path))[0]
7 changes: 4 additions & 3 deletions auto_subtitle/utils/tempfile.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import tempfile
import os
import shutil
from typing import TextIO, cast
from auto_subtitle.models.Subtitles import Subtitles
from auto_subtitle.utils.files import write_srt

Expand All @@ -21,12 +22,12 @@ def __enter__(self):
if self.subtitles.output_path is not None and os.path.isfile(self.subtitles.output_path):
shutil.copyfile(self.subtitles.output_path, self.tmp_file_path)
else:
write_srt(self.subtitles.segments, self.tmp_file)
write_srt(self.subtitles.segments, cast(TextIO, self.tmp_file))
return self

def __exit__(self, exc_type, exc_value, exc_traceback):
def __exit__(self, exc_type, exc_value, exc_traceback) -> None:
if self.subtitles is None:
return self
return

self.tmp_file.close()
if os.path.isfile(self.tmp_file_path):
Expand Down
11 changes: 7 additions & 4 deletions auto_subtitle/utils/whisper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import warnings
import faster_whisper
from typing import Iterable
from faster_whisper import WhisperModel
from faster_whisper.transcribe import Segment, TranscriptionInfo
from tqdm import tqdm


Expand Down Expand Up @@ -38,10 +40,10 @@ class WhisperAI:
"""

def __init__(self, model_args: dict, transcribe_args: dict):
self.model = faster_whisper.WhisperModel(**model_args)
self.model = WhisperModel(**model_args)
self.transcribe_args = transcribe_args

def transcribe(self, audio_path: str):
def transcribe(self, audio_path: str) -> tuple[Iterable[Segment], TranscriptionInfo]:
"""
Transcribes the specified audio file and yields the resulting segments.
Expand All @@ -59,7 +61,8 @@ def transcribe(self, audio_path: str):
return self.subtitles_iterator(segments, info), info

@staticmethod
def subtitles_iterator(segments, info):
def subtitles_iterator(segments: Iterable[Segment],
info: TranscriptionInfo) -> Iterable[Segment]:
# Same precision as the Whisper timestamps.
total_duration = round(info.duration, 2)

Expand Down

0 comments on commit e44fb1c

Please sign in to comment.