Skip to content

Commit abfbf02

Browse files
authored
Merge pull request #865 from pavaris-pm/update-pos-tag-transformers
Update `pos_tag_transformers` function
2 parents a319d08 + 5574ce3 commit abfbf02

File tree

2 files changed

+59
-26
lines changed

2 files changed

+59
-26
lines changed

pythainlp/tag/pos_tag.py

+52-22
Original file line numberDiff line numberDiff line change
@@ -180,13 +180,36 @@ def pos_tag_sents(
180180

181181

182182
def pos_tag_transformers(
183-
words: str, engine: str = "bert-base-th-cased-blackboard"
184-
):
183+
sentence: str,
184+
engine: str = "bert",
185+
corpus: str = "blackboard",
186+
)->List[List[Tuple[str, str]]]:
185187
"""
186-
"wangchanberta-ud-thai-pud-upos",
187-
"mdeberta-v3-ud-thai-pud-upos",
188-
"bert-base-th-cased-blackboard",
188+
Marks sentences with part-of-speech (POS) tags.
189+
190+
:param str sentence: a list of lists of tokenized words
191+
:param str engine:
192+
* *bert* - BERT: Bidirectional Encoder Representations from Transformers (default)
193+
* *wangchanberta* - fine-tuned version of airesearch/wangchanberta-base-att-spm-uncased on pud corpus (support PUD cotpus only)
194+
* *mdeberta* - mDeBERTa: Multilingual Decoding-enhanced BERT with disentangled attention (support PUD corpus only)
195+
:param str corpus: the corpus that is used to create the language model for tagger
196+
* *blackboard* - `blackboard treebank (support bert engine only) <https://bitbucket.org/kaamanita/blackboard-treebank/src/master/>`_
197+
* *pud* - `Parallel Universal Dependencies (PUD)\
198+
<https://github.com/UniversalDependencies/UD_Thai-PUD>`_ \
199+
treebanks, natively use Universal POS tags (support wangchanberta and mdeberta engine)
200+
:return: a list of lists of tuples (word, POS tag)
201+
:rtype: list[list[tuple[str, str]]]
189202
203+
:Example:
204+
205+
Labels POS for given sentence::
206+
207+
from pythainlp.tag import pos_tag_transformers
208+
209+
sentences = "แมวทำอะไรตอนห้าโมงเช้า"
210+
pos_tag_transformers(sentences, engine="bert", corpus='blackboard')
211+
# output:
212+
# [[('แมว', 'NOUN'), ('ทําอะไร', 'VERB'), ('ตอนห้าโมงเช้า', 'NOUN')]]
190213
"""
191214

192215
try:
@@ -196,28 +219,35 @@ def pos_tag_transformers(
196219
raise ImportError(
197220
"Not found transformers! Please install transformers by pip install transformers")
198221

199-
if not words:
222+
if not sentence:
200223
return []
201224

202-
if engine == "wangchanberta-ud-thai-pud-upos":
203-
model = AutoModelForTokenClassification.from_pretrained(
204-
"Pavarissy/wangchanberta-ud-thai-pud-upos")
205-
tokenizer = AutoTokenizer.from_pretrained("Pavarissy/wangchanberta-ud-thai-pud-upos")
206-
elif engine == "mdeberta-v3-ud-thai-pud-upos":
207-
model = AutoModelForTokenClassification.from_pretrained(
208-
"Pavarissy/mdeberta-v3-ud-thai-pud-upos")
209-
tokenizer = AutoTokenizer.from_pretrained("Pavarissy/mdeberta-v3-ud-thai-pud-upos")
210-
elif engine == "bert-base-th-cased-blackboard":
211-
model = AutoModelForTokenClassification.from_pretrained("lunarlist/pos_thai")
212-
tokenizer = AutoTokenizer.from_pretrained("lunarlist/pos_thai")
225+
_blackboard_support_engine = {
226+
"bert" : "lunarlist/pos_thai",
227+
}
228+
229+
_pud_support_engine = {
230+
"wangchanberta" : "Pavarissy/wangchanberta-ud-thai-pud-upos",
231+
"mdeberta" : "Pavarissy/mdeberta-v3-ud-thai-pud-upos",
232+
}
233+
234+
if corpus == 'blackboard' and engine in _blackboard_support_engine.keys():
235+
base_model = _blackboard_support_engine.get(engine)
236+
model = AutoModelForTokenClassification.from_pretrained(base_model)
237+
tokenizer = AutoTokenizer.from_pretrained(base_model)
238+
elif corpus == 'pud' and engine in _pud_support_engine.keys():
239+
base_model = _pud_support_engine.get(engine)
240+
model = AutoModelForTokenClassification.from_pretrained(base_model)
241+
tokenizer = AutoTokenizer.from_pretrained(base_model)
213242
else:
214243
raise ValueError(
215-
"pos_tag_transformers not support {0} engine.".format(
216-
engine
244+
"pos_tag_transformers not support {0} engine or {1} corpus.".format(
245+
engine, corpus
217246
)
218247
)
219248

220-
pipeline = TokenClassificationPipeline(model=model, tokenizer=tokenizer, grouped_entities=True)
249+
pipeline = TokenClassificationPipeline(model=model, tokenizer=tokenizer, aggregation_strategy="simple")
221250

222-
outputs = pipeline(words)
223-
return outputs
251+
outputs = pipeline(sentence)
252+
word_tags = [[(tag['word'], tag['entity_group']) for tag in outputs]]
253+
return word_tags

tests/test_tag.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -367,10 +367,13 @@ def test_NNER_class(self):
367367

368368
def test_pos_tag_transformers(self):
369369
self.assertIsNotNone(pos_tag_transformers(
370-
words="แมวทำอะไรตอนห้าโมงเช้า", engine="bert-base-th-cased-blackboard"))
370+
words="แมวทำอะไรตอนห้าโมงเช้า", engine="bert", corpus="blackboard"))
371371
self.assertIsNotNone(pos_tag_transformers(
372-
words="แมวทำอะไรตอนห้าโมงเช้า", engine="mdeberta-v3-ud-thai-pud-upos"))
372+
words="แมวทำอะไรตอนห้าโมงเช้า", engine="mdeberta", corpus="pud"))
373373
self.assertIsNotNone(pos_tag_transformers(
374-
words="แมวทำอะไรตอนห้าโมงเช้า", engine="wangchanberta-ud-thai-pud-upos"))
374+
words="แมวทำอะไรตอนห้าโมงเช้า", engine="wangchanberta", corpus="pud"))
375375
with self.assertRaises(ValueError):
376-
pos_tag_transformers(words="แมวทำอะไรตอนห้าโมงเช้า", engine="non-existing-engine")
376+
pos_tag_transformers(words="แมวทำอะไรตอนห้าโมงเช้า", engine="non-existing-engine")
377+
with self.assertRaises(ValueError):
378+
pos_tag_transformers(words="แมวทำอะไรตอนห้าโมงเช้า", engine="bert",
379+
corpus="non-existing corpus")

0 commit comments

Comments
 (0)