diff --git a/beginner_source/transformer_tutorial.py b/beginner_source/transformer_tutorial.py index d7ebee959e5..437f6345241 100644 --- a/beginner_source/transformer_tutorial.py +++ b/beginner_source/transformer_tutorial.py @@ -29,7 +29,7 @@ ###################################################################### # In this tutorial, we train a ``nn.TransformerEncoder`` model on a -# language modeling task. Please note that this tutorial does not cover +# causal language modeling task. Please note that this tutorial does not cover # the training of `nn.TransformerDecoder `__, as depicted in # the right half of the diagram above. The language modeling task is to assign a # probability for the likelihood of a given word (or a sequence of words) @@ -41,8 +41,10 @@ # Along with the input sequence, a square attention mask is required because the # self-attention layers in ``nn.TransformerDecoder`` are only allowed to attend # the earlier positions in the sequence. For the language modeling task, any -# tokens on the future positions should be masked. To produce a probability -# distribution over output words, the output of the ``nn.TransformerEncoder`` +# tokens on the future positions should be masked. This masking, combined with fact that +# the output embeddings are offset with later positions ensures that the +# predictions for position i can depend only on the known outputs at positions less than i. +# To produce a probability distribution over output words, the output of the ``nn.TransformerEncoder`` # model is passed through a linear layer to output unnormalized logits. # The log-softmax function isn't applied here due to the later use of # `CrossEntropyLoss `__, @@ -91,6 +93,11 @@ def forward(self, src: Tensor, src_mask: Tensor = None) -> Tensor: """ src = self.embedding(src) * math.sqrt(self.d_model) src = self.pos_encoder(src) + if src_mask is None: + """Generate a square causal mask for the sequence. The masked positions are filled with float('-inf'). + Unmasked positions are filled with float(0.0). + """ + src_mask = nn.Transformer.generate_square_subsequent_mask(len(src)).to(device) output = self.transformer_encoder(src, src_mask) output = self.linear(output) return output