diff --git a/_static/img/seq-seq-images/attention-decoder-network.png b/_static/img/seq-seq-images/attention-decoder-network.png
index 243f87c6e97..d31d42a5af1 100755
Binary files a/_static/img/seq-seq-images/attention-decoder-network.png and b/_static/img/seq-seq-images/attention-decoder-network.png differ
diff --git a/intermediate_source/seq2seq_translation_tutorial.py b/intermediate_source/seq2seq_translation_tutorial.py
index ea583821f85..c2b0b722e5b 100644
--- a/intermediate_source/seq2seq_translation_tutorial.py
+++ b/intermediate_source/seq2seq_translation_tutorial.py
@@ -440,25 +440,27 @@ def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=MAX_LENGT
         self.max_length = max_length
 
         self.embedding = nn.Embedding(self.output_size, self.hidden_size)
-        self.attn = nn.Linear(self.hidden_size * 2, self.max_length)
-        self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
+        self.fc_hidden = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
+        self.fc_encoder = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
+        self.alignment_vector = nn.Parameter(torch.Tensor(1, hidden_size))
+        torch.nn.init.xavier_uniform_(self.alignment_vector)
         self.dropout = nn.Dropout(self.dropout_p)
-        self.gru = nn.GRU(self.hidden_size, self.hidden_size)
+        self.gru = nn.GRU(self.hidden_size * 2, self.hidden_size)
         self.out = nn.Linear(self.hidden_size, self.output_size)
 
     def forward(self, input, hidden, encoder_outputs):
-        embedded = self.embedding(input).view(1, 1, -1)
+        embedded = self.embedding(input).view(1, -1)
         embedded = self.dropout(embedded)
 
-        attn_weights = F.softmax(
-            self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1)
-        attn_applied = torch.bmm(attn_weights.unsqueeze(0),
-                                 encoder_outputs.unsqueeze(0))
-
-        output = torch.cat((embedded[0], attn_applied[0]), 1)
-        output = self.attn_combine(output).unsqueeze(0)
+        transformed_hidden = self.fc_hidden(hidden[0])
+        expanded_hidden_state = transformed_hidden.expand(self.max_length, -1)
+        alignment_scores = torch.tanh(expanded_hidden_state +
+                                      self.fc_encoder(encoder_outputs))
+        alignment_scores = self.alignment_vector.mm(alignment_scores.T)
+        attn_weights = F.softmax(alignment_scores, dim=1)
+        context_vector = attn_weights.mm(encoder_outputs)
 
-        output = F.relu(output)
+        output = torch.cat((embedded, context_vector), 1).unsqueeze(0)
         output, hidden = self.gru(output, hidden)
 
         output = F.log_softmax(self.out(output[0]), dim=1)
@@ -761,15 +763,15 @@ def evaluateRandomly(encoder, decoder, n=10):
 #
 
 hidden_size = 256
-encoder1 = EncoderRNN(input_lang.n_words, hidden_size).to(device)
-attn_decoder1 = AttnDecoderRNN(hidden_size, output_lang.n_words, dropout_p=0.1).to(device)
+encoder = EncoderRNN(input_lang.n_words, hidden_size).to(device)
+attn_decoder = AttnDecoderRNN(hidden_size, output_lang.n_words, dropout_p=0.1).to(device)
 
-trainIters(encoder1, attn_decoder1, 75000, print_every=5000)
+trainIters(encoder, attn_decoder, 75000, print_every=5000)
 
 ######################################################################
 #
 
-evaluateRandomly(encoder1, attn_decoder1)
+evaluateRandomly(encoder, attn_decoder)
 
 
 ######################################################################
@@ -787,7 +789,7 @@ def evaluateRandomly(encoder, decoder, n=10):
 #
 
 output_words, attentions = evaluate(
-    encoder1, attn_decoder1, "je suis trop froid .")
+    encoder, attn_decoder, "je suis trop froid .")
 plt.matshow(attentions.numpy())
 
 
@@ -817,7 +819,7 @@ def showAttention(input_sentence, output_words, attentions):
 
 def evaluateAndShowAttention(input_sentence):
     output_words, attentions = evaluate(
-        encoder1, attn_decoder1, input_sentence)
+        encoder, attn_decoder, input_sentence)
     print('input =', input_sentence)
     print('output =', ' '.join(output_words))
     showAttention(input_sentence, output_words, attentions)