Skip to content

A real fix for issue #250 (failure with mock) #295

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Oct 21, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions python2/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,62 @@ class MM1(MutableMapping[str, str], collections_abc.MutableMapping):
class MM2(collections_abc.MutableMapping, MutableMapping[str, str]):
pass

def test_orig_bases(self):
T = TypeVar('T')
class C(typing.Dict[str, T]): pass
self.assertEqual(C.__orig_bases__, (typing.Dict[str, T],))

def test_naive_runtime_checks(self):
def naive_dict_check(obj, tp):
# Check if a dictionary conforms to Dict type
if len(tp.__parameters__) > 0:
raise NotImplementedError
if tp.__args__:
KT, VT = tp.__args__
return all(isinstance(k, KT) and isinstance(v, VT)
for k, v in obj.items())
self.assertTrue(naive_dict_check({'x': 1}, typing.Dict[typing.Text, int]))
self.assertFalse(naive_dict_check({1: 'x'}, typing.Dict[typing.Text, int]))
with self.assertRaises(NotImplementedError):
naive_dict_check({1: 'x'}, typing.Dict[typing.Text, T])

def naive_generic_check(obj, tp):
# Check if an instance conforms to the generic class
if not hasattr(obj, '__orig_class__'):
raise NotImplementedError
return obj.__orig_class__ == tp
class Node(Generic[T]): pass
self.assertTrue(naive_generic_check(Node[int](), Node[int]))
self.assertFalse(naive_generic_check(Node[str](), Node[int]))
self.assertFalse(naive_generic_check(Node[str](), List))
with self.assertRaises(NotImplementedError):
naive_generic_check([1,2,3], Node[int])

def naive_list_base_check(obj, tp):
# Check if list conforms to a List subclass
return all(isinstance(x, tp.__orig_bases__[0].__args__[0])
for x in obj)
class C(List[int]): pass
self.assertTrue(naive_list_base_check([1, 2, 3], C))
self.assertFalse(naive_list_base_check(['a', 'b'], C))

def test_multi_subscr_base(self):
T = TypeVar('T')
U = TypeVar('U')
V = TypeVar('V')
class C(List[T][U][V]): pass
class D(C, List[T][U][V]): pass
self.assertEqual(C.__parameters__, (V,))
self.assertEqual(D.__parameters__, (V,))
self.assertEqual(C[int].__parameters__, ())
self.assertEqual(D[int].__parameters__, ())
self.assertEqual(C[int].__args__, (int,))
self.assertEqual(D[int].__args__, (int,))
self.assertEqual(C.__bases__, (List,))
self.assertEqual(D.__bases__, (C, List))
self.assertEqual(C.__orig_bases__, (List[T][U][V],))
self.assertEqual(D.__orig_bases__, (C, List[T][U][V]))

def test_pickle(self):
global C # pickle wants to reference the class by name
T = TypeVar('T')
Expand Down
33 changes: 24 additions & 9 deletions python2/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,13 +1045,7 @@ class GenericMeta(TypingMeta, abc.ABCMeta):
"""Metaclass for generic types."""

def __new__(cls, name, bases, namespace,
tvars=None, args=None, origin=None, extra=None):
if extra is None:
extra = namespace.get('__extra__')
if extra is not None and type(extra) is abc.ABCMeta and extra not in bases:
bases = (extra,) + bases
self = super(GenericMeta, cls).__new__(cls, name, bases, namespace)

tvars=None, args=None, origin=None, extra=None, orig_bases=None):
if tvars is not None:
# Called from __getitem__() below.
assert origin is not None
Expand Down Expand Up @@ -1092,12 +1086,27 @@ def __new__(cls, name, bases, namespace,
", ".join(str(g) for g in gvars)))
tvars = gvars

initial_bases = bases
if extra is None:
extra = namespace.get('__extra__')
if extra is not None and type(extra) is abc.ABCMeta and extra not in bases:
bases = (extra,) + bases
bases = tuple(_gorg(b) if isinstance(b, GenericMeta) else b for b in bases)

# remove bare Generic from bases if there are other generic bases
if any(isinstance(b, GenericMeta) and b is not Generic for b in bases):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could use a comment explaining what's happening -- it looks like the idea is to remove bare Generic from the bases if there are other generic-derived bases, but to leave generic if if not.

