Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 92267e8

Browse files
twairballrsepassi
authored andcommitted
Enhance WMT17 En-Zh task with full dataset. (#461)
* Enhance WMT17 En-Zh task with full dataset. Fix #446 Added `file_size_budget` as argument to `get_or_generate_vocab`. * Made requested Fixes: - Added TranslateEnzhWmt8k problem. - Renamed to TranslateEnzhWmt32k, to reflect target vocab in problem name - Added instructions for manually downloading full dataset.
1 parent 0fcdf8e commit 92267e8

File tree

2 files changed

+173
-21
lines changed

2 files changed

+173
-21
lines changed

tensor2tensor/data_generators/generator_utils.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,8 @@ def get_or_generate_vocab_inner(data_dir, vocab_filename, vocab_size,
316316

317317

318318
def get_or_generate_vocab(data_dir, tmp_dir, vocab_filename, vocab_size,
319-
sources):
319+
sources,
320+
_file_byte_budget=1e6):
320321
"""Generate a vocabulary from the datasets in sources."""
321322

322323
def generate():
@@ -349,7 +350,7 @@ def generate():
349350

350351
# Use Tokenizer to count the word occurrences.
351352
with tf.gfile.GFile(filepath, mode="r") as source_file:
352-
file_byte_budget = 1e6
353+
file_byte_budget = _file_byte_budget
353354
counter = 0
354355
countermax = int(source_file.size() / file_byte_budget / 2)
355356
for line in source_file:

tensor2tensor/data_generators/translate_enzh.py

+170-19
Original file line numberDiff line numberDiff line change
@@ -42,28 +42,145 @@
4242
# This is far from being the real WMT17 task - only toyset here
4343
# you need to register to get UN data and CWT data. Also, by convention,
4444
# this is EN to ZH - use translate_enzh_wmt8k_rev for ZH to EN task
45-
_ENZH_TRAIN_DATASETS = [[("http://data.statmt.org/wmt17/translation-task/"
46-
"training-parallel-nc-v12.tgz"),
47-
("training/news-commentary-v12.zh-en.en",
48-
"training/news-commentary-v12.zh-en.zh")]]
45+
#
46+
# News Commentary, around 220k lines
47+
# This dataset is only a small fraction of full WMT17 task
48+
_NC_TRAIN_DATASETS = [[
49+
"http://data.statmt.org/wmt17/translation-task/training-parallel-nc-v12.tgz",
50+
["training/news-commentary-v12.zh-en.en",
51+
"training/news-commentary-v12.zh-en.zh"]]]
4952

50-
_ENZH_TEST_DATASETS = [[
53+
# Test set from News Commentary. 2000 lines
54+
_NC_TEST_DATASETS = [[
5155
"http://data.statmt.org/wmt17/translation-task/dev.tgz",
5256
("dev/newsdev2017-enzh-src.en.sgm", "dev/newsdev2017-enzh-ref.zh.sgm")
5357
]]
5458

59+
# UN parallel corpus. 15,886,041 lines
60+
# Visit source website to download manually:
61+
# https://conferences.unite.un.org/UNCorpus
62+
#
63+
# NOTE: You need to register to download dataset from official source
64+
# place into tmp directory e.g. /tmp/t2t_datagen/dataset.tgz
65+
_UN_TRAIN_DATASETS = [[
66+
"https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/UNv1.0.en-zh.tar.gz",
67+
["en-zh/UNv1.0.en-zh.en",
68+
"en-zh/UNv1.0.en-zh.zh"]]]
69+
70+
# CWMT corpus
71+
# Visit source website to download manually:
72+
# http://nlp.nju.edu.cn/cwmt-wmt/
73+
#
74+
# casia2015: 1,050,000 lines
75+
# casict2015: 2,036,833 lines
76+
# datum2015: 1,000,003 lines
77+
# datum2017: 1,999,968 lines
78+
# NEU2017: 2,000,000 lines
79+
#
80+
# NOTE: You need to register to download dataset from official source
81+
# place into tmp directory e.g. /tmp/t2t_datagen/dataset.tgz
82+
83+
_CWMT_TRAIN_DATASETS = [
84+
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
85+
["cwmt/casia2015/casia2015_en.txt",
86+
"cwmt/casia2015/casia2015_ch.txt"]],
87+
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
88+
["cwmt/casict2015/casict2015_en.txt",
89+
"cwmt/casict2015/casict2015_ch.txt"]],
90+
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
91+
["cwmt/neu2017/NEU_en.txt",
92+
"cwmt/neu2017/NEU_cn.txt"]],
93+
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
94+
["cwmt/datum2015/datum_en.txt",
95+
"cwmt/datum2015/datum_ch.txt"]],
96+
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
97+
["cwmt/datum2017/Book1_en.txt",
98+
"cwmt/datum2017/Book1_cn.txt"]],
99+
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
100+
["cwmt/datum2017/Book2_en.txt",
101+
"cwmt/datum2017/Book2_cn.txt"]],
102+
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
103+
["cwmt/datum2017/Book3_en.txt",
104+
"cwmt/datum2017/Book3_cn.txt"]],
105+
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
106+
["cwmt/datum2017/Book4_en.txt",
107+
"cwmt/datum2017/Book4_cn.txt"]],
108+
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
109+
["cwmt/datum2017/Book5_en.txt",
110+
"cwmt/datum2017/Book5_cn.txt"]],
111+
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
112+
["cwmt/datum2017/Book6_en.txt",
113+
"cwmt/datum2017/Book6_cn.txt"]],
114+
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
115+
["cwmt/datum2017/Book7_en.txt",
116+
"cwmt/datum2017/Book7_cn.txt"]],
117+
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
118+
["cwmt/datum2017/Book8_en.txt",
119+
"cwmt/datum2017/Book8_cn.txt"]],
120+
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
121+
["cwmt/datum2017/Book9_en.txt",
122+
"cwmt/datum2017/Book9_cn.txt"]],
123+
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
124+
["cwmt/datum2017/Book10_en.txt",
125+
"cwmt/datum2017/Book10_cn.txt"]],
126+
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
127+
["cwmt/datum2017/Book11_en.txt",
128+
"cwmt/datum2017/Book11_cn.txt"]],
129+
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
130+
["cwmt/datum2017/Book12_en.txt",
131+
"cwmt/datum2017/Book12_cn.txt"]],
132+
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
133+
["cwmt/datum2017/Book13_en.txt",
134+
"cwmt/datum2017/Book13_cn.txt"]],
135+
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
136+
["cwmt/datum2017/Book14_en.txt",
137+
"cwmt/datum2017/Book14_cn.txt"]],
138+
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
139+
["cwmt/datum2017/Book15_en.txt",
140+
"cwmt/datum2017/Book15_cn.txt"]],
141+
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
142+
["cwmt/datum2017/Book16_en.txt",
143+
"cwmt/datum2017/Book16_cn.txt"]],
144+
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
145+
["cwmt/datum2017/Book17_en.txt",
146+
"cwmt/datum2017/Book17_cn.txt"]],
147+
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
148+
["cwmt/datum2017/Book18_en.txt",
149+
"cwmt/datum2017/Book18_cn.txt"]],
150+
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
151+
["cwmt/datum2017/Book19_en.txt",
152+
"cwmt/datum2017/Book19_cn.txt"]],
153+
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
154+
["cwmt/datum2017/Book20_en.txt",
155+
"cwmt/datum2017/Book20_cn.txt"]]
156+
]
157+
158+
159+
def get_filename(dataset):
160+
return dataset[0][0].split('/')[-1]
55161

