-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add subtitles translation using EasyNMT and OpusMT libraries
- Loading branch information
1 parent
1c0cdb6
commit e76027c
Showing
15 changed files
with
321 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
name: Setup | ||
|
||
on: [push] | ||
|
||
jobs: | ||
build: | ||
runs-on: ubuntu-latest | ||
strategy: | ||
matrix: | ||
python-version: ["3.9"] | ||
steps: | ||
- uses: actions/checkout@v3 | ||
- name: Set up Python ${{ matrix.python-version }} | ||
uses: actions/setup-python@v3 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
- name: Install application | ||
run: | | ||
pip install wheel | ||
pip install -e . | ||
- name: Check that package was installed successfully | ||
run: | | ||
faster_auto_subtitle -h |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from easynmt import EasyNMT | ||
from faster_whisper.transcribe import Segment | ||
from .opusmt_utils import OpusMT | ||
|
||
|
||
class EasyNMTWrapper: | ||
def __init__(self, device): | ||
self.translator = OpusMT() | ||
self.model = EasyNMT('opus-mt', | ||
translator=self.translator, | ||
device=device if device != 'auto' else None) | ||
|
||
def translate(self, segments: list[Segment], source_lang: str, target_lang: str): | ||
source_text = [segment.text for segment in segments] | ||
self.translator.load_available_models() | ||
|
||
translated_text = self.model.translate(source_text, target_lang, | ||
source_lang, show_progress_bar=True) | ||
translated_segments = [None] * len(segments) | ||
for index, segment in enumerate(segments): | ||
translated_segments[index] = segment._replace( | ||
text=translated_text[index]) | ||
|
||
return translated_segments |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import langcodes | ||
from transformers.models.marian.convert_marian_tatoeba_to_pytorch import GROUP_MEMBERS | ||
|
||
|
||
def to_alpha2_languages(languages): | ||
return set(item for sublist in [__to_alpha2_language(language) for language in languages] for item in sublist) | ||
|
||
|
||
def __to_alpha2_language(language): | ||
if len(language) == 2: | ||
return [language] | ||
|
||
if language in GROUP_MEMBERS: | ||
return set([langcodes.Language.get(x).language for x in GROUP_MEMBERS[language][1]]) | ||
|
||
return [langcodes.Language.get(language).language] | ||
|
||
|
||
def to_alpha3_language(language): | ||
return langcodes.Language.get(language).to_alpha3() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
import time | ||
import logging | ||
from typing import List | ||
import torch | ||
from huggingface_hub import list_models, ModelFilter | ||
from transformers import MarianMTModel, MarianTokenizer | ||
from .languages import to_alpha2_languages, to_alpha3_language | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
NLP_ROOT = 'Helsinki-NLP' | ||
|
||
|
||
class OpusMT: | ||
def __init__(self, max_loaded_models: int = 10): | ||
self.models = {} | ||
self.max_loaded_models = max_loaded_models | ||
self.max_length = None | ||
|
||
self.available_models = None | ||
self.translations_graph = None | ||
|
||
def load_model(self, model_name): | ||
if model_name in self.models: | ||
self.models[model_name]['last_loaded'] = time.time() | ||
return self.models[model_name]['tokenizer'], self.models[model_name]['model'] | ||
|
||
logger.info("Load model: {}" % model_name) | ||
tokenizer = MarianTokenizer.from_pretrained(model_name) | ||
model = MarianMTModel.from_pretrained(model_name) | ||
model.eval() | ||
|
||
if len(self.models) >= self.max_loaded_models: | ||
oldest_time = time.time() | ||
oldest_model = None | ||
for loaded_model_name in self.models.keys(): | ||
if self.models[loaded_model_name]['last_loaded'] <= oldest_time: | ||
oldest_model = loaded_model_name | ||
oldest_time = self.models[loaded_model_name]['last_loaded'] | ||
del self.models[oldest_model] | ||
|
||
self.models[model_name] = { | ||
'tokenizer': tokenizer, 'model': model, 'last_loaded': time.time()} | ||
return tokenizer, model | ||
|
||
def load_available_models(self): | ||
if self.available_models is not None: | ||
return | ||
|
||
print('Loading a list of available language models from OPUS-NT') | ||
model_list = list_models( | ||
filter=ModelFilter( | ||
author=NLP_ROOT | ||
) | ||
) | ||
|
||
suffix = [x.modelId.split("/")[1] for x in model_list | ||
if x.modelId.startswith(f'{NLP_ROOT}/opus-mt') and 'tc' not in x.modelId] | ||
|
||
models = [DownloadableModel(f"{NLP_ROOT}/{s}") | ||
for s in suffix if s == s.lower()] | ||
|
||
self.available_models = {} | ||
for model in models: | ||
for src in model.source_languages: | ||
for tgt in model.target_languages: | ||
key = f'{src}-{tgt}' | ||
if key not in self.available_models: | ||
self.available_models[key] = model | ||
elif self.available_models[key].language_count > model.language_count: | ||
self.available_models[key] = model | ||
|
||
def determine_required_translations(self, source_lang, target_lang): | ||
direct_key = f'{source_lang}-{target_lang}' | ||
if direct_key in self.available_models: | ||
print( | ||
f'Found direct translation from {source_lang} to {target_lang}.') | ||
return [(source_lang, target_lang, direct_key)] | ||
|
||
print( | ||
f'No direct translation from {source_lang} to {target_lang}. Trying to translate through en.') | ||
|
||
to_en_key = f'{source_lang}-en' | ||
if to_en_key not in self.available_models: | ||
print(f'No translation from {source_lang} to en.') | ||
return [] | ||
|
||
from_en_key = f'en-{target_lang}' | ||
if from_en_key not in self.available_models: | ||
print(f'No translation from en to {target_lang}.') | ||
return [] | ||
|
||
return [(source_lang, 'en', to_en_key), ('en', target_lang, from_en_key)] | ||
|
||
def translate_sentences(self, sentences: List[str], source_lang: str, target_lang: str, device: str, beam_size: int = 5, **kwargs): | ||
self.load_available_models() | ||
|
||
translations = self.determine_required_translations( | ||
source_lang, target_lang) | ||
|
||
if len(translations) == 0: | ||
return sentences | ||
|
||
intermediate = sentences | ||
for _, tgt_lang, key in translations: | ||
model_data = self.available_models[key] | ||
model_name = model_data.name | ||
tokenizer, model = self.load_model(model_name) | ||
model.to(device) | ||
|
||
if model_data.multilanguage: | ||
alpha3 = to_alpha3_language(tgt_lang) | ||
prefix = next( | ||
x for x in tokenizer.supported_language_codes if alpha3 in x) | ||
intermediate = [f'{prefix} {x}' for x in intermediate] | ||
|
||
inputs = tokenizer(intermediate, truncation=True, padding=True, | ||
max_length=self.max_length, return_tensors="pt") | ||
|
||
for key in inputs: | ||
inputs[key] = inputs[key].to(device) | ||
|
||
with torch.no_grad(): | ||
translated = model.generate( | ||
**inputs, num_beams=beam_size, **kwargs) | ||
intermediate = [tokenizer.decode( | ||
t, skip_special_tokens=True) for t in translated] | ||
|
||
return intermediate | ||
|
||
|
||
class DownloadableModel: | ||
def __init__(self, name): | ||
self.name = name | ||
source_languages, target_languages = self.parse_languages(name) | ||
self.source_languages = source_languages | ||
self.target_languages = target_languages | ||
self.multilanguage = len(self.target_languages) > 1 | ||
self.language_count = len( | ||
self.source_languages) + len(self.target_languages) | ||
|
||
@staticmethod | ||
def parse_languages(name): | ||
parts = name.split('-') | ||
if len(parts) > 5: | ||
return set(), set() | ||
|
||
src, tgt = parts[3], parts[4] | ||
return to_alpha2_languages(src.split('_')), to_alpha2_languages(tgt.split('_')) |
Oops, something went wrong.