From e440461add999db3980aee0e27494ec07a1cdde7 Mon Sep 17 00:00:00 2001 From: Nayef Ahmed Date: Tue, 24 Jan 2023 16:14:11 -0800 Subject: [PATCH 1/3] Remove deepcopies to store best model states --- beginner_source/transformer_tutorial.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/beginner_source/transformer_tutorial.py b/beginner_source/transformer_tutorial.py index 95372cabceb..386e1348545 100644 --- a/beginner_source/transformer_tutorial.py +++ b/beginner_source/transformer_tutorial.py @@ -45,6 +45,8 @@ # import math +import os +from tempfile import TemporaryDirectory from typing import Tuple import torch @@ -346,8 +348,9 @@ def evaluate(model: nn.Module, eval_data: Tensor) -> float: best_val_loss = float('inf') epochs = 3 -best_model = None +tempdir = TemporaryDirectory() +best_model_params_path = os.path.join(tempdir.path, "best_model_params.pt") for epoch in range(1, epochs + 1): epoch_start_time = time.time() train(model) @@ -356,12 +359,12 @@ def evaluate(model: nn.Module, eval_data: Tensor) -> float: elapsed = time.time() - epoch_start_time print('-' * 89) print(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | ' - f'valid loss {val_loss:5.2f} | valid ppl {val_ppl:8.2f}') + f'valid loss {val_loss:5.2f} | valid ppl {val_ppl:8.2f}') print('-' * 89) if val_loss < best_val_loss: best_val_loss = val_loss - best_model = copy.deepcopy(model) + torch.save(model.state_dict(), best_model_params_path) scheduler.step() @@ -371,9 +374,12 @@ def evaluate(model: nn.Module, eval_data: Tensor) -> float: # ------------------------------------------- # -test_loss = evaluate(best_model, test_data) +model = torch.load(best_model_params_path) # load best model states +tempdir.cleanup() + +test_loss = evaluate(model, test_data) test_ppl = math.exp(test_loss) print('=' * 89) print(f'| End of training | test loss {test_loss:5.2f} | ' f'test ppl {test_ppl:8.2f}') -print('=' * 89) +print('=' * 89) \ No newline at end of file From 41cfe092fcfc5c0a17dc18c7fe49cf91d157e103 Mon Sep 17 00:00:00 2001 From: Nayef Ahmed Date: Tue, 24 Jan 2023 16:20:25 -0800 Subject: [PATCH 2/3] Formatting --- beginner_source/transformer_tutorial.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/beginner_source/transformer_tutorial.py b/beginner_source/transformer_tutorial.py index 386e1348545..8d717fc26bc 100644 --- a/beginner_source/transformer_tutorial.py +++ b/beginner_source/transformer_tutorial.py @@ -382,4 +382,4 @@ def evaluate(model: nn.Module, eval_data: Tensor) -> float: print('=' * 89) print(f'| End of training | test loss {test_loss:5.2f} | ' f'test ppl {test_ppl:8.2f}') -print('=' * 89) \ No newline at end of file +print('=' * 89) From 18e5bfe471796ff3f9c847161004b67c65328e1b Mon Sep 17 00:00:00 2001 From: Nayef Ahmed Date: Wed, 25 Jan 2023 09:26:57 -0800 Subject: [PATCH 3/3] Fix CI errors --- beginner_source/transformer_tutorial.py | 41 ++++++++++++------------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/beginner_source/transformer_tutorial.py b/beginner_source/transformer_tutorial.py index 8d717fc26bc..1917b56ee8a 100644 --- a/beginner_source/transformer_tutorial.py +++ b/beginner_source/transformer_tutorial.py @@ -349,24 +349,26 @@ def evaluate(model: nn.Module, eval_data: Tensor) -> float: best_val_loss = float('inf') epochs = 3 -tempdir = TemporaryDirectory() -best_model_params_path = os.path.join(tempdir.path, "best_model_params.pt") -for epoch in range(1, epochs + 1): - epoch_start_time = time.time() - train(model) - val_loss = evaluate(model, val_data) - val_ppl = math.exp(val_loss) - elapsed = time.time() - epoch_start_time - print('-' * 89) - print(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | ' - f'valid loss {val_loss:5.2f} | valid ppl {val_ppl:8.2f}') - print('-' * 89) - - if val_loss < best_val_loss: - best_val_loss = val_loss - torch.save(model.state_dict(), best_model_params_path) - - scheduler.step() +with TemporaryDirectory() as tempdir: + best_model_params_path = os.path.join(tempdir, "best_model_params.pt") + + for epoch in range(1, epochs + 1): + epoch_start_time = time.time() + train(model) + val_loss = evaluate(model, val_data) + val_ppl = math.exp(val_loss) + elapsed = time.time() - epoch_start_time + print('-' * 89) + print(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | ' + f'valid loss {val_loss:5.2f} | valid ppl {val_ppl:8.2f}') + print('-' * 89) + + if val_loss < best_val_loss: + best_val_loss = val_loss + torch.save(model.state_dict(), best_model_params_path) + + scheduler.step() + model.load_state_dict(torch.load(best_model_params_path)) # load best model states ###################################################################### @@ -374,9 +376,6 @@ def evaluate(model: nn.Module, eval_data: Tensor) -> float: # ------------------------------------------- # -model = torch.load(best_model_params_path) # load best model states -tempdir.cleanup() - test_loss = evaluate(model, test_data) test_ppl = math.exp(test_loss) print('=' * 89)