17
17
from pathlib import Path
18
18
from shutil import copyfile
19
19
20
- from transformers import PrismTokenizer
20
+ from transformers import PrismTokenizer , is_torch_available
21
21
from transformers .testing_utils import (
22
22
get_tests_dir ,
23
23
nested_simplify ,
36
36
37
37
if is_sentencepiece_available ():
38
38
SAMPLE_SP = get_tests_dir ("fixtures/test_sentencepiece.model" )
39
-
39
+
40
+ if is_torch_available ():
41
+ from transformers .models .prism .modeling_prism import shift_tokens_right
40
42
41
43
EN_CODE = 37
42
44
FR_CODE = 85
@@ -81,7 +83,7 @@ def test_get_vocab(self):
81
83
82
84
self .assertEqual (vocab_keys [0 ], "</s>" )
83
85
self .assertEqual (vocab_keys [1 ], "<unk>" )
84
- self .assertEqual (vocab_keys [- 1 ], "<s>" )
86
+ self .assertEqual (vocab_keys [10 ], "<s>" )
85
87
86
88
def test_full_tokenizer (self ):
87
89
tokenizer = self .get_tokenizer ()
@@ -107,7 +109,9 @@ def test_full_tokenizer(self):
107
109
class PrismTokenizerIntegrationTest (unittest .TestCase ):
108
110
checkpoint_name = CHECKPOINT_NAME
109
111
src_text = ["Hi world." , "This is a Test." , "Some of my Best Friends are Linguists." ]
110
-
112
+ tgt_text = ['Hé, monde!' ,
113
+ "C'est un test." ,
114
+ 'Certains de mes meilleurs amis sont linguistes.' ]
111
115
expected_src_tokens = [EN_CODE , 5050 , 21 , 1951 , 13934 , 33789 , 7 , 269 , 11348 , 983 , 9393 , 6 , 2 ]
112
116
113
117
@classmethod
@@ -177,7 +181,27 @@ def test_special_tokens_unaffacted_by_save_load(self):
177
181
self .tokenizer .save_pretrained (tmpdirname )
178
182
new_tok = PrismTokenizer .from_pretrained (tmpdirname )
179
183
self .assertDictEqual (new_tok .lang_token_to_id , original_special_tokens )
184
+
185
+ @require_torch
186
+ def test_batch_fairseq_parity (self ):
187
+ self .tokenizer .src_lang = "en"
188
+ self .tokenizer .tgt_lang = "fr"
189
+
190
+ batch = self .tokenizer (self .src_text , text_target = self .tgt_text , padding = True , return_tensors = "pt" )
191
+
192
+ batch ["decoder_input_ids" ] = shift_tokens_right (
193
+ batch ["labels" ], self .tokenizer .pad_token_id , self .tokenizer .eos_token_id
194
+ )
180
195
196
+ for k in batch :
197
+ batch [k ] = batch [k ].tolist ()
198
+
199
+ assert batch .input_ids [1 ][0 ] == EN_CODE
200
+ assert batch .input_ids [1 ][- 1 ] == 1
201
+ assert batch .labels [1 ][0 ] == FR_CODE
202
+ assert batch .labels [1 ][- 1 ] == 1
203
+ assert batch .decoder_input_ids [1 ][:2 ] == [2 , FR_CODE ]
204
+
181
205
def test_decoding (self ):
182
206
text = "Hello, world!"
183
207
encoded = self .tokenizer .encode (text )
0 commit comments