Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 20c7e41

Browse files
T2T TeamRyan Sepassi
T2T Team
authored and
Ryan Sepassi
committed
Discrete autoencoder with VQ-VAE as in https://arxiv.org/abs/1711.00937.
PiperOrigin-RevId: 177371794
1 parent bb1173a commit 20c7e41

11 files changed

+27
-289
lines changed

README.md

+2-5
Original file line numberDiff line numberDiff line change
@@ -126,12 +126,9 @@ t2t-decoder \
126126
--output_dir=$TRAIN_DIR \
127127
--decode_hparams="beam_size=$BEAM_SIZE,alpha=$ALPHA" \
128128
--decode_from_file=$DECODE_FILE
129-
--decode_to_file=translation.en
130-
```
131-
132-
# Eval BLEU
133129
134-
t2t-bleu --translation=translation.en --reference=ref-translation.de
130+
cat $DECODE_FILE.$MODEL.$HPARAMS.beam$BEAM_SIZE.alpha$ALPHA.decodes
131+
```
135132

136133
---
137134

setup.py

-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
'tensor2tensor/bin/t2t-datagen',
2525
'tensor2tensor/bin/t2t-decoder',
2626
'tensor2tensor/bin/t2t-make-tf-configs',
27-
'tensor2tensor/bin/t2t-bleu',
2827
],
2928
install_requires=[
3029
'bz2file',

tensor2tensor/bin/t2t-bleu

-200
This file was deleted.

tensor2tensor/bin/t2t-datagen

100755100644
File mode changed.

tensor2tensor/bin/t2t-decoder

100755100644
+2-5
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,7 @@ import tensorflow as tf
4646
flags = tf.flags
4747
FLAGS = flags.FLAGS
4848

49-
flags.DEFINE_string("output_dir", "",
50-
"Training directory where the latest checkpoint is used.")
51-
flags.DEFINE_string("checkpoint_path", None,
52-
"Path to the model checkpoint. Overrides output_dir.")
49+
flags.DEFINE_string("output_dir", "", "Training directory to load from.")
5350
flags.DEFINE_string("decode_from_file", None,
5451
"Path to the source file for decoding")
5552
flags.DEFINE_string("decode_to_file", None,
@@ -93,7 +90,7 @@ def main(_):
9390
decoding.decode_interactively(estimator, decode_hp)
9491
elif FLAGS.decode_from_file:
9592
decoding.decode_from_file(estimator, FLAGS.decode_from_file, decode_hp,
96-
FLAGS.decode_to_file, checkpoint_path=FLAGS.checkpoint_path)
93+
FLAGS.decode_to_file)
9794
else:
9895
decoding.decode_from_dataset(
9996
estimator,

tensor2tensor/bin/t2t-make-tf-configs

100755100644
File mode changed.

tensor2tensor/bin/t2t-trainer

100755100644
File mode changed.

tensor2tensor/models/transformer_vae.py

+17-4
Original file line numberDiff line numberDiff line change
@@ -147,17 +147,22 @@ def nearest(x, means, hparams):
147147
transpose_b=True)
148148
_, nearest_idx = tf.nn.top_k(- dist, k=1)
149149
nearest_hot = tf.one_hot(tf.squeeze(nearest_idx, axis=1), hparams.v_size)
150-
nearest_hot = tf.reshape(nearest_hot, [tf.shape(x)[0], tf.shape(x)[1],
151-
tf.shape(x)[2], hparams.v_size])
150+
shape = common_layers.shape_list(x)
151+
shape[-1] = hparams.v_size
152+
nearest_hot = tf.reshape(nearest_hot, shape=shape)
152153
return tf.stop_gradient(nearest_hot)
153154

154155

155156
def kmeans(x, means, hparams, name):
156157
with tf.variable_scope(name):
157158
x_means_hot = nearest(x, means, hparams)
158159
x_means = tf.gather(means, tf.argmax(x_means_hot, axis=-1))
159-
kl = tf.reduce_sum(tf.square(x - x_means), axis=-1)
160-
return x_means_hot, tf.reduce_mean(kl) # * 10.0
160+
x_flat = tf.reshape(x, [-1, hparams.hidden_size])
161+
kl = tf.reduce_mean(tf.reduce_sum(tf.square(x_flat - x_means), axis=-1))
162+
reg_loss1 = tf.nn.l2_loss((tf.stop_gradient(x) - x_means))
163+
reg_loss2 = hparams.beta * tf.nn.l2_loss((x - tf.stop_gradient(x_means)))
164+
l = kl + reg_loss1 + reg_loss2
165+
return x_means_hot, x_means, l
161166

162167

163168
def bit_to_int(x_bit, nbits):
@@ -233,6 +238,12 @@ def embed(x):
233238
_, hot, l = dae(x, hparams, name)
234239
c = tf.argmax(hot, axis=-1)
235240
h1 = tf.layers.dense(hot, hparams.hidden_size, name="dae_dense")
241+
if hparams.bottleneck_kind == "vq-vae":
242+
means = tf.get_variable(name="means", shape=[hparams.v_size,
243+
hparams.hidden_size])
244+
x_means_hot, x_means, l = kmeans(x, means, hparams, name="vq-vae-kmeans")
245+
h1 = x_means
246+
c = tf.argmax(x_means_hot, axis=-1)
236247
h2 = tf.layers.dense(tf.nn.relu(h1), filter_size, name="vch2")
237248
res = tf.layers.dense(tf.nn.relu(h2), hparams.hidden_size, name="vcfin")
238249
return res, c, l, embed
@@ -500,6 +511,8 @@ def transformer_ae_small():
500511
hparams.add_hparam("decode_autoregressive", True)
501512
hparams.add_hparam("do_vae", True)
502513
hparams.add_hparam("bit_vae", True)
514+
hparams.add_hparam("beta", 0.25)
515+
hparams.kl_warmup_steps = 150000
503516
return hparams
504517

505518

tensor2tensor/utils/bleu_hook.py

+1-67
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,13 @@
2020

2121
import collections
2222
import math
23-
import re
24-
import sys
25-
import unicodedata
2623

2724
# Dependency imports
2825

2926
import numpy as np
3027
# pylint: disable=redefined-builtin
3128
from six.moves import xrange
3229
from six.moves import zip
33-
import six
3430
# pylint: enable=redefined-builtin
3531

3632
import tensorflow as tf
@@ -96,17 +92,10 @@ def compute_bleu(reference_corpus,
9692
matches_by_order[len(ngram) - 1] += overlap[ngram]
9793
for ngram in translation_ngram_counts:
9894
possible_matches_by_order[len(ngram)-1] += translation_ngram_counts[ngram]
99-
assert reference_length, "no reference provided"
100-
assert translation_length, "no translation provided"
10195
precisions = [0] * max_order
102-
smooth = 1.0
10396
for i in xrange(0, max_order):
10497
if possible_matches_by_order[i] > 0:
105-
if matches_by_order[i] > 0:
106-
precisions[i] = matches_by_order[i] / possible_matches_by_order[i]
107-
else:
108-
smooth *= 2
109-
precisions[i] = 1.0 / (smooth * possible_matches_by_order[i])
98+
precisions[i] = matches_by_order[i] / possible_matches_by_order[i]
11099
else:
111100
precisions[i] = 0.0
112101

@@ -142,58 +131,3 @@ def bleu_score(predictions, labels, **unused_kwargs):
142131

143132
bleu = tf.py_func(compute_bleu, (labels, outputs), tf.float32)
144133
return bleu, tf.constant(1.0)
145-
146-
147-
class UnicodeRegex:
148-
"""Ad-hoc hack to recognize all punctuation and symbols.
149-
150-
without dependening on https://pypi.python.org/pypi/regex/."""
151-
def _property_chars(prefix):
152-
return ''.join(six.unichr(x) for x in range(sys.maxunicode)
153-
if unicodedata.category(six.unichr(x)).startswith(prefix))
154-
punctuation = _property_chars('P')
155-
nondigit_punct_re = re.compile(r'([^\d])([' + punctuation + r'])')
156-
punct_nondigit_re = re.compile(r'([' + punctuation + r'])([^\d])')
157-
symbol_re = re.compile('([' + _property_chars('S') + '])')
158-
159-
160-
def bleu_tokenize(string):
161-
r"""Tokenize a string following the official BLEU implementation.
162-
163-
See https://github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/mteval-v14.pl#L954-L983
164-
In our case, the input string is expected to be just one line
165-
and no HTML entities de-escaping is needed.
166-
So we just tokenize on punctuation and symbols,
167-
except when a punctuation is preceded and followed by a digit
168-
(e.g. a comma/dot as a thousand/decimal separator).
169-
170-
Note that a numer (e.g. a year) followed by a dot at the end of sentence is NOT tokenized,
171-
i.e. the dot stays with the number because `s/(\p{P})(\P{N})/ $1 $2/g`
172-
does not match this case (unless we add a space after each sentence).
173-
However, this error is already in the original mteval-v14.pl
174-
and we want to be consistent with it.
175-
176-
Args:
177-
string: the input string
178-
179-
Returns:
180-
a list of tokens
181-
"""
182-
string = UnicodeRegex.nondigit_punct_re.sub(r'\1 \2 ', string)
183-
string = UnicodeRegex.punct_nondigit_re.sub(r' \1 \2', string)
184-
string = UnicodeRegex.symbol_re.sub(r' \1 ', string)
185-
return string.split()
186-
187-
188-
def bleu_wrapper(ref_filename, hyp_filename, case_sensitive=False):
189-
"""Compute BLEU for two files (reference and hypothesis translation)."""
190-
# TODO: Does anyone care about Python2 compatibility?
191-
ref_lines = open(ref_filename, 'rt', encoding='utf-8').read().splitlines()
192-
hyp_lines = open(hyp_filename, 'rt', encoding='utf-8').read().splitlines()
193-
assert len(ref_lines) == len(hyp_lines)
194-
if not case_sensitive:
195-
ref_lines = [x.lower() for x in ref_lines]
196-
hyp_lines = [x.lower() for x in hyp_lines]
197-
ref_tokens = [bleu_tokenize(x) for x in ref_lines]
198-
hyp_tokens = [bleu_tokenize(x) for x in hyp_lines]
199-
return compute_bleu(ref_tokens, hyp_tokens)

0 commit comments

Comments
 (0)