Skip to content

Commit 46a25b7

Browse files
authored
Merge pull request #3053 from Matiiss/matiiss-allow-sprite-group-subscripts
Add runtime support for `pygame.sprite.AbstractGroup` subscripts
2 parents 71d8b23 + e3816c5 commit 46a25b7

File tree

3 files changed

+18
-0
lines changed

3 files changed

+18
-0
lines changed

buildconfig/stubs/pygame/sprite.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import types
12
from collections.abc import Callable, Iterable, Iterator
23
from typing import (
34
Any,
@@ -144,6 +145,7 @@ _TDirtySprite = TypeVar("_TDirtySprite", bound=_DirtySpriteSupportsGroup)
144145
class AbstractGroup(Generic[_TSprite]):
145146
spritedict: dict[_TSprite, Optional[Union[FRect, Rect]]]
146147
lostsprites: list[Union[FRect, Rect]]
148+
def __class_getitem__(cls, generic: Any) -> types.GenericAlias: ...
147149
def __init__(self) -> None: ...
148150
def __len__(self) -> int: ...
149151
def __iter__(self) -> Iterator[_TSprite]: ...

src_py/sprite.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
# specific ones that aren't quite so general but fit into common
8585
# specialized cases.
8686

87+
import types
8788
from warnings import warn
8889
from typing import Optional
8990

@@ -371,6 +372,9 @@ class AbstractGroup:
371372
372373
"""
373374

375+
def __class_getitem__(cls, generic):
376+
return types.GenericAlias(cls, generic)
377+
374378
# protected identifier value to identify sprite groups, and avoid infinite recursion
375379
_spritegroup = True
376380

test/sprite_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#################################### IMPORTS ###################################
22

33

4+
import types
5+
import typing
46
import unittest
57

68
import pygame
@@ -660,6 +662,16 @@ def update(self, *args, **kwargs):
660662
self.assertEqual(test_sprite.sink, [1, 2, 3])
661663
self.assertEqual(test_sprite.sink_kwargs, {"foo": 4, "bar": 5})
662664

665+
def test_type_subscript(self):
666+
try:
667+
group_generic_alias = sprite.Group[sprite.Sprite]
668+
except TypeError as e:
669+
self.fail(e)
670+
671+
self.assertIsInstance(group_generic_alias, types.GenericAlias)
672+
self.assertIs(typing.get_origin(group_generic_alias), sprite.Group)
673+
self.assertEqual(typing.get_args(group_generic_alias), (sprite.Sprite,))
674+
663675

664676
################################################################################
665677

0 commit comments

Comments
 (0)