Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit f77da80

Browse files
T2T TeamRyan Sepassi
T2T Team
authored and
Ryan Sepassi
committed
Update distance computation for k-nearest neighbours to be more efficient, by computing the norms separately.
PiperOrigin-RevId: 176585527
1 parent 2145729 commit f77da80

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

tensor2tensor/models/transformer_vae.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -140,10 +140,11 @@ def vae(x, z_size, name):
140140
def nearest(x, means, hparams):
141141
"""Find the nearest means to elements in x."""
142142
x, means = tf.stop_gradient(x), tf.stop_gradient(means)
143-
means = tf.nn.l2_normalize(means, dim=1)
144143
x_flat = tf.reshape(x, [-1, hparams.hidden_size])
145-
# dist = tf.reduce_sum(tf.square(x_flat - tf.expand_dims(means, 0)), axis=2)
146-
dist = - tf.matmul(x_flat, means, transpose_b=True)
144+
x_norm = tf.norm(x_flat, axis=-1, keep_dims=True)
145+
means_norm = tf.norm(means, axis=-1, keep_dims=True)
146+
dist = x_norm + tf.transpose(means_norm) - 2 * tf.matmul(x_flat, means,
147+
transpose_b=True)
147148
_, nearest_idx = tf.nn.top_k(- dist, k=1)
148149
nearest_hot = tf.one_hot(tf.squeeze(nearest_idx, axis=1), hparams.v_size)
149150
nearest_hot = tf.reshape(nearest_hot, [tf.shape(x)[0], tf.shape(x)[1],

0 commit comments

Comments
 (0)