Skip to content

Commit a2f1186

Browse files
authored
Fix #318 cannot pickle metrics (#319)
* Change executor test Pickling certain metrics while specifying input type as generic type parameter would result in a failure in Python 3.6. Previous tests didn't cover this case. * Fix #318: cannot pickle metrics in Python 3.6 Worked around this limitation by using a dummy `Generic` in Python 3.6 and below. * Fix stupid linting issues
1 parent fcd3aa0 commit a2f1186

File tree

3 files changed

+34
-3
lines changed

3 files changed

+34
-3
lines changed

tests/run/executor_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def test_train_loop(self):
111111
save_every=[cond.time(seconds=10), cond.validation(better=True)],
112112
train_metrics=[("loss", metric.RunningAverage(20)),
113113
metric.F1(pred_name="preds", mode="macro"),
114-
metric.Accuracy(pred_name="preds"),
114+
metric.Accuracy[float](pred_name="preds"),
115115
metric.LR(optimizer)],
116116
optimizer=optimizer,
117117
stop_training_on=cond.epoch(10),

tests/run/metric/classification_test.py

+16
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
Unit tests for classification related operations.
1616
"""
1717
import functools
18+
import pickle
1819
import unittest
1920

2021
import numpy as np
@@ -74,3 +75,18 @@ def test_f1(self):
7475
self._test_metric(
7576
metric, functools.partial(f1_score, average=mode),
7677
binary=(mode == 'binary'))
78+
79+
80+
class MetricPicklingTest(unittest.TestCase):
81+
def test_pickle_unpickle(self):
82+
metric = Accuracy(pred_name="123")
83+
metric.add([1, 2, 3], [1, 2, 4])
84+
metric_new = pickle.loads(pickle.dumps(metric))
85+
self.assertEqual(metric.count, metric_new.count)
86+
self.assertEqual(metric.correct, metric_new.correct)
87+
88+
metric = Accuracy[float](pred_name="123")
89+
metric.add([1, 2, 3], [1, 2, 4])
90+
metric_new = pickle.loads(pickle.dumps(metric))
91+
self.assertEqual(metric.count, metric_new.count)
92+
self.assertEqual(metric.correct, metric_new.correct)

texar/torch/run/metric/base_metric.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
"""
1515
Base classes for Executor metrics.
1616
"""
17-
17+
import sys
1818
from abc import ABC, abstractmethod
19-
from typing import Generic, List, Optional, Sequence, TypeVar
19+
from typing import Generic, List, Optional, Sequence, TYPE_CHECKING, TypeVar
2020

2121
__all__ = [
2222
"Metric",
@@ -27,6 +27,21 @@
2727
Input = TypeVar('Input')
2828
Value = TypeVar('Value')
2929

30+
if not TYPE_CHECKING and sys.version_info[:2] <= (3, 6):
31+
# In Python 3.6 and below, pickling a `Generic` subclass that is specialized
32+
# would cause an exception. To prevent troubles with `Executor` save & load,
33+
# we use a dummy implementation of `Generic` through our home-brew
34+
# `GenericMeta`.
35+
from abc import ABCMeta # pylint: disable=ungrouped-imports
36+
37+
class GenericMeta(ABCMeta):
38+
def __getitem__(cls, params):
39+
# Whatever the parameters, just return the same class.
40+
return cls
41+
42+
class Generic(metaclass=GenericMeta): # pylint: disable=function-redefined
43+
pass
44+
3045

3146
class Metric(Generic[Input, Value], ABC):
3247
r"""Base class of all metrics. You should not directly inherit this class,

0 commit comments

Comments
 (0)