Skip to content

Commit

Permalink
Add subtitles translation using EasyNMT and OpusMT libraries
Browse files Browse the repository at this point in the history
  • Loading branch information
Sirozha1337 committed Jan 30, 2024
1 parent 1c0cdb6 commit e76027c
Show file tree
Hide file tree
Showing 15 changed files with 321 additions and 31 deletions.
1 change: 1 addition & 0 deletions .github/workflows/pylint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install pylint
pip install fasttext
pip install -r requirements.txt
- name: Analysing the code with pylint
run: |
Expand Down
23 changes: 23 additions & 0 deletions .github/workflows/setup.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
name: Setup

on: [push]

jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install application
run: |
pip install wheel
pip install -e .
- name: Check that package was installed successfully
run: |
faster_auto_subtitle -h
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ Adding `--task translate` will translate the subtitles into English:

faster_auto_subtitle /path/to/video.mp4 --task translate

Adding `--target_language {2-letter-language-code}` will translate the subtitles into specified language using [Opus-MT](https://github.com/Helsinki-NLP/Opus-MT):

faster_auto_subtitle /path/to/video.mp4 --target_language fr

This will require downloading the appropriate model. If direct translation is not available it will attempt translation from source to english and from english to source.

Run the following to view all available options:

faster_auto_subtitle --help
Expand All @@ -49,7 +55,7 @@ Higher `beam_size` usually leads to greater accuracy, but slows down the process

Setting higher `no_speech_threshold` could be useful for videos with a lot of background noise to stop Whisper from "hallucinating" subtitles for it.

In my experience settings option `condition_on_previous_text` to `False` dramatically increases accurracy for videos like TV Shows with an intro song at the start.
In my experience settings option `condition_on_previous_text` to `False` dramatically increases accurracy for videos like TV Shows with an intro song at the start.

You can use `sample_interval` parameter to generate subtitles for a portion of the video to play around with those parameters:

Expand Down
7 changes: 6 additions & 1 deletion auto_subtitle/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,16 @@ def main():
parser.add_argument("--task", type=str, default="transcribe",
choices=["transcribe", "translate"],
help="whether to perform X->X speech recognition ('transcribe') \
or X->English translation ('translate')")
or X->Language translation ('translate')")
parser.add_argument("--language", type=str, default="auto",
choices=LANGUAGE_CODES,
help="What is the origin language of the video? \
If unset, it is detected automatically.")
parser.add_argument("--target_language", type=str, default="en",
choices=LANGUAGE_CODES,
help="Desired language to translate subtitles to. \
If language is not en, Opus-MT will be used. \
See https://github.com/Helsinki-NLP/Opus-MT.")

args = parser.parse_args().__dict__

Expand Down
78 changes: 60 additions & 18 deletions auto_subtitle/main.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import os
import warnings
import tempfile
from .utils.files import filename, write_srt
from .utils.ffmpeg import get_audio, overlay_subtitles
from .utils.whisper import WhisperAI
from .translation.easynmt_utils import EasyNMTWrapper


def process(args: dict):
Expand All @@ -12,52 +12,94 @@ def process(args: dict):
output_srt: bool = args.pop("output_srt")
srt_only: bool = args.pop("srt_only")
language: str = args.pop("language")
sample_interval: str = args.pop("sample_interval")
sample_interval: list = args.pop("sample_interval")
target_language: str = args.pop("target_language")

os.makedirs(output_dir, exist_ok=True)

if model_name.endswith(".en"):
warnings.warn(
f"{model_name} is an English-only model, forcing English detection.")
args["language"] = "en"
language = "en"
# if translate task used and language argument is set, then use it
elif language != "auto":
args["language"] = language

if target_language != 'en':
warnings.warn(
f"{target_language} is not English, Opus-MT will be used to perform translation.")
args['task'] = 'transcribe'

audios = get_audio(args.pop("video"), args.pop(
'audio_channel'), sample_interval)

model_args = {}
model_args["model_size_or_path"] = model_name
model_args["device"] = args.pop("device")
model_args["compute_type"] = args.pop("compute_type")
model_args = {
"model_size_or_path": model_name,
"device": args.pop("device"),
"compute_type": args.pop("compute_type")
}

subtitles = get_subtitles(audios, model_args, args)
print('Subtitles generated.')

if target_language != 'en':
print('Translating subtitles... This might take a while.')
subtitles = translate_subtitles(
subtitles, language, target_language, model_args)

srt_output_dir = output_dir if output_srt or srt_only else tempfile.gettempdir()
subtitles = get_subtitles(audios, srt_output_dir, model_args, args)
if output_srt or srt_only:
print('Saving subtitle files...')
save_subtitles(subtitles, output_dir)

if srt_only:
return

overlay_subtitles(subtitles, output_dir, sample_interval)


def get_subtitles(audio_paths: list, output_dir: str,
model_args: dict, transcribe_args: dict):
def translate_subtitles(subtitles: dict, source_lang: str, target_lang: str, model_args: dict):
model = EasyNMTWrapper(device=model_args['device'])

translated_subtitles = {}
for key, subtitle in subtitles.items():
src_lang = source_lang
if src_lang == '' or src_lang is None:
src_lang = subtitle['language']

translated_segments = model.translate(
subtitle['segments'], src_lang, target_lang)

translated_subtitle = subtitle.copy()
translated_subtitle['segments'] = translated_segments
translated_subtitles[key] = translated_subtitle

return translated_subtitles


def save_subtitles(subtitles: dict, output_dir: str):
for path, subtitle in subtitles.items():
subtitle["output_path"] = os.path.join(
output_dir, f"{filename(path)}.srt")

print(f'Saving to path {subtitle["output_path"]}')
with open(subtitle['output_path'], "w", encoding="utf-8") as srt:
write_srt(subtitle['segments'], file=srt)


def get_subtitles(audio_paths: dict, model_args: dict, transcribe_args: dict):
model = WhisperAI(model_args, transcribe_args)

subtitles_path = {}
subtitles = {}

for path, audio_path in audio_paths.items():
print(
f"Generating subtitles for {filename(path)}... This might take a while."
)
srt_path = os.path.join(output_dir, f"{filename(path)}.srt")

segments = model.transcribe(audio_path)

with open(srt_path, "w", encoding="utf-8") as srt:
write_srt(segments, file=srt)
segments, info = model.transcribe(audio_path)

subtitles_path[path] = srt_path
subtitles[path] = {'segments': list(
segments), 'language': info.language}

return subtitles_path
return subtitles
Empty file.
24 changes: 24 additions & 0 deletions auto_subtitle/translation/easynmt_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from easynmt import EasyNMT
from faster_whisper.transcribe import Segment
from .opusmt_utils import OpusMT


class EasyNMTWrapper:
def __init__(self, device):
self.translator = OpusMT()
self.model = EasyNMT('opus-mt',
translator=self.translator,
device=device if device != 'auto' else None)

def translate(self, segments: list[Segment], source_lang: str, target_lang: str):
source_text = [segment.text for segment in segments]
self.translator.load_available_models()

translated_text = self.model.translate(source_text, target_lang,
source_lang, show_progress_bar=True)
translated_segments = [None] * len(segments)
for index, segment in enumerate(segments):
translated_segments[index] = segment._replace(
text=translated_text[index])

return translated_segments
20 changes: 20 additions & 0 deletions auto_subtitle/translation/languages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import langcodes
from transformers.models.marian.convert_marian_tatoeba_to_pytorch import GROUP_MEMBERS


def to_alpha2_languages(languages):
return set(item for sublist in [__to_alpha2_language(language) for language in languages] for item in sublist)


def __to_alpha2_language(language):
if len(language) == 2:
return [language]

if language in GROUP_MEMBERS:
return set([langcodes.Language.get(x).language for x in GROUP_MEMBERS[language][1]])

return [langcodes.Language.get(language).language]


def to_alpha3_language(language):
return langcodes.Language.get(language).to_alpha3()
149 changes: 149 additions & 0 deletions auto_subtitle/translation/opusmt_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import time
import logging
from typing import List
import torch
from huggingface_hub import list_models, ModelFilter
from transformers import MarianMTModel, MarianTokenizer
from .languages import to_alpha2_languages, to_alpha3_language

logger = logging.getLogger(__name__)

NLP_ROOT = 'Helsinki-NLP'


class OpusMT:
def __init__(self, max_loaded_models: int = 10):
self.models = {}
self.max_loaded_models = max_loaded_models
self.max_length = None

self.available_models = None
self.translations_graph = None

def load_model(self, model_name):
if model_name in self.models:
self.models[model_name]['last_loaded'] = time.time()
return self.models[model_name]['tokenizer'], self.models[model_name]['model']

logger.info("Load model: {}" % model_name)
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(model_name)
model.eval()

if len(self.models) >= self.max_loaded_models:
oldest_time = time.time()
oldest_model = None
for loaded_model_name in self.models.keys():
if self.models[loaded_model_name]['last_loaded'] <= oldest_time:
oldest_model = loaded_model_name
oldest_time = self.models[loaded_model_name]['last_loaded']
del self.models[oldest_model]

self.models[model_name] = {
'tokenizer': tokenizer, 'model': model, 'last_loaded': time.time()}
return tokenizer, model

def load_available_models(self):
if self.available_models is not None:
return

print('Loading a list of available language models from OPUS-NT')
model_list = list_models(
filter=ModelFilter(
author=NLP_ROOT
)
)

suffix = [x.modelId.split("/")[1] for x in model_list
if x.modelId.startswith(f'{NLP_ROOT}/opus-mt') and 'tc' not in x.modelId]

models = [DownloadableModel(f"{NLP_ROOT}/{s}")
for s in suffix if s == s.lower()]

self.available_models = {}
for model in models:
for src in model.source_languages:
for tgt in model.target_languages:
key = f'{src}-{tgt}'
if key not in self.available_models:
self.available_models[key] = model
elif self.available_models[key].language_count > model.language_count:
self.available_models[key] = model

def determine_required_translations(self, source_lang, target_lang):
direct_key = f'{source_lang}-{target_lang}'
if direct_key in self.available_models:
print(
f'Found direct translation from {source_lang} to {target_lang}.')
return [(source_lang, target_lang, direct_key)]

print(
f'No direct translation from {source_lang} to {target_lang}. Trying to translate through en.')

to_en_key = f'{source_lang}-en'
if to_en_key not in self.available_models:
print(f'No translation from {source_lang} to en.')
return []

from_en_key = f'en-{target_lang}'
if from_en_key not in self.available_models:
print(f'No translation from en to {target_lang}.')
return []

return [(source_lang, 'en', to_en_key), ('en', target_lang, from_en_key)]

def translate_sentences(self, sentences: List[str], source_lang: str, target_lang: str, device: str, beam_size: int = 5, **kwargs):
self.load_available_models()

translations = self.determine_required_translations(
source_lang, target_lang)

if len(translations) == 0:
return sentences

intermediate = sentences
for _, tgt_lang, key in translations:
model_data = self.available_models[key]
model_name = model_data.name
tokenizer, model = self.load_model(model_name)
model.to(device)

if model_data.multilanguage:
alpha3 = to_alpha3_language(tgt_lang)
prefix = next(
x for x in tokenizer.supported_language_codes if alpha3 in x)
intermediate = [f'{prefix} {x}' for x in intermediate]

inputs = tokenizer(intermediate, truncation=True, padding=True,
max_length=self.max_length, return_tensors="pt")

for key in inputs:
inputs[key] = inputs[key].to(device)

with torch.no_grad():
translated = model.generate(
**inputs, num_beams=beam_size, **kwargs)
intermediate = [tokenizer.decode(
t, skip_special_tokens=True) for t in translated]

return intermediate


class DownloadableModel:
def __init__(self, name):
self.name = name
source_languages, target_languages = self.parse_languages(name)
self.source_languages = source_languages
self.target_languages = target_languages
self.multilanguage = len(self.target_languages) > 1
self.language_count = len(
self.source_languages) + len(self.target_languages)

@staticmethod
def parse_languages(name):
parts = name.split('-')
if len(parts) > 5:
return set(), set()

src, tgt = parts[3], parts[4]
return to_alpha2_languages(src.split('_')), to_alpha2_languages(tgt.split('_'))
Loading

0 comments on commit e76027c

Please sign in to comment.