Skip to content

Commit a38bdac

Browse files
ilevkivskyigvanrossum
authored andcommitted
A real fix for issue #250 (failure with mock) (#295)
Fixes #250 The main idea here is optimizing generics for cases where a type information is added to existing code. Now: * ``Node[int]`` and ``Node`` have identical ``__bases__`` and identical ``__mro__[1:]`` (except for the first item, since it is the class itself). * After addition of typing information (i.e. making some classes generic), ``__mro__`` is changed very little, at most one bare ``Generic`` appears in ``__mro__``. * Consequently, only non-parameterized generics appear in ``__bases__`` and ``__mro__[1:]``. Interestingly, this could be achieved in few lines of code and no existing test break. On the positive side of this approach, there is very little chance that existing code (even with sophisticated "magic") will break after addition of typing information. On the negative side, it will be more difficult for _runtime_ type-checkers to perform decorator-based type checks (e.g. enforce method overriding only by consistent methods). Essentially, now type erasure happens partially at the class creation time (all bases are reduced to origin). (We have __orig_class__ and __orig_bases__ to help runtime checkers.)
1 parent 8988cd9 commit a38bdac

File tree

4 files changed

+159
-16
lines changed

4 files changed

+159
-16
lines changed

python2/test_typing.py

+56
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,62 @@ class MM1(MutableMapping[str, str], collections_abc.MutableMapping):
627627
class MM2(collections_abc.MutableMapping, MutableMapping[str, str]):
628628
pass
629629

630+
def test_orig_bases(self):
631+
T = TypeVar('T')
632+
class C(typing.Dict[str, T]): pass
633+
self.assertEqual(C.__orig_bases__, (typing.Dict[str, T],))
634+
635+
def test_naive_runtime_checks(self):
636+
def naive_dict_check(obj, tp):
637+
# Check if a dictionary conforms to Dict type
638+
if len(tp.__parameters__) > 0:
639+
raise NotImplementedError
640+
if tp.__args__:
641+
KT, VT = tp.__args__
642+
return all(isinstance(k, KT) and isinstance(v, VT)
643+
for k, v in obj.items())
644+
self.assertTrue(naive_dict_check({'x': 1}, typing.Dict[typing.Text, int]))
645+
self.assertFalse(naive_dict_check({1: 'x'}, typing.Dict[typing.Text, int]))
646+
with self.assertRaises(NotImplementedError):
647+
naive_dict_check({1: 'x'}, typing.Dict[typing.Text, T])
648+
649+
def naive_generic_check(obj, tp):
650+
# Check if an instance conforms to the generic class
651+
if not hasattr(obj, '__orig_class__'):
652+
raise NotImplementedError
653+
return obj.__orig_class__ == tp
654+
class Node(Generic[T]): pass
655+
self.assertTrue(naive_generic_check(Node[int](), Node[int]))
656+
self.assertFalse(naive_generic_check(Node[str](), Node[int]))
657+
self.assertFalse(naive_generic_check(Node[str](), List))
658+
with self.assertRaises(NotImplementedError):
659+
naive_generic_check([1,2,3], Node[int])
660+
661+
def naive_list_base_check(obj, tp):
662+
# Check if list conforms to a List subclass
663+
return all(isinstance(x, tp.__orig_bases__[0].__args__[0])
664+
for x in obj)
665+
class C(List[int]): pass
666+
self.assertTrue(naive_list_base_check([1, 2, 3], C))
667+
self.assertFalse(naive_list_base_check(['a', 'b'], C))
668+
669+
def test_multi_subscr_base(self):
670+
T = TypeVar('T')
671+
U = TypeVar('U')
672+
V = TypeVar('V')
673+
class C(List[T][U][V]): pass
674+
class D(C, List[T][U][V]): pass
675+
self.assertEqual(C.__parameters__, (V,))
676+
self.assertEqual(D.__parameters__, (V,))
677+
self.assertEqual(C[int].__parameters__, ())
678+
self.assertEqual(D[int].__parameters__, ())
679+
self.assertEqual(C[int].__args__, (int,))
680+
self.assertEqual(D[int].__args__, (int,))
681+
self.assertEqual(C.__bases__, (List,))
682+
self.assertEqual(D.__bases__, (C, List))
683+
self.assertEqual(C.__orig_bases__, (List[T][U][V],))
684+
self.assertEqual(D.__orig_bases__, (C, List[T][U][V]))
685+
630686
def test_pickle(self):
631687
global C # pickle wants to reference the class by name
632688
T = TypeVar('T')

python2/typing.py

+24-9
Original file line numberDiff line numberDiff line change
@@ -1069,13 +1069,7 @@ class GenericMeta(TypingMeta, abc.ABCMeta):
10691069
"""Metaclass for generic types."""
10701070

