diff --git a/src/test_typing.py b/src/test_typing.py index fd2d93c3..9bd150f7 100644 --- a/src/test_typing.py +++ b/src/test_typing.py @@ -1612,6 +1612,25 @@ def __str__(self): def __add__(self, other): return 0 +class Base1(NamedTuple): + x: int + y: int + +class Base2(NamedTuple): + z: int = 0 + +class Derived1(Base1, NamedTuple): + '''Named tuples can have doctrings.''' + label: str + +class Derived2(Base2, Base1, NamedTuple): + def method(self): + return self.x + self.y + +class TwoDefaults(NamedTuple): + x: int = 0 + other: str = '' + async def g_with(am: AsyncContextManager[int]): x: int async with am as x: @@ -1629,7 +1648,7 @@ async def g_with(am: AsyncContextManager[int]): # fake names for the sake of static analysis ann_module = ann_module2 = ann_module3 = None A = B = CSub = G = CoolEmployee = CoolEmployeeWithDefault = object - XMeth = XRepr = NoneAndForward = object + XMeth = XRepr = NoneAndForward = Derived1 = Derived2 = TwoDefaults = object gth = get_type_hints @@ -2313,6 +2332,14 @@ def test_annotation_usage_with_default(self): self.assertEqual(CoolEmployeeWithDefault._fields, ('name', 'cool')) self.assertEqual(CoolEmployeeWithDefault._field_types, dict(name=str, cool=int)) self.assertEqual(CoolEmployeeWithDefault._field_defaults, dict(cool=0)) + self.assertEqual(TwoDefaults.__new__.__defaults__, (0, '')) + self.assertEqual(TwoDefaults().x, 0) + self.assertEqual(TwoDefaults().other, '') + self.assertEqual(TwoDefaults()[0], 0) + self.assertEqual(TwoDefaults()[1], '') + self.assertEqual(Derived2._fields, ('x', 'y', 'z')) + self.assertEqual(Derived2._field_types, dict(x=int, y=int, z=int)) + self.assertEqual(Derived2._field_defaults, dict(z=0)) with self.assertRaises(TypeError): exec(""" @@ -2320,6 +2347,16 @@ class NonDefaultAfterDefault(NamedTuple): x: int = 3 y: int """) + with self.assertRaises(TypeError): + exec(""" +class BadMerged(Base1, Base2, NamedTuple): + pass +""") + with self.assertRaises(TypeError): + exec(""" +class BadExtended(Base2, NamedTuple): + label: str +""") @skipUnless(PY36, 'Python 3.6 required') def test_annotation_usage_with_methods(self): @@ -2327,6 +2364,8 @@ def test_annotation_usage_with_methods(self): self.assertEqual(XMeth(42).x, XMeth(42)[0]) self.assertEqual(str(XRepr(42)), '42 -> 1') self.assertEqual(XRepr(1, 2) + XRepr(3), 0) + self.assertEqual(Derived2(1, 2).method(), 3) + self.assertEqual(Derived2(3, 4, 5).method(), 7) with self.assertRaises(AttributeError): exec(""" @@ -2344,6 +2383,42 @@ def _source(self): return 'no chance for this as well' """) + @skipUnless(PY36, 'Python 3.6 required') + def test_namedtuple_extending(self): + example = Derived1(1, 2, 'test') + with self.assertRaises(TypeError): + Derived1('test') + self.assertIsInstance(example, Derived1) + self.assertIsInstance(example, tuple) + self.assertEqual(example.x, 1) + self.assertEqual(example.y, 2) + self.assertEqual(example.label, 'test') + self.assertEqual(Derived1.__name__, 'Derived1') + self.assertEqual(Derived1._fields, ('x', 'y', 'label')) + self.assertEqual(Derived1.__annotations__, + collections.OrderedDict(x=int, y=int, label=str)) + self.assertIs(Derived1._field_types, Derived1.__annotations__) + + @skipUnless(PY36, 'Python 3.6 required') + def test_namedtuple_merging(self): + example = Derived2(1, 2) + example2 = Derived2(1, 2, 3) + with self.assertRaises(TypeError): + Derived2(1) + with self.assertRaises(TypeError): + Derived2(1, 2, 3, 4) + self.assertIsInstance(example, Derived2) + self.assertIsInstance(example, tuple) + self.assertEqual(example.x, 1) + self.assertEqual(example.y, 2) + self.assertEqual(example.z, 0) + self.assertEqual(example2.z, 3) + self.assertEqual(Derived2.__name__, 'Derived2') + self.assertEqual(Derived2._fields, ('x', 'y', 'z')) + self.assertEqual(Derived2.__annotations__, + collections.OrderedDict(x=int, y=int, z=int)) + self.assertIs(Derived2._field_types, Derived2.__annotations__) + @skipUnless(PY36, 'Python 3.6 required') def test_namedtuple_keyword_usage(self): LocalEmployee = NamedTuple("LocalEmployee", name=str, age=int) diff --git a/src/typing.py b/src/typing.py index c487afcb..60898ae1 100644 --- a/src/typing.py +++ b/src/typing.py @@ -2148,23 +2148,28 @@ def __new__(cls, typename, bases, ns): if not _PY36: raise TypeError("Class syntax for NamedTuple is only supported" " in Python 3.6+") - types = ns.get('__annotations__', {}) + types = collections.OrderedDict() + defaults_ns = {} + # Merge field types and defaults in reverse order (similar to how MRO works). + for base in reversed(bases): + if hasattr(base, '_field_types'): # new style named tuple + types.update(base._field_types) + defaults_ns.update(base._field_defaults) + types.update(ns.get('__annotations__', {})) + defaults_ns.update(ns) nm_tpl = _make_nmtuple(typename, types.items()) - defaults = [] - defaults_dict = {} + defaults_dict = collections.OrderedDict() for field_name in types: - if field_name in ns: - default_value = ns[field_name] - defaults.append(default_value) - defaults_dict[field_name] = default_value - elif defaults: + if field_name in defaults_ns: + defaults_dict[field_name] = defaults_ns[field_name] + elif defaults_dict: raise TypeError("Non-default namedtuple field {field_name} cannot " "follow default field(s) {default_names}" .format(field_name=field_name, default_names=', '.join(defaults_dict.keys()))) - nm_tpl.__new__.__defaults__ = tuple(defaults) + nm_tpl.__new__.__defaults__ = tuple(defaults_dict.values()) nm_tpl._field_defaults = defaults_dict - # update from user namespace without overriding special namedtuple attributes + # Update from user namespace without overriding special namedtuple attributes. for key in ns: if key in _prohibited: raise AttributeError("Cannot overwrite NamedTuple attribute " + key)