diff --git a/beginner_source/transformer_tutorial.py b/beginner_source/transformer_tutorial.py index 95372cabceb..1917b56ee8a 100644 --- a/beginner_source/transformer_tutorial.py +++ b/beginner_source/transformer_tutorial.py @@ -45,6 +45,8 @@ # import math +import os +from tempfile import TemporaryDirectory from typing import Tuple import torch @@ -346,24 +348,27 @@ def evaluate(model: nn.Module, eval_data: Tensor) -> float: best_val_loss = float('inf') epochs = 3 -best_model = None -for epoch in range(1, epochs + 1): - epoch_start_time = time.time() - train(model) - val_loss = evaluate(model, val_data) - val_ppl = math.exp(val_loss) - elapsed = time.time() - epoch_start_time - print('-' * 89) - print(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | ' - f'valid loss {val_loss:5.2f} | valid ppl {val_ppl:8.2f}') - print('-' * 89) +with TemporaryDirectory() as tempdir: + best_model_params_path = os.path.join(tempdir, "best_model_params.pt") - if val_loss < best_val_loss: - best_val_loss = val_loss - best_model = copy.deepcopy(model) + for epoch in range(1, epochs + 1): + epoch_start_time = time.time() + train(model) + val_loss = evaluate(model, val_data) + val_ppl = math.exp(val_loss) + elapsed = time.time() - epoch_start_time + print('-' * 89) + print(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | ' + f'valid loss {val_loss:5.2f} | valid ppl {val_ppl:8.2f}') + print('-' * 89) - scheduler.step() + if val_loss < best_val_loss: + best_val_loss = val_loss + torch.save(model.state_dict(), best_model_params_path) + + scheduler.step() + model.load_state_dict(torch.load(best_model_params_path)) # load best model states ###################################################################### @@ -371,7 +376,7 @@ def evaluate(model: nn.Module, eval_data: Tensor) -> float: # ------------------------------------------- # -test_loss = evaluate(best_model, test_data) +test_loss = evaluate(model, test_data) test_ppl = math.exp(test_loss) print('=' * 89) print(f'| End of training | test loss {test_loss:5.2f} | '