Skip to content

Commit 2e7dc23

Browse files
committed
Update cuda in pythainlp.translate
1 parent f2aaa7e commit 2e7dc23

File tree

3 files changed

+3
-3
lines changed

3 files changed

+3
-3
lines changed

pythainlp/translate/en_th.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def __init__(self, use_gpu: bool = False):
8787
),
8888
)
8989
if use_gpu:
90-
self._model.cuda()
90+
self._model = self._model.cuda()
9191

9292
def translate(self, text: str) -> str:
9393
"""

pythainlp/translate/th_fr.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def __init__(
4949
self.tokenizer_thzh = AutoTokenizer.from_pretrained(pretrained)
5050
self.model_thzh = AutoModelForSeq2SeqLM.from_pretrained(pretrained)
5151
if use_gpu:
52-
self.model_thzh.cuda()
52+
self.model_thzh = self.model_thzh.cuda()
5353

5454
def translate(self, text: str) -> str:
5555
"""

pythainlp/translate/zh_th.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __init__(
4343
self.tokenizer_thzh = AutoTokenizer.from_pretrained(pretrained)
4444
self.model_thzh = AutoModelForSeq2SeqLM.from_pretrained(pretrained)
4545
if use_gpu:
46-
self.model_thzh.cuda()
46+
self.model_thzh = self.model_thzh.cuda()
4747

4848
def translate(self, text: str) -> str:
4949
"""

0 commit comments

Comments
 (0)