1
1
# -*- coding: utf-8 -*-
2
+
3
+ """
4
+ Code by https://github.com/cstorm125/thai2fit/
5
+ """
6
+ import re
7
+ import numpy as np
8
+ import dill as pickle
9
+
10
+ #fastai
11
+ from fastai import *
12
+ from fastai .text .transform import *
13
+
14
+ #pytorch
15
+ import torch
16
+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
17
+
18
+ #pythainlp
19
+ from pythainlp .corpus import download , get_file
20
+ from pythainlp .tokenize import word_tokenize
21
+ from pythainlp .util import normalize as normalize_char_order
22
+
23
+ MODEL_NAME = "thai2fit_lm"
24
+ ITOS_NAME = "thai2fit_itos"
25
+
26
+ #custom fastai tokenizer
27
+ class ThaiTokenizer (BaseTokenizer ):
28
+ """
29
+ Wrapper around a frozen newmm tokenizer to make it a fastai `BaseTokenizer`.
30
+ """
31
+ def __init__ (self , lang :str = 'th' ):
32
+ self .lang = lang
33
+ def tokenizer (self , t :str ) -> List [str ]:
34
+ """
35
+ :meth: tokenize text with a frozen newmm engine
36
+ :param str t: text to tokenize
37
+ :return: tokenized text
38
+ """
39
+ return (word_tokenize (t ,engine = 'ulmfit' ))
40
+ def add_special_cases (self , toks :Collection [str ]):
41
+ pass
42
+
43
+ #special rules for thai
44
+ def replace_rep_after (t :str ) -> str :
45
+ "Replace repetitions at the character level in `t` after the repetition"
46
+ def _replace_rep (m :Collection [str ]) -> str :
47
+ c ,cc = m .groups ()
48
+ return f' { c } { TK_REP } { len (cc )+ 1 } '
49
+ re_rep = re .compile (r'(\S)(\1{3,})' )
50
+ return re_rep .sub (_replace_rep , t )
51
+
52
+ def rm_useless_newlines (t :str ) -> str :
53
+ "Remove multiple newlines in `t`."
54
+ return re .sub ('[\n ]{2,}' , ' ' , t )
55
+
56
+ def rm_brackets (t :str ) -> str :
57
+ "Remove all empty brackets from `t`."
58
+ new_line = re .sub ('\(\)' ,'' ,t )
59
+ new_line = re .sub ('\{\}' ,'' ,new_line )
60
+ new_line = re .sub ('\[\]' ,'' ,new_line )
61
+ return (new_line )
62
+
63
+ #in case we want to add more specific rules for thai
64
+ thai_rules = [fix_html , deal_caps , replace_rep_after , normalize_char_order ,
65
+ spec_add_spaces , rm_useless_spaces , rm_useless_newlines , rm_brackets ]
66
+
67
+ # Download pretrained models
68
+ def get_path (fname ):
69
+ """
70
+ :meth: download get path of file from pythainlp-corpus
71
+ :param str fname: file name
72
+ :return: path to downloaded file
73
+ """
74
+ path = get_file (fname )
75
+ if not path :
76
+ download (fname )
77
+ path = get_file (fname )
78
+ return (path )
79
+
80
+ #pretrained paths
81
+ THWIKI = [get_path (MODEL_NAME )[:- 4 ], get_path (ITOS_NAME )[:- 4 ]]
82
+ tt = ThaiTokenizer ()
83
+
84
+ def document_vector (ss , learn , data ):
85
+ """
86
+ :meth: `document_vector` get document vector using pretrained ULMFiT model
87
+ :param str ss: sentence to extract embeddings
88
+ :param learn: fastai language model learner
89
+ :param data: fastai data bunch
90
+ :return: `numpy.array` of document vector sized 400
91
+ """
92
+ s = tt .tokenizer (ss )
93
+ t = torch .tensor (data .vocab .numericalize (s ), requires_grad = False )[:,None ].to (device )
94
+ m = learn .model [0 ]
95
+ m .reset ()
96
+ pred ,_ = m (t )
97
+ res = pred [- 1 ][- 1 ,:,:].squeeze ().detach ().numpy ()
98
+ return (res )
0 commit comments