Skip to content

Commit 64b14f6

Browse files
Merge pull request huggingface#5 from DaryaTereshchenko/special_tokens
add a fix to special tokens handling and add the test_batch_fairseq_p…
2 parents 8bedcb3 + 30e4169 commit 64b14f6

File tree

2 files changed

+32
-12
lines changed

2 files changed

+32
-12
lines changed

src/transformers/models/prism/tokenization_prism.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -168,10 +168,11 @@ def __init__(
168168
self.language_codes = language_codes
169169
fairseq_language_code = FAIRSEQ_LANGUAGE_CODES[language_codes]
170170
self.lang_code_to_token = {lang_code: f"<{lang_code}>" for lang_code in fairseq_language_code}
171-
172-
language_tokens = [self.get_lang_token(lang_code) for lang_code in fairseq_language_code]
171+
173172
additional_special_tokens = kwargs.pop("additional_special_tokens", [])
174-
self.additional_special_tokens = language_tokens + additional_special_tokens
173+
language_tokens = [self.get_lang_token(lang_code) for lang_code in fairseq_language_code]
174+
175+
additional_special_tokens = language_tokens + additional_special_tokens
175176

176177
self.vocab_file = vocab_file
177178
self.encoder = load_json(vocab_file)
@@ -213,8 +214,6 @@ def __init__(
213214
num_madeup_words=num_madeup_words,
214215
**kwargs,
215216
)
216-
217-
self.special_tokens_map['additional_special_tokens'] = self.additional_special_tokens
218217
self.set_src_lang_special_tokens(self._src_lang)
219218

220219
@property
@@ -254,9 +253,6 @@ def convert_tokens_to_string(self, tokens):
254253
current_sub_tokens = []
255254
out_string = ""
256255
for token in tokens:
257-
# Skip language tokens during decoding
258-
if token in self.lang_code_to_token.values():
259-
continue
260256
# Ensure special tokens are not decoded with the sentencepiece model
261257
if token in self.all_special_tokens:
262258
out_string += self.sp_model.decode(current_sub_tokens) + token

tests/models/prism/test_tokenization_prism.py

+28-4
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from pathlib import Path
1818
from shutil import copyfile
1919

20-
from transformers import PrismTokenizer
20+
from transformers import PrismTokenizer, is_torch_available
2121
from transformers.testing_utils import (
2222
get_tests_dir,
2323
nested_simplify,
@@ -36,7 +36,9 @@
3636

3737
if is_sentencepiece_available():
3838
SAMPLE_SP = get_tests_dir("fixtures/test_sentencepiece.model")
39-
39+
40+
if is_torch_available():
41+
from transformers.models.prism.modeling_prism import shift_tokens_right
4042

4143
EN_CODE = 37
4244
FR_CODE = 85
@@ -81,7 +83,7 @@ def test_get_vocab(self):
8183

8284
self.assertEqual(vocab_keys[0], "</s>")
8385
self.assertEqual(vocab_keys[1], "<unk>")
84-
self.assertEqual(vocab_keys[-1], "<s>")
86+
self.assertEqual(vocab_keys[10], "<s>")
8587

8688
def test_full_tokenizer(self):
8789
tokenizer = self.get_tokenizer()
@@ -107,7 +109,9 @@ def test_full_tokenizer(self):
107109
class PrismTokenizerIntegrationTest(unittest.TestCase):
108110
checkpoint_name = CHECKPOINT_NAME
109111
src_text = ["Hi world.", "This is a Test.", "Some of my Best Friends are Linguists."]
110-
112+
tgt_text = ['Hé, monde!',
113+
"C'est un test.",
114+
'Certains de mes meilleurs amis sont linguistes.']
111115
expected_src_tokens = [EN_CODE, 5050, 21, 1951, 13934, 33789, 7, 269, 11348, 983, 9393, 6, 2]
112116

113117
@classmethod
@@ -177,7 +181,27 @@ def test_special_tokens_unaffacted_by_save_load(self):
177181
self.tokenizer.save_pretrained(tmpdirname)
178182
new_tok = PrismTokenizer.from_pretrained(tmpdirname)
179183
self.assertDictEqual(new_tok.lang_token_to_id, original_special_tokens)
184+
185+
@require_torch
186+
def test_batch_fairseq_parity(self):
187+
self.tokenizer.src_lang = "en"
188+
self.tokenizer.tgt_lang = "fr"
189+
190+
batch = self.tokenizer(self.src_text, text_target=self.tgt_text, padding=True, return_tensors="pt")
191+
192+
batch["decoder_input_ids"] = shift_tokens_right(
193+
batch["labels"], self.tokenizer.pad_token_id, self.tokenizer.eos_token_id
194+
)
180195

196+
for k in batch:
197+
batch[k] = batch[k].tolist()
198+
199+
assert batch.input_ids[1][0] == EN_CODE
200+
assert batch.input_ids[1][-1] == 1
201+
assert batch.labels[1][0] == FR_CODE
202+
assert batch.labels[1][-1] == 1
203+
assert batch.decoder_input_ids[1][:2] == [2, FR_CODE]
204+
181205
def test_decoding(self):
182206
text = "Hello, world!"
183207
encoded = self.tokenizer.encode(text)

0 commit comments

Comments
 (0)