Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Enhance WMT17 En-Zh task with full dataset. #461

Merged
merged 2 commits into from
Jan 9, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions tensor2tensor/data_generators/generator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,8 @@ def get_or_generate_vocab_inner(data_dir, vocab_filename, vocab_size,


def get_or_generate_vocab(data_dir, tmp_dir, vocab_filename, vocab_size,
sources):
sources,
_file_byte_budget=1e6):
"""Generate a vocabulary from the datasets in sources."""

def generate():
Expand Down Expand Up @@ -349,7 +350,7 @@ def generate():

# Use Tokenizer to count the word occurrences.
with tf.gfile.GFile(filepath, mode="r") as source_file:
file_byte_budget = 1e6
file_byte_budget = _file_byte_budget
counter = 0
countermax = int(source_file.size() / file_byte_budget / 2)
for line in source_file:
Expand Down
191 changes: 171 additions & 20 deletions tensor2tensor/data_generators/translate_enzh.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,28 +42,145 @@
# This is far from being the real WMT17 task - only toyset here
# you need to register to get UN data and CWT data. Also, by convention,
# this is EN to ZH - use translate_enzh_wmt8k_rev for ZH to EN task
_ENZH_TRAIN_DATASETS = [[("http://data.statmt.org/wmt17/translation-task/"
"training-parallel-nc-v12.tgz"),
("training/news-commentary-v12.zh-en.en",
"training/news-commentary-v12.zh-en.zh")]]
#
# News Commentary, around 220k lines
# This dataset is only a small fraction of full WMT17 task
_NC_TRAIN_DATASETS = [[
"http://data.statmt.org/wmt17/translation-task/training-parallel-nc-v12.tgz",
["training/news-commentary-v12.zh-en.en",
"training/news-commentary-v12.zh-en.zh"]]]

_ENZH_TEST_DATASETS = [[
# Test set from News Commentary. 2000 lines
_NC_TEST_DATASETS = [[
"http://data.statmt.org/wmt17/translation-task/dev.tgz",
("dev/newsdev2017-zhen-src.en.sgm", "dev/newsdev2017-zhen-ref.zh.sgm")
("dev/newsdev2017-enzh-src.en.sgm", "dev/newsdev2017-enzh-ref.zh.sgm")
]]

# UN parallel corpus. 15,886,041 lines
# Visit source website to download manually:
# https://conferences.unite.un.org/UNCorpus
#
# NOTE: You need to register to download dataset from official source
# place into tmp directory e.g. /tmp/t2t_datagen/dataset.tgz
_UN_TRAIN_DATASETS = [[
"https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/UNv1.0.en-zh.tar.gz",
["en-zh/UNv1.0.en-zh.en",
"en-zh/UNv1.0.en-zh.zh"]]]

# CWMT corpus
# Visit source website to download manually:
# http://nlp.nju.edu.cn/cwmt-wmt/
#
# casia2015: 1,050,000 lines
# casict2015: 2,036,833 lines
# datum2015: 1,000,003 lines
# datum2017: 1,999,968 lines
# NEU2017: 2,000,000 lines
#
# NOTE: You need to register to download dataset from official source
# place into tmp directory e.g. /tmp/t2t_datagen/dataset.tgz

_CWMT_TRAIN_DATASETS = [
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
["cwmt/casia2015/casia2015_en.txt",
"cwmt/casia2015/casia2015_ch.txt"]],
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
["cwmt/casict2015/casict2015_en.txt",
"cwmt/casict2015/casict2015_ch.txt"]],
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
["cwmt/neu2017/NEU_en.txt",
"cwmt/neu2017/NEU_cn.txt"]],
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
["cwmt/datum2015/datum_en.txt",
"cwmt/datum2015/datum_ch.txt"]],
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
["cwmt/datum2017/Book1_en.txt",
"cwmt/datum2017/Book1_cn.txt"]],
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
["cwmt/datum2017/Book2_en.txt",
"cwmt/datum2017/Book2_cn.txt"]],
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
["cwmt/datum2017/Book3_en.txt",
"cwmt/datum2017/Book3_cn.txt"]],
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
["cwmt/datum2017/Book4_en.txt",
"cwmt/datum2017/Book4_cn.txt"]],
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
["cwmt/datum2017/Book5_en.txt",
"cwmt/datum2017/Book5_cn.txt"]],
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
["cwmt/datum2017/Book6_en.txt",
"cwmt/datum2017/Book6_cn.txt"]],
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
["cwmt/datum2017/Book7_en.txt",
"cwmt/datum2017/Book7_cn.txt"]],
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
["cwmt/datum2017/Book8_en.txt",
"cwmt/datum2017/Book8_cn.txt"]],
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
["cwmt/datum2017/Book9_en.txt",
"cwmt/datum2017/Book9_cn.txt"]],
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
["cwmt/datum2017/Book10_en.txt",
"cwmt/datum2017/Book10_cn.txt"]],
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
["cwmt/datum2017/Book11_en.txt",
"cwmt/datum2017/Book11_cn.txt"]],
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
["cwmt/datum2017/Book12_en.txt",
"cwmt/datum2017/Book12_cn.txt"]],
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
["cwmt/datum2017/Book13_en.txt",
"cwmt/datum2017/Book13_cn.txt"]],
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
["cwmt/datum2017/Book14_en.txt",
"cwmt/datum2017/Book14_cn.txt"]],
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
["cwmt/datum2017/Book15_en.txt",
"cwmt/datum2017/Book15_cn.txt"]],
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
["cwmt/datum2017/Book16_en.txt",
"cwmt/datum2017/Book16_cn.txt"]],
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
["cwmt/datum2017/Book17_en.txt",
"cwmt/datum2017/Book17_cn.txt"]],
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
["cwmt/datum2017/Book18_en.txt",
"cwmt/datum2017/Book18_cn.txt"]],
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
["cwmt/datum2017/Book19_en.txt",
"cwmt/datum2017/Book19_cn.txt"]],
["https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz",
["cwmt/datum2017/Book20_en.txt",
"cwmt/datum2017/Book20_cn.txt"]]
]