bases = tuple(b for b in bases if b is not Generic)
self = super(GenericMeta, cls).__new__(cls, name, bases, namespace)

self.__parameters__ = tvars
self.__args__ = args
self.__origin__ = origin
self.__extra__ = extra
# Speed hack (https://github.com/python/typing/issues/196).
self.__next_in_mro__ = _next_in_mro(self)
# Preserve base classes on subclassing (__bases__ are type erased now).
if orig_bases is None:
self.__orig_bases__ = initial_bases

# This allows unparameterized generic collections to be used
# with issubclass() and isinstance() in the same way as their
Expand Down Expand Up @@ -1180,12 +1189,13 @@ def __getitem__(self, params):
tvars = _type_vars(params)
args = params
return self.__class__(self.__name__,
(self,) + self.__bases__,
self.__bases__,
dict(self.__dict__),
tvars=tvars,
args=args,
origin=self,
extra=self.__extra__)
extra=self.__extra__,
orig_bases=self.__orig_bases__)

def __instancecheck__(self, instance):
# Since we extend ABC.__subclasscheck__ and
Expand Down Expand Up @@ -1232,6 +1242,10 @@ def __new__(cls, *args, **kwds):
else:
origin = _gorg(cls)
obj = cls.__next_in_mro__.__new__(origin)
try:
obj.__orig_class__ = cls
except AttributeError:
pass
obj.__init__(*args, **kwds)
return obj

Expand Down Expand Up @@ -1402,6 +1416,7 @@ def _get_protocol_attrs(self):
attr != '__next_in_mro__' and
attr != '__parameters__' and
attr != '__origin__' and
attr != '__orig_bases__' and
attr != '__extra__' and
attr != '__module__'):
attrs.add(attr)
Expand Down
57 changes: 57 additions & 0 deletions src/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,63 @@ class MM1(MutableMapping[str, str], collections_abc.MutableMapping):
class MM2(collections_abc.MutableMapping, MutableMapping[str, str]):
pass

def test_orig_bases(self):
T = TypeVar('T')
class C(typing.Dict[str, T]): ...
self.assertEqual(C.__orig_bases__, (typing.Dict[str, T],))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know there are no docs for __parameters__ etc. either, but I would still like to see some kind of example code that takes a class like this and recovers what it means given the various new dunder methods (__origin__, __orig_bases__, __parameters__). E.g. if I had a dict {x: y} how would I check that it's a subclass of Dict[str, T]? (The answer should reveal that the type of x must be str and T must be the type of y.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is quite difficult to write a reasonable runtime type check functions. I just added three naive functions just to illustrate how to use the dunder attributes for typical type checking tasks.

In process of doing this I realized that it is difficult to perform runtime type checks since type erasure happens not only at subclassing, but also at generic class instantiation. In the new commit I added __orig_class__ to instances (if subclass allows this, i.e., it does not use __slots__) that saves a reference to original class before type erasure.
So that now Node[int]().__class__ is Node, while Node[int]().__orig_class__ is Node[int].


def test_naive_runtime_checks(self):
def naive_dict_check(obj, tp):
# Check if a dictionary conforms to Dict type
if len(tp.__parameters__) > 0:
raise NotImplementedError
if tp.__args__:
KT, VT = tp.__args__
return all(isinstance(k, KT) and isinstance(v, VT)
for k, v in obj.items())
self.assertTrue(naive_dict_check({'x': 1}, typing.Dict[str, int]))
self.assertFalse(naive_dict_check({1: 'x'}, typing.Dict[str, int]))
with self.assertRaises(NotImplementedError):
naive_dict_check({1: 'x'}, typing.Dict[str, T])

def naive_generic_check(obj, tp):
# Check if an instance conforms to the generic class
if not hasattr(obj, '__orig_class__'):
raise NotImplementedError
return obj.__orig_class__ == tp
class Node(Generic[T]): ...
self.assertTrue(naive_generic_check(Node[int](), Node[int]))
self.assertFalse(naive_generic_check(Node[str](), Node[int]))
self.assertFalse(naive_generic_check(Node[str](), List))
with self.assertRaises(NotImplementedError):
naive_generic_check([1,2,3], Node[int])

