Skip to content

Commit

Permalink
Merge pull request #815 from PyThaiNLP/add-small100
Browse files Browse the repository at this point in the history
Add small100 to pythainlp.translate
  • Loading branch information
wannaphong authored Jul 14, 2023
2 parents bf74de6 + b6212d9 commit e4f9c9b
Show file tree
Hide file tree
Showing 4 changed files with 460 additions and 6 deletions.
30 changes: 25 additions & 5 deletions pythainlp/translate/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,26 @@ class Translate:
"""

def __init__(
self, src_lang: str, target_lang: str, use_gpu: bool = False
self, src_lang: str, target_lang: str, engine: str="default", use_gpu: bool = False
) -> None:
"""
:param str src_lang: source language
:param str target_lang: target language
:param str engine: Machine Translation engine
:param bool use_gpu: load model to gpu (Default is False)
**Options for engine*
* *default* - The engine default by each a language.
* *small100* - A multilingual machine translation model (covering 100 languages)
**Options for source & target language**
* *th* - *en* - Thai to English
* *en* - *th* - English to Thai
* *th* - *zh* - Thai to Chinese
* *zh* - *th* - Chinese to Thai
* *th* - *fr* - Thai to French
* *th* - *xx* - Thai to xx (xx is language code). It uses small100 model.
* *xx* - *th* - xx to Thai (xx is language code). It uses small100 model.
:Example:
Expand All @@ -66,10 +73,21 @@ def __init__(
# output: I love cat.
"""
self.model = None
self.load_model(src_lang, target_lang, use_gpu)

def load_model(self, src_lang: str, target_lang: str, use_gpu: bool):
if src_lang == "th" and target_lang == "en":
self.engine = engine
self.src_lang = src_lang
self.use_gpu = use_gpu
self.target_lang = target_lang
self.load_model()

def load_model(self):
src_lang = self.src_lang
target_lang = self.target_lang
use_gpu = self.use_gpu
if self.engine == "small100":
from .small100 import Small100Translator

self.model = Small100Translator(use_gpu)
elif src_lang == "th" and target_lang == "en":
from pythainlp.translate.en_th import ThEnTranslator

self.model = ThEnTranslator(use_gpu)
Expand Down Expand Up @@ -100,4 +118,6 @@ def translate(self, text) -> str:
:return: translated text in target language
:rtype: str
"""
if self.engine == "small100":
return self.model.translate(text, tgt_lang=self.target_lang)
return self.model.translate(text)
60 changes: 60 additions & 0 deletions pythainlp/translate/small100.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from transformers import M2M100ForConditionalGeneration
from .tokenization_small100 import SMALL100Tokenizer

class Small100Translator:
"""
Machine Translation with small100 model
- Huggingface https://huggingface.co/alirezamsh/small100
:param bool use_gpu : load model to gpu (Default is False)
"""

def __init__(
self,
use_gpu: bool = False,
pretrained: str = "alirezamsh/small100",
) -> None:
self.pretrained = pretrained
self.model = M2M100ForConditionalGeneration.from_pretrained(self.pretrained)
self.tgt_lang = None
if use_gpu:
self.model = self.model.cuda()

def translate(self, text: str, tgt_lang: str="en") -> str:
"""
Translate text from X to X
:param str text: input text in source language
:param str tgt_lang: target language
:return: translated text in target language
:rtype: str
:Example:
::
from pythainlp.translate.small100 import Small100Translator
mt = Small100Translator()
# Translate text from Thai to English
mt.translate("ทดสอบระบบ", tgt_lang="en")
# output: 'Testing system'
# Translate text from Thai to Chinese
mt.translate("ทดสอบระบบ", tgt_lang="zh")
# output: '系统测试'
# Translate text from Thai to French
mt.translate("ทดสอบระบบ", tgt_lang="fr")
# output: 'Test du système'
"""
if tgt_lang!=self.tgt_lang:
self.tokenizer = SMALL100Tokenizer.from_pretrained(self.pretrained, tgt_lang=tgt_lang)
self.tgt_lang = tgt_lang
self.translated = self.model.generate(
**self.tokenizer(text, return_tensors="pt")
)
return self.tokenizer.batch_decode(self.translated, skip_special_tokens=True)[0]
Loading

0 comments on commit e4f9c9b

Please sign in to comment.