def get_filename(dataset):
return dataset[0][0].split('/')[-1]

@registry.register_problem
class TranslateEnzhWmt8k(translate.TranslateProblem):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's keep an 8k version and add a 32k version. target vocab size should be part of the problem name.

"""Problem spec for WMT En-Zh translation."""
class TranslateEnzhWmt32k(translate.TranslateProblem):
"""Problem spec for WMT En-Zh translation.
Attempts to use full training dataset, which needs website
registration and downloaded manually from official sources:

@property
def targeted_vocab_size(self):
return 2**13 # 8192
CWMT:
- http://nlp.nju.edu.cn/cwmt-wmt/
- Website contrains instructions for FTP server access.
- You'll need to download CASIA, CASICT, DATUM2015, DATUM2017,
NEU datasets

UN Parallel Corpus:
- https://conferences.unite.un.org/UNCorpus
- You'll need to register your to download the dataset.

NOTE: place into tmp directory e.g. /tmp/t2t_datagen/dataset.tgz
"""

@property
def num_shards(self):
return 10 # This is a small dataset.
def targeted_vocab_size(self):
return 2**15 # 32k

@property
def source_vocab_name(self):
Expand All @@ -72,20 +189,35 @@ def source_vocab_name(self):
@property
def target_vocab_name(self):
return "vocab.enzh-zh.%d" % self.targeted_vocab_size

def get_training_dataset(self, tmp_dir):
"""UN Parallel Corpus and CWMT Corpus need to be downloaded manually.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you provide instructions somewhere for how to download these manually?

Append to training dataset if available
"""
full_dataset = _NC_TRAIN_DATASETS
for dataset in [_CWMT_TRAIN_DATASETS, _UN_TRAIN_DATASETS]:
filename = get_filename(dataset)
tmp_filepath = os.path.join(tmp_dir, filename)
if tf.gfile.Exists(tmp_filepath):
full_dataset = full_dataset + dataset
else:
tf.logging.info("[TranslateEzhWmt] dataset incomplete, you need to manually download %s" % filename)
return full_dataset

def generator(self, data_dir, tmp_dir, train):
datasets = _ENZH_TRAIN_DATASETS if train else _ENZH_TEST_DATASETS
source_datasets = [[item[0], [item[1][0]]] for item in _ENZH_TRAIN_DATASETS]
target_datasets = [[item[0], [item[1][1]]] for item in _ENZH_TRAIN_DATASETS]
TRAIN_DATASET = self.get_training_dataset(tmp_dir)
datasets = TRAIN_DATASET if train else _NC_TEST_DATASETS
source_datasets = [[item[0], [item[1][0]]] for item in TRAIN_DATASET]
target_datasets = [[item[0], [item[1][1]]] for item in TRAIN_DATASET]
source_vocab = generator_utils.get_or_generate_vocab(
data_dir, tmp_dir, self.source_vocab_name, self.targeted_vocab_size,
source_datasets)
source_datasets, _file_byte_budget=1e8)
target_vocab = generator_utils.get_or_generate_vocab(
data_dir, tmp_dir, self.target_vocab_name, self.targeted_vocab_size,
target_datasets)
target_datasets, _file_byte_budget=1e8)
tag = "train" if train else "dev"
data_path = translate.compile_data(tmp_dir, datasets,
"wmt_enzh_tok_%s" % tag)
filename_base = "wmt_enzh_%sk_tok_%s" % (self.targeted_vocab_size, tag)
data_path = translate.compile_data(tmp_dir, datasets, filename_base)
return translate.bi_vocabs_token_generator(data_path + ".lang1",
data_path + ".lang2",
source_vocab, target_vocab, EOS)
Expand All @@ -107,3 +239,22 @@ def feature_encoders(self, data_dir):
"inputs": source_token,
"targets": target_token,
}


@registry.register_problem
class TranslateEnzhWmt8k(TranslateEnzhWmt32k):
"""Problem spec for WMT En-Zh translation.
This is far from being the real WMT17 task - only toyset here
"""

@property
def targeted_vocab_size(self):
return 2**13 # 8192

@property
def num_shards(self):
return 10 # This is a small dataset.

def get_training_dataset(self, tmp_dir):
"""Uses only News Commentary Dataset for training"""
return _NC_TRAIN_DATASETS