diff --git a/beginner_source/transformer_tutorial.py b/beginner_source/transformer_tutorial.py
index 95372cabceb..1917b56ee8a 100644
--- a/beginner_source/transformer_tutorial.py
+++ b/beginner_source/transformer_tutorial.py
@@ -45,6 +45,8 @@
 #
 
 import math
+import os
+from tempfile import TemporaryDirectory
 from typing import Tuple
 
 import torch
@@ -346,24 +348,27 @@ def evaluate(model: nn.Module, eval_data: Tensor) -> float:
 
 best_val_loss = float('inf')
 epochs = 3
-best_model = None
 
-for epoch in range(1, epochs + 1):
-    epoch_start_time = time.time()
-    train(model)
-    val_loss = evaluate(model, val_data)
-    val_ppl = math.exp(val_loss)
-    elapsed = time.time() - epoch_start_time
-    print('-' * 89)
-    print(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | '
-          f'valid loss {val_loss:5.2f} | valid ppl {val_ppl:8.2f}')
-    print('-' * 89)
+with TemporaryDirectory() as tempdir:
+    best_model_params_path = os.path.join(tempdir, "best_model_params.pt")
 
-    if val_loss < best_val_loss:
-        best_val_loss = val_loss
-        best_model = copy.deepcopy(model)
+    for epoch in range(1, epochs + 1):
+        epoch_start_time = time.time()
+        train(model)
+        val_loss = evaluate(model, val_data)
+        val_ppl = math.exp(val_loss)
+        elapsed = time.time() - epoch_start_time
+        print('-' * 89)
+        print(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | '
+            f'valid loss {val_loss:5.2f} | valid ppl {val_ppl:8.2f}')
+        print('-' * 89)
 
-    scheduler.step()
+        if val_loss < best_val_loss:
+            best_val_loss = val_loss
+            torch.save(model.state_dict(), best_model_params_path)
+
+        scheduler.step()
+    model.load_state_dict(torch.load(best_model_params_path)) # load best model states
 
 
 ######################################################################
@@ -371,7 +376,7 @@ def evaluate(model: nn.Module, eval_data: Tensor) -> float:
 # -------------------------------------------
 #
 
-test_loss = evaluate(best_model, test_data)
+test_loss = evaluate(model, test_data)
 test_ppl = math.exp(test_loss)
 print('=' * 89)
 print(f'| End of training | test loss {test_loss:5.2f} | '