Skip to content

Commit a1cce77

Browse files
committed
Add small100 to pythainlp.translate.Translate
1 parent c5f7c9d commit a1cce77

File tree

2 files changed

+38
-6
lines changed

2 files changed

+38
-6
lines changed

pythainlp/translate/core.py

+25-5
Original file line numberDiff line numberDiff line change
@@ -41,19 +41,26 @@ class Translate:
4141
"""
4242

4343
def __init__(
44-
self, src_lang: str, target_lang: str, use_gpu: bool = False
44+
self, src_lang: str, target_lang: str, engine: str="default", use_gpu: bool = False
4545
) -> None:
4646
"""
4747
:param str src_lang: source language
4848
:param str target_lang: target language
49+
:param str engine: Machine Translation engine
4950
:param bool use_gpu: load model to gpu (Default is False)
5051
52+
**Options for engine*
53+
* *default* - The engine default by each a language.
54+
* *small100* - A multilingual machine translation model (covering 100 languages)
55+
5156
**Options for source & target language**
5257
* *th* - *en* - Thai to English
5358
* *en* - *th* - English to Thai
5459
* *th* - *zh* - Thai to Chinese
5560
* *zh* - *th* - Chinese to Thai
5661
* *th* - *fr* - Thai to French
62+
* *th* - *xx* - Thai to xx (xx is language code). It uses small100 model.
63+
* *xx* - *th* - xx to Thai (xx is language code). It uses small100 model.
5764
5865
:Example:
5966
@@ -66,10 +73,21 @@ def __init__(
6673
# output: I love cat.
6774
"""
6875
self.model = None
69-
self.load_model(src_lang, target_lang, use_gpu)
70-
71-
def load_model(self, src_lang: str, target_lang: str, use_gpu: bool):
72-
if src_lang == "th" and target_lang == "en":
76+
self.engine = engine
77+
self.src_lang = src_lang
78+
self.use_gpu = use_gpu
79+
self.target_lang = target_lang
80+
self.load_model()
81+
82+
def load_model(self):
83+
src_lang = self.src_lang
84+
target_lang = self.target_lang
85+
use_gpu = self.use_gpu
86+
if self.engine == "small100":
87+
from .small100 import Small100Translator
88+
89+
self.model = Small100Translator(use_gpu)
90+
elif src_lang == "th" and target_lang == "en":
7391
from pythainlp.translate.en_th import ThEnTranslator
7492

7593
self.model = ThEnTranslator(use_gpu)
@@ -100,4 +118,6 @@ def translate(self, text) -> str:
100118
:return: translated text in target language
101119
:rtype: str
102120
"""
121+
if self.engine == "small100":
122+
return self.model.translate(text, tgt_lang=self.target_lang)
103123
return self.model.translate(text)

tests/test_translate.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -73,5 +73,17 @@ def test_translate(self):
7373
"ทดสอบระบบ",
7474
)
7575
)
76+
self.th_fr_translator = Translate('th', 'fr', engine="small100")
77+
self.assertIsNotNone(
78+
self.th_fr_translator.translate(
79+
"ทดสอบระบบ",
80+
)
81+
)
82+
self.th_fr_translator = Translate('th', 'ja')
83+
self.assertIsNotNone(
84+
self.th_fr_translator.translate(
85+
"ทดสอบระบบ",
86+
)
87+
)
7688
with self.assertRaises(ValueError):
77-
self.th_cat_translator = Translate('th', 'cat')
89+
self.th_cat_translator = Translate('th', 'cat', engine="fkfj")

0 commit comments

Comments
 (0)