diff --git a/pythainlp/tag/pos_tag.py b/pythainlp/tag/pos_tag.py
index abdfe5fc2..ee2a2b478 100644
--- a/pythainlp/tag/pos_tag.py
+++ b/pythainlp/tag/pos_tag.py
@@ -180,13 +180,36 @@ def pos_tag_sents(
def pos_tag_transformers(
- words: str, engine: str = "bert-base-th-cased-blackboard"
-):
+ sentence: str,
+ engine: str = "bert",
+ corpus: str = "blackboard",
+)->List[List[Tuple[str, str]]]:
"""
- "wangchanberta-ud-thai-pud-upos",
- "mdeberta-v3-ud-thai-pud-upos",
- "bert-base-th-cased-blackboard",
+ Marks sentences with part-of-speech (POS) tags.
+
+ :param str sentence: a list of lists of tokenized words
+ :param str engine:
+ * *bert* - BERT: Bidirectional Encoder Representations from Transformers (default)
+ * *wangchanberta* - fine-tuned version of airesearch/wangchanberta-base-att-spm-uncased on pud corpus (support PUD cotpus only)
+ * *mdeberta* - mDeBERTa: Multilingual Decoding-enhanced BERT with disentangled attention (support PUD corpus only)
+ :param str corpus: the corpus that is used to create the language model for tagger
+ * *blackboard* - `blackboard treebank (support bert engine only) `_
+ * *pud* - `Parallel Universal Dependencies (PUD)\
+ `_ \
+ treebanks, natively use Universal POS tags (support wangchanberta and mdeberta engine)
+ :return: a list of lists of tuples (word, POS tag)
+ :rtype: list[list[tuple[str, str]]]
+ :Example:
+
+ Labels POS for given sentence::
+
+ from pythainlp.tag import pos_tag_transformers
+
+ sentences = "แมวทำอะไรตอนห้าโมงเช้า"
+ pos_tag_transformers(sentences, engine="bert", corpus='blackboard')
+ # output:
+ # [[('แมว', 'NOUN'), ('ทําอะไร', 'VERB'), ('ตอนห้าโมงเช้า', 'NOUN')]]
"""
try:
@@ -196,28 +219,35 @@ def pos_tag_transformers(
raise ImportError(
"Not found transformers! Please install transformers by pip install transformers")
- if not words:
+ if not sentence:
return []
- if engine == "wangchanberta-ud-thai-pud-upos":
- model = AutoModelForTokenClassification.from_pretrained(
- "Pavarissy/wangchanberta-ud-thai-pud-upos")
- tokenizer = AutoTokenizer.from_pretrained("Pavarissy/wangchanberta-ud-thai-pud-upos")
- elif engine == "mdeberta-v3-ud-thai-pud-upos":
- model = AutoModelForTokenClassification.from_pretrained(
- "Pavarissy/mdeberta-v3-ud-thai-pud-upos")
- tokenizer = AutoTokenizer.from_pretrained("Pavarissy/mdeberta-v3-ud-thai-pud-upos")
- elif engine == "bert-base-th-cased-blackboard":
- model = AutoModelForTokenClassification.from_pretrained("lunarlist/pos_thai")
- tokenizer = AutoTokenizer.from_pretrained("lunarlist/pos_thai")
+ _blackboard_support_engine = {
+ "bert" : "lunarlist/pos_thai",
+ }
+
+ _pud_support_engine = {
+ "wangchanberta" : "Pavarissy/wangchanberta-ud-thai-pud-upos",
+ "mdeberta" : "Pavarissy/mdeberta-v3-ud-thai-pud-upos",
+ }
+
+ if corpus == 'blackboard' and engine in _blackboard_support_engine.keys():
+ base_model = _blackboard_support_engine.get(engine)
+ model = AutoModelForTokenClassification.from_pretrained(base_model)
+ tokenizer = AutoTokenizer.from_pretrained(base_model)
+ elif corpus == 'pud' and engine in _pud_support_engine.keys():
+ base_model = _pud_support_engine.get(engine)
+ model = AutoModelForTokenClassification.from_pretrained(base_model)
+ tokenizer = AutoTokenizer.from_pretrained(base_model)
else:
raise ValueError(
- "pos_tag_transformers not support {0} engine.".format(
- engine
+ "pos_tag_transformers not support {0} engine or {1} corpus.".format(
+ engine, corpus
)
)
- pipeline = TokenClassificationPipeline(model=model, tokenizer=tokenizer, grouped_entities=True)
+ pipeline = TokenClassificationPipeline(model=model, tokenizer=tokenizer, aggregation_strategy="simple")
- outputs = pipeline(words)
- return outputs
\ No newline at end of file
+ outputs = pipeline(sentence)
+ word_tags = [[(tag['word'], tag['entity_group']) for tag in outputs]]
+ return word_tags
\ No newline at end of file
diff --git a/tests/test_tag.py b/tests/test_tag.py
index 8d1755b18..b5529ec5b 100644
--- a/tests/test_tag.py
+++ b/tests/test_tag.py
@@ -367,10 +367,13 @@ def test_NNER_class(self):
def test_pos_tag_transformers(self):
self.assertIsNotNone(pos_tag_transformers(
- words="แมวทำอะไรตอนห้าโมงเช้า", engine="bert-base-th-cased-blackboard"))
+ words="แมวทำอะไรตอนห้าโมงเช้า", engine="bert", corpus="blackboard"))
self.assertIsNotNone(pos_tag_transformers(
- words="แมวทำอะไรตอนห้าโมงเช้า", engine="mdeberta-v3-ud-thai-pud-upos"))
+ words="แมวทำอะไรตอนห้าโมงเช้า", engine="mdeberta", corpus="pud"))
self.assertIsNotNone(pos_tag_transformers(
- words="แมวทำอะไรตอนห้าโมงเช้า", engine="wangchanberta-ud-thai-pud-upos"))
+ words="แมวทำอะไรตอนห้าโมงเช้า", engine="wangchanberta", corpus="pud"))
with self.assertRaises(ValueError):
- pos_tag_transformers(words="แมวทำอะไรตอนห้าโมงเช้า", engine="non-existing-engine")
\ No newline at end of file
+ pos_tag_transformers(words="แมวทำอะไรตอนห้าโมงเช้า", engine="non-existing-engine")
+ with self.assertRaises(ValueError):
+ pos_tag_transformers(words="แมวทำอะไรตอนห้าโมงเช้า", engine="bert",
+ corpus="non-existing corpus")
\ No newline at end of file