def naive_list_base_check(obj, tp):
# Check if list conforms to a List subclass
return all(isinstance(x, tp.__orig_bases__[0].__args__[0])
for x in obj)
class C(List[int]): ...
self.assertTrue(naive_list_base_check([1, 2, 3], C))
self.assertFalse(naive_list_base_check(['a', 'b'], C))

def test_multi_subscr_base(self):
T = TypeVar('T')
U = TypeVar('U')
V = TypeVar('V')
class C(List[T][U][V]): ...
class D(C, List[T][U][V]): ...
self.assertEqual(C.__parameters__, (V,))
self.assertEqual(D.__parameters__, (V,))
self.assertEqual(C[int].__parameters__, ())
self.assertEqual(D[int].__parameters__, ())
self.assertEqual(C[int].__args__, (int,))
self.assertEqual(D[int].__args__, (int,))
self.assertEqual(C.__bases__, (List,))
self.assertEqual(D.__bases__, (C, List))
self.assertEqual(C.__orig_bases__, (List[T][U][V],))
self.assertEqual(D.__orig_bases__, (C, List[T][U][V]))


def test_pickle(self):
global C # pickle wants to reference the class by name
T = TypeVar('T')
Expand Down
29 changes: 22 additions & 7 deletions src/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,11 +938,7 @@ class GenericMeta(TypingMeta, abc.ABCMeta):
"""Metaclass for generic types."""

def __new__(cls, name, bases, namespace,
tvars=None, args=None, origin=None, extra=None):
if extra is not None and type(extra) is abc.ABCMeta and extra not in bases:
bases = (extra,) + bases
self = super().__new__(cls, name, bases, namespace, _root=True)

tvars=None, args=None, origin=None, extra=None, orig_bases=None):
if tvars is not None:
# Called from __getitem__() below.
assert origin is not None
Expand Down Expand Up @@ -983,12 +979,25 @@ def __new__(cls, name, bases, namespace,
", ".join(str(g) for g in gvars)))
tvars = gvars

initial_bases = bases
if extra is not None and type(extra) is abc.ABCMeta and extra not in bases:
bases = (extra,) + bases
bases = tuple(_gorg(b) if isinstance(b, GenericMeta) else b for b in bases)

# remove bare Generic from bases if there are other generic bases
if any(isinstance(b, GenericMeta) and b is not Generic for b in bases):
bases = tuple(b for b in bases if b is not Generic)
self = super().__new__(cls, name, bases, namespace, _root=True)

self.__parameters__ = tvars
self.__args__ = args
self.__origin__ = origin
self.__extra__ = extra
# Speed hack (https://github.com/python/typing/issues/196).
self.__next_in_mro__ = _next_in_mro(self)
# Preserve base classes on subclassing (__bases__ are type erased now).
if orig_bases is None:
self.__orig_bases__ = initial_bases

# This allows unparameterized generic collections to be used
# with issubclass() and isinstance() in the same way as their
Expand Down Expand Up @@ -1071,12 +1080,13 @@ def __getitem__(self, params):
tvars = _type_vars(params)
args = params
return self.__class__(self.__name__,
(self,) + self.__bases__,
self.__bases__,
dict(self.__dict__),
tvars=tvars,
args=args,
origin=self,
extra=self.__extra__)
extra=self.__extra__,
orig_bases=self.__orig_bases__)

def __instancecheck__(self, instance):
# Since we extend ABC.__subclasscheck__ and
Expand Down Expand Up @@ -1120,6 +1130,10 @@ def __new__(cls, *args, **kwds):
else:
origin = _gorg(cls)
obj = cls.__next_in_mro__.__new__(origin)
try:
obj.__orig_class__ = cls
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'd just catch AttributeError here -- the "dict in dict" idiom feels obscure.

except AttributeError:
pass
obj.__init__(*args, **kwds)
return obj

Expand Down Expand Up @@ -1485,6 +1499,7 @@ def _get_protocol_attrs(self):
attr != '__next_in_mro__' and
attr != '__parameters__' and
attr != '__origin__' and
attr != '__orig_bases__' and
attr != '__extra__' and
attr != '__module__'):
attrs.add(attr)
Expand Down