10711071
def __new__(cls, name, bases, namespace,
1072-
tvars=None, args=None, origin=None, extra=None):
1073-
if extra is None:
1074-
extra = namespace.get('__extra__')
1075-
if extra is not None and type(extra) is abc.ABCMeta and extra not in bases:
1076-
bases = (extra,) + bases
1077-
self = super(GenericMeta, cls).__new__(cls, name, bases, namespace)
1078-
1072+
tvars=None, args=None, origin=None, extra=None, orig_bases=None):
10791073
if tvars is not None:
10801074
# Called from __getitem__() below.
10811075
assert origin is not None
@@ -1116,12 +1110,27 @@ def __new__(cls, name, bases, namespace,
11161110
", ".join(str(g) for g in gvars)))
11171111
tvars = gvars
11181112

1113+
initial_bases = bases
1114+
if extra is None:
1115+
extra = namespace.get('__extra__')
1116+
if extra is not None and type(extra) is abc.ABCMeta and extra not in bases:
1117+
bases = (extra,) + bases
1118+
bases = tuple(_gorg(b) if isinstance(b, GenericMeta) else b for b in bases)
1119+
1120+
# remove bare Generic from bases if there are other generic bases
1121+
if any(isinstance(b, GenericMeta) and b is not Generic for b in bases):
1122+
bases = tuple(b for b in bases if b is not Generic)
1123+
self = super(GenericMeta, cls).__new__(cls, name, bases, namespace)
1124+
11191125
self.__parameters__ = tvars
11201126
self.__args__ = args
11211127
self.__origin__ = origin
11221128
self.__extra__ = extra
11231129
# Speed hack (https://github.com/python/typing/issues/196).
11241130
self.__next_in_mro__ = _next_in_mro(self)
1131+
# Preserve base classes on subclassing (__bases__ are type erased now).
1132+
if orig_bases is None:
1133+
self.__orig_bases__ = initial_bases
11251134

11261135
# This allows unparameterized generic collections to be used
11271136
# with issubclass() and isinstance() in the same way as their
@@ -1216,12 +1225,13 @@ def __getitem__(self, params):
12161225
tvars = _type_vars(params)
12171226
args = params
12181227
return self.__class__(self.__name__,
1219-
(self,) + self.__bases__,
1228+
self.__bases__,
12201229
dict(self.__dict__),
12211230
tvars=tvars,
12221231
args=args,
12231232
origin=self,
1224-
extra=self.__extra__)
1233+
extra=self.__extra__,
1234+
orig_bases=self.__orig_bases__)
12251235

12261236
def __instancecheck__(self, instance):
12271237
# Since we extend ABC.__subclasscheck__ and
@@ -1268,6 +1278,10 @@ def __new__(cls, *args, **kwds):
12681278
else:
12691279
origin = _gorg(cls)
12701280
obj = cls.__next_in_mro__.__new__(origin)
1281+
try:
1282+
obj.__orig_class__ = cls
1283+
except AttributeError:
1284+
pass
12711285
obj.__init__(*args, **kwds)
12721286
return obj
12731287

@@ -1438,6 +1452,7 @@ def _get_protocol_attrs(self):
14381452
attr != '__next_in_mro__' and
14391453
attr != '__parameters__' and
14401454
attr != '__origin__' and
1455+
attr != '__orig_bases__' and
14411456
attr != '__extra__' and
14421457
attr != '__module__'):
14431458
attrs.add(attr)

src/test_typing.py

+57
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,63 @@ class MM1(MutableMapping[str, str], collections_abc.MutableMapping):
654654
class MM2(collections_abc.MutableMapping, MutableMapping[str, str]):
655655
pass
656656

657+
def test_orig_bases(self):
658+
T = TypeVar('T')
659+
class C(typing.Dict[str, T]): ...
660+
self.assertEqual(C.__orig_bases__, (typing.Dict[str, T],))
661+
662+
def test_naive_runtime_checks(self):
663+
def naive_dict_check(obj, tp):
664+
# Check if a dictionary conforms to Dict type
665+
if len(tp.__parameters__) > 0:
666+
raise NotImplementedError
667+
if tp.__args__:
668+
KT, VT = tp.__args__
669+
return all(isinstance(k, KT) and isinstance(v, VT)
670+
for k, v in obj.items())
671+
self.assertTrue(naive_dict_check({'x': 1}, typing.Dict[str, int]))
672+
self.assertFalse(naive_dict_check({1: 'x'}, typing.Dict[str, int]))
673+
with self.assertRaises(NotImplementedError):
674+
naive_dict_check({1: 'x'}, typing.Dict[str, T])
675+
676+
def naive_generic_check(obj, tp):
677+
# Check if an instance conforms to the generic class
678+
if not hasattr(obj, '__orig_class__'):
679+
raise NotImplementedError
680+
return obj.__orig_class__ == tp
681+
class Node(Generic[T]): ...
682+
self.assertTrue(naive_generic_check(Node[int](), Node[int]))
683+
self.assertFalse(naive_generic_check(Node[str](), Node[int]))
684+
self.assertFalse(naive_generic_check(Node[str](), List))
685+
with self.assertRaises(NotImplementedError):
686+
naive_generic_check([1,2,3], Node[int])
687+
688+
def naive_list_base_check(obj, tp):
689+
# Check if list conforms to a List subclass
690+
return all(isinstance(x, tp.__orig_bases__[0].__args__[0])
691+
for x in obj)
692+
class C(List[int]): ...
693+
self.assertTrue(naive_list_base_check([1, 2, 3], C))
694+
self.assertFalse(naive_list_base_check(['a', 'b'], C))
695+
696+
def test_multi_subscr_base(self):
697+
T = TypeVar('T')
698+
U = TypeVar('U')
699+
V = TypeVar('V')
700+
class C(List[T][U][V]): ...
701+
class D(C, List[T][U][V]): ...
702+
self.assertEqual(C.__parameters__, (V,))
703+
self.assertEqual(D.__parameters__, (V,))
704+
self.assertEqual(C[int].__parameters__, ())
705+
self.assertEqual(D[int].__parameters__, ())
706+
self.assertEqual(C[int].__args__, (int,))
707+
self.assertEqual(D[int].__args__, (int,))
708+
self.assertEqual(C.__bases__, (List,))
709+
self.assertEqual(D.__bases__, (C, List))
710+
self.assertEqual(C.__orig_bases__, (List[T][U][V],))
711+
self.assertEqual(D.__orig_bases__, (C, List[T][U][V]))
712+
713+
657714
def test_pickle(self):
658715
global C # pickle wants to reference the class by name
659716
T = TypeVar('T')

