Skip to content

Commit 0344c55

Browse files
authored
Fix transformer loss (tensorflow#4270)
1 parent 461fc09 commit 0344c55

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

official/transformer/transformer_main.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,12 @@ def model_fn(features, labels, mode, params):
8181
logits = output
8282

8383
# Calculate model loss.
84+
# xentropy contains the cross entropy loss of every nonpadding token in the
85+
# targets.
8486
xentropy, weights = metrics.padded_cross_entropy_loss(
8587
logits, targets, params.label_smoothing, params.vocab_size)
86-
loss = tf.reduce_sum(xentropy * weights) / tf.reduce_sum(weights)
88+
# Compute the weighted mean of the cross entropy losses
89+
loss = tf.reduce_sum(xentropy) / tf.reduce_sum(weights)
8790

8891
# Save loss as named tensor that will be logged with the logging hook.
8992
tf.identity(loss, "cross_entropy")

official/transformer/utils/metrics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ def padded_cross_entropy_loss(logits, labels, smoothing, vocab_size):
5858
smoothing: Label smoothing constant, used to determine the on and off values
5959
vocab_size: int size of the vocabulary
6060
Returns:
61-
Returns a float32 tensor with shape
62-
[batch_size, max(length_logits, length_labels)]
61+
Returns the cross entropy loss and weight tensors: float32 tensors with
62+
shape [batch_size, max(length_logits, length_labels)]
6363
"""
6464
with tf.name_scope("loss", [logits, labels]):
6565
logits, labels = _pad_tensors_to_same_length(logits, labels)

0 commit comments

Comments
 (0)