diff --git a/tests/run/executor_test.py b/tests/run/executor_test.py index fd0b10106..430ed5eea 100644 --- a/tests/run/executor_test.py +++ b/tests/run/executor_test.py @@ -111,7 +111,7 @@ def test_train_loop(self): save_every=[cond.time(seconds=10), cond.validation(better=True)], train_metrics=[("loss", metric.RunningAverage(20)), metric.F1(pred_name="preds", mode="macro"), - metric.Accuracy(pred_name="preds"), + metric.Accuracy[float](pred_name="preds"), metric.LR(optimizer)], optimizer=optimizer, stop_training_on=cond.epoch(10), diff --git a/tests/run/metric/classification_test.py b/tests/run/metric/classification_test.py index 1f7a0cd07..a813836dd 100644 --- a/tests/run/metric/classification_test.py +++ b/tests/run/metric/classification_test.py @@ -15,6 +15,7 @@ Unit tests for classification related operations. """ import functools +import pickle import unittest import numpy as np @@ -74,3 +75,18 @@ def test_f1(self): self._test_metric( metric, functools.partial(f1_score, average=mode), binary=(mode == 'binary')) + + +class MetricPicklingTest(unittest.TestCase): + def test_pickle_unpickle(self): + metric = Accuracy(pred_name="123") + metric.add([1, 2, 3], [1, 2, 4]) + metric_new = pickle.loads(pickle.dumps(metric)) + self.assertEqual(metric.count, metric_new.count) + self.assertEqual(metric.correct, metric_new.correct) + + metric = Accuracy[float](pred_name="123") + metric.add([1, 2, 3], [1, 2, 4]) + metric_new = pickle.loads(pickle.dumps(metric)) + self.assertEqual(metric.count, metric_new.count) + self.assertEqual(metric.correct, metric_new.correct) diff --git a/texar/torch/run/metric/base_metric.py b/texar/torch/run/metric/base_metric.py index b4146de09..5c58ec6bf 100644 --- a/texar/torch/run/metric/base_metric.py +++ b/texar/torch/run/metric/base_metric.py @@ -14,9 +14,9 @@ """ Base classes for Executor metrics. """ - +import sys from abc import ABC, abstractmethod -from typing import Generic, List, Optional, Sequence, TypeVar +from typing import Generic, List, Optional, Sequence, TYPE_CHECKING, TypeVar __all__ = [ "Metric", @@ -27,6 +27,21 @@ Input = TypeVar('Input') Value = TypeVar('Value') +if not TYPE_CHECKING and sys.version_info[:2] <= (3, 6): + # In Python 3.6 and below, pickling a `Generic` subclass that is specialized + # would cause an exception. To prevent troubles with `Executor` save & load, + # we use a dummy implementation of `Generic` through our home-brew + # `GenericMeta`. + from abc import ABCMeta # pylint: disable=ungrouped-imports + + class GenericMeta(ABCMeta): + def __getitem__(cls, params): + # Whatever the parameters, just return the same class. + return cls + + class Generic(metaclass=GenericMeta): # pylint: disable=function-redefined + pass + class Metric(Generic[Input, Value], ABC): r"""Base class of all metrics. You should not directly inherit this class,