diff --git a/config.yaml b/config.yaml index f770aa6..c1ba0a1 100644 --- a/config.yaml +++ b/config.yaml @@ -52,4 +52,7 @@ train: save: path_to_folder: 'models/test_main/' + model_checkpoint: + save_frequency: 2 + save_best_weights: true export_onnx: true diff --git a/main.py b/main.py index 289d551..4f35865 100644 --- a/main.py +++ b/main.py @@ -8,6 +8,7 @@ from torch.utils.data import DataLoader from pytorch_ner.dataset import NERCollator, NERDataset +from pytorch_ner.model_checkpoint import model_checkpoint from pytorch_ner.nn_modules.architecture import BiLSTM from pytorch_ner.nn_modules.embedding import Embedding from pytorch_ner.nn_modules.linear import LinearHead @@ -201,6 +202,10 @@ def main(path_to_config: str): optimizer=optimizer, device=device, n_epoch=config["train"]["n_epoch"], + export_onnx=config["save"]["export_onnx"], + path_to_folder=config["save"]["path_to_folder"], + save_frequency=config["save"]["model_checkpoint"]["save_frequency"], + save_best_weights=config["save"]["model_checkpoint"]["save_best_weights"], verbose=config["train"]["verbose"], ) diff --git a/pytorch_ner/model_checkpoint.py b/pytorch_ner/model_checkpoint.py new file mode 100644 index 0000000..1c07cbb --- /dev/null +++ b/pytorch_ner/model_checkpoint.py @@ -0,0 +1,53 @@ +import json +import os +from typing import Dict + +import numpy as np +import torch +import torch.nn as nn +import yaml + +from pytorch_ner.onnx import onnx_export_and_check +from pytorch_ner.utils import mkdir, rmdir + + +def model_checkpoint( + model: nn.Module, + epoch: int, + save_best_weights: bool, + val_metrics, + val_losses, + path_to_folder: str, + export_onnx: bool, + save_frequency: int, +): + + """ + This function creates check point based on either one of the two scenarios: + 1. Save best weights regarding the val_loss + 2. Save weights frequently with save_frequency int + + """ + if save_best_weights: + if np.mean(val_metrics["loss"]) < min(val_losses): + # This iteration has lower val_loss, let's save it + val_losses.append(np.mean(val_metrics["loss"])) + pth_file_name = "best_model.pth" + onnx_file_name = "best_model.onnx" + else: + # No need to save weights + return + else: + if epoch % save_frequency == 0: + # We're at multiple of save_frequency, let's save weights + pth_file_name = "model_epoch_" + str(epoch) + ".pth" + onnx_file_name = "model_epoch_" + str(epoch) + ".onnx" + else: + # No need to save weights + return + + torch.save(model.state_dict(), os.path.join(path_to_folder, pth_file_name)) + if export_onnx: + onnx_export_and_check( + model=model, path_to_save=os.path.join(path_to_folder, onnx_file_name) + ) diff --git a/pytorch_ner/save.py b/pytorch_ner/save.py index fa0fd7f..50a5b61 100644 --- a/pytorch_ner/save.py +++ b/pytorch_ner/save.py @@ -18,9 +18,10 @@ def save_model( config: Dict, export_onnx: bool = False, ): - # make empty dir - rmdir(path_to_folder) - mkdir(path_to_folder) + + if not os.path.exists(path_to_folder): + # make empty dir + mkdir(path_to_folder) model.cpu() model.eval() diff --git a/pytorch_ner/train.py b/pytorch_ner/train.py index 60e1d52..a0a9432 100644 --- a/pytorch_ner/train.py +++ b/pytorch_ner/train.py @@ -1,3 +1,4 @@ +import os from collections import defaultdict from typing import Callable, DefaultDict, List, Optional @@ -9,7 +10,9 @@ from tqdm import tqdm from pytorch_ner.metrics import calculate_metrics -from pytorch_ner.utils import to_numpy +from pytorch_ner.model_checkpoint import model_checkpoint +from pytorch_ner.onnx import onnx_export_and_check +from pytorch_ner.utils import mkdir, rmdir, to_numpy def masking(lengths: torch.Tensor) -> torch.Tensor: @@ -144,12 +147,23 @@ def train( optimizer: optim.Optimizer, device: torch.device, n_epoch: int, + export_onnx: bool, + path_to_folder: str, + save_frequency: int, + save_best_weights: bool, testloader: Optional[DataLoader] = None, verbose: bool = True, ): """ Training / validation loop for n_epoch with final testing. """ + if os.path.exists(path_to_folder): + # delete any previous versions of models + rmdir(path_to_folder) + mkdir(path_to_folder) + + # List that tracks val_loss over training to save best weights + val_losses = [np.inf] for epoch in range(n_epoch): @@ -183,6 +197,18 @@ def train( print(f"val {metric_name}: {np.mean(metric_list)}") print() + # Model Checkpoint + model_checkpoint( + model=model, + epoch=epoch, + save_best_weights=save_best_weights, + val_metrics=val_metrics, + val_losses=val_losses, + path_to_folder=path_to_folder, + export_onnx=export_onnx, + save_frequency=save_frequency, + ) + if testloader is not None: test_metrics = validate_loop( diff --git a/tests/test_train.py b/tests/test_train.py index 4158166..3efab0e 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -75,6 +75,10 @@ optimizer=optimizer, device=device, n_epoch=5, + export_onnx=True, + path_to_folder="models/test_main/", + save_frequency=1, + save_best_weights=True, verbose=False, )