56162
@registry.register_problem
57-
class TranslateEnzhWmt8k(translate.TranslateProblem):
58-
"""Problem spec for WMT En-Zh translation."""
163+
class TranslateEnzhWmt32k(translate.TranslateProblem):
164+
"""Problem spec for WMT En-Zh translation.
165+
Attempts to use full training dataset, which needs website
166+
registration and downloaded manually from official sources:
59167
60-
@property
61-
def targeted_vocab_size(self):
62-
return 2**13 # 8192
168+
CWMT:
169+
- http://nlp.nju.edu.cn/cwmt-wmt/
170+
- Website contrains instructions for FTP server access.
171+
- You'll need to download CASIA, CASICT, DATUM2015, DATUM2017,
172+
NEU datasets
173+
174+
UN Parallel Corpus:
175+
- https://conferences.unite.un.org/UNCorpus
176+
- You'll need to register your to download the dataset.
177+
178+
NOTE: place into tmp directory e.g. /tmp/t2t_datagen/dataset.tgz
179+
"""
63180

64181
@property
65-
def num_shards(self):
66-
return 10 # This is a small dataset.
182+
def targeted_vocab_size(self):
183+
return 2**15 # 32k
67184

68185
@property
69186
def source_vocab_name(self):
@@ -72,20 +189,35 @@ def source_vocab_name(self):
72189
@property
73190
def target_vocab_name(self):
74191
return "vocab.enzh-zh.%d" % self.targeted_vocab_size
192+
193+
def get_training_dataset(self, tmp_dir):
194+
"""UN Parallel Corpus and CWMT Corpus need to be downloaded manually.
195+
Append to training dataset if available
196+
"""
197+
full_dataset = _NC_TRAIN_DATASETS
198+
for dataset in [_CWMT_TRAIN_DATASETS, _UN_TRAIN_DATASETS]:
199+
filename = get_filename(dataset)
200+
tmp_filepath = os.path.join(tmp_dir, filename)
201+
if tf.gfile.Exists(tmp_filepath):
202+
full_dataset = full_dataset + dataset
203+
else:
204+
tf.logging.info("[TranslateEzhWmt] dataset incomplete, you need to manually download %s" % filename)
205+
return full_dataset
75206

