Skip to content

Commit ed63356

Browse files
committed
Fix asyml#318: cannot pickle metrics in Python 3.6
Worked around this limitation by using a dummy `Generic` in Python 3.6 and below.
1 parent d383604 commit ed63356

File tree

2 files changed

+35
-2
lines changed

2 files changed

+35
-2
lines changed

tests/run/metric/classification_test.py

+16
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
Unit tests for classification related operations.
1616
"""
1717
import functools
18+
import pickle
1819
import unittest
1920

2021
import numpy as np
@@ -74,3 +75,18 @@ def test_f1(self):
7475
self._test_metric(
7576
metric, functools.partial(f1_score, average=mode),
7677
binary=(mode == 'binary'))
78+
79+
80+
class MetricPicklingTest(unittest.TestCase):
81+
def test_pickle_unpickle(self):
82+
metric = Accuracy(pred_name="123")
83+
metric.add([1, 2, 3], [1, 2, 4])
84+
metric_new = pickle.loads(pickle.dumps(metric))
85+
self.assertEqual(metric.count, metric_new.count)
86+
self.assertEqual(metric.correct, metric_new.correct)
87+
88+
metric = Accuracy[float](pred_name="123")
89+
metric.add([1, 2, 3], [1, 2, 4])
90+
metric_new = pickle.loads(pickle.dumps(metric))
91+
self.assertEqual(metric.count, metric_new.count)
92+
self.assertEqual(metric.correct, metric_new.correct)

texar/torch/run/metric/base_metric.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
"""
1515
Base classes for Executor metrics.
1616
"""
17-
17+
import sys
1818
from abc import ABC, abstractmethod
19-
from typing import Generic, List, Optional, Sequence, TypeVar
19+
from typing import Generic, List, Optional, Sequence, TYPE_CHECKING, TypeVar
2020

2121
__all__ = [
2222
"Metric",
@@ -27,6 +27,23 @@
2727
Input = TypeVar('Input')
2828
Value = TypeVar('Value')
2929

30+
if not TYPE_CHECKING and sys.version_info[:2] <= (3, 6):
31+
# In Python 3.6 and below, pickling a `Generic` subclass that is specialized
32+
# would cause an exception. To prevent troubles with `Executor` save & load,
33+
# we use a dummy implementation of `Generic` through our home-brew
34+
# `GenericMeta`.
35+
from abc import ABCMeta
36+
37+
38+
class GenericMeta(ABCMeta):
39+
def __getitem__(cls, params):
40+
# Whatever the parameters, just return the same class.
41+
return cls
42+
43+
44+
class Generic(metaclass=GenericMeta):
45+
pass
46+
3047

3148
class Metric(Generic[Input, Value], ABC):
3249
r"""Base class of all metrics. You should not directly inherit this class,

0 commit comments

Comments
 (0)