src/typing.py

+22-7
Original file line numberDiff line numberDiff line change
@@ -959,11 +959,7 @@ class GenericMeta(TypingMeta, abc.ABCMeta):
959959
"""Metaclass for generic types."""
960960

961961
def __new__(cls, name, bases, namespace,
962-
tvars=None, args=None, origin=None, extra=None):
963-
if extra is not None and type(extra) is abc.ABCMeta and extra not in bases:
964-
bases = (extra,) + bases
965-
self = super().__new__(cls, name, bases, namespace, _root=True)
966-
962+
tvars=None, args=None, origin=None, extra=None, orig_bases=None):
967963
if tvars is not None:
968964
# Called from __getitem__() below.
969965
assert origin is not None
@@ -1004,12 +1000,25 @@ def __new__(cls, name, bases, namespace,
10041000
", ".join(str(g) for g in gvars)))
10051001
tvars = gvars
10061002

1003+
initial_bases = bases
1004+
if extra is not None and type(extra) is abc.ABCMeta and extra not in bases:
1005+
bases = (extra,) + bases
1006+
bases = tuple(_gorg(b) if isinstance(b, GenericMeta) else b for b in bases)
1007+
1008+
# remove bare Generic from bases if there are other generic bases
1009+
if any(isinstance(b, GenericMeta) and b is not Generic for b in bases):
1010+
bases = tuple(b for b in bases if b is not Generic)
1011+
self = super().__new__(cls, name, bases, namespace, _root=True)
1012+
10071013
self.__parameters__ = tvars
10081014
self.__args__ = args
10091015
self.__origin__ = origin
10101016
self.__extra__ = extra
10111017
# Speed hack (https://github.com/python/typing/issues/196).
10121018
self.__next_in_mro__ = _next_in_mro(self)
1019+
# Preserve base classes on subclassing (__bases__ are type erased now).
1020+
if orig_bases is None:
1021+
self.__orig_bases__ = initial_bases
10131022

10141023
# This allows unparameterized generic collections to be used
10151024
# with issubclass() and isinstance() in the same way as their
@@ -1104,12 +1113,13 @@ def __getitem__(self, params):
11041113
tvars = _type_vars(params)
11051114
args = params
11061115
return self.__class__(self.__name__,
1107-
(self,) + self.__bases__,
1116+
self.__bases__,
11081117
dict(self.__dict__),
11091118
tvars=tvars,
11101119
args=args,
11111120
origin=self,
1112-
extra=self.__extra__)
1121+
extra=self.__extra__,
1122+
orig_bases=self.__orig_bases__)
11131123

11141124
def __instancecheck__(self, instance):
11151125
# Since we extend ABC.__subclasscheck__ and
@@ -1153,6 +1163,10 @@ def __new__(cls, *args, **kwds):
11531163
else:
11541164
origin = _gorg(cls)
11551165
obj = cls.__next_in_mro__.__new__(origin)
1166+
try:
1167+
obj.__orig_class__ = cls
1168+
except AttributeError:
1169+
pass
11561170
obj.__init__(*args, **kwds)
11571171
return obj
11581172

@@ -1521,6 +1535,7 @@ def _get_protocol_attrs(self):
15211535
attr != '__next_in_mro__' and
15221536
attr != '__parameters__' and
15231537
attr != '__origin__' and
1538+
attr != '__orig_bases__' and
15241539
attr != '__extra__' and
15251540
attr != '__module__'):
15261541
attrs.add(attr)

0 commit comments

Comments
 (0)