76207
def generator(self, data_dir, tmp_dir, train):
77-
datasets = _ENZH_TRAIN_DATASETS if train else _ENZH_TEST_DATASETS
78-
source_datasets = [[item[0], [item[1][0]]] for item in _ENZH_TRAIN_DATASETS]
79-
target_datasets = [[item[0], [item[1][1]]] for item in _ENZH_TRAIN_DATASETS]
208+
TRAIN_DATASET = self.get_training_dataset(tmp_dir)
209+
datasets = TRAIN_DATASET if train else _NC_TEST_DATASETS
210+
source_datasets = [[item[0], [item[1][0]]] for item in TRAIN_DATASET]
211+
target_datasets = [[item[0], [item[1][1]]] for item in TRAIN_DATASET]
80212
source_vocab = generator_utils.get_or_generate_vocab(
81213
data_dir, tmp_dir, self.source_vocab_name, self.targeted_vocab_size,
82-
source_datasets)
214+
source_datasets, _file_byte_budget=1e8)
83215
target_vocab = generator_utils.get_or_generate_vocab(
84216
data_dir, tmp_dir, self.target_vocab_name, self.targeted_vocab_size,
85-
target_datasets)
217+
target_datasets, _file_byte_budget=1e8)
86218
tag = "train" if train else "dev"
87-
data_path = translate.compile_data(tmp_dir, datasets,
88-
"wmt_enzh_tok_%s" % tag)
219+
filename_base = "wmt_enzh_%sk_tok_%s" % (self.targeted_vocab_size, tag)
220+
data_path = translate.compile_data(tmp_dir, datasets, filename_base)
89221
return translate.bi_vocabs_token_generator(data_path + ".lang1",
90222
data_path + ".lang2",
91223
source_vocab, target_vocab, EOS)
@@ -107,3 +239,22 @@ def feature_encoders(self, data_dir):
107239
"inputs": source_token,
108240
"targets": target_token,
109241
}
242+
243+
244+
@registry.register_problem
245+
class TranslateEnzhWmt8k(TranslateEnzhWmt32k):
246+
"""Problem spec for WMT En-Zh translation.
247+
This is far from being the real WMT17 task - only toyset here
248+
"""
249+
250+
@property
251+
def targeted_vocab_size(self):
252+
return 2**13 # 8192
253+
254+
@property
255+
def num_shards(self):
256+
return 10 # This is a small dataset.
257+
258+
def get_training_dataset(self, tmp_dir):
259+
"""Uses only News Commentary Dataset for training"""
260+
return _NC_TRAIN_DATASETS

0 commit comments

Comments
 (0)