1
1
# Adapted with permission from the EdgeDB project;
2
2
# license: PSFL.
3
3
4
+ import sys
4
5
import gc
5
6
import asyncio
6
7
import contextvars
@@ -28,6 +29,15 @@ def get_error_types(eg):
28
29
return {type (exc ) for exc in eg .exceptions }
29
30
30
31
32
+ def no_other_refs ():
33
+ # due to gh-124392 coroutines now refer to their locals
34
+ coro = asyncio .current_task ().get_coro ()
35
+ frame = sys ._getframe (1 )
36
+ while coro .cr_frame != frame :
37
+ coro = coro .cr_await
38
+ return [coro ]
39
+
40
+
31
41
class TestTaskGroup (unittest .IsolatedAsyncioTestCase ):
32
42
33
43
async def test_taskgroup_01 (self ):
@@ -913,7 +923,7 @@ class _Done(Exception):
913
923
exc = e
914
924
915
925
self .assertIsNotNone (exc )
916
- self .assertListEqual (gc .get_referrers (exc ), [] )
926
+ self .assertListEqual (gc .get_referrers (exc ), no_other_refs () )
917
927
918
928
919
929
async def test_exception_refcycles_errors (self ):
@@ -931,7 +941,7 @@ class _Done(Exception):
931
941
exc = excs .exceptions [0 ]
932
942
933
943
self .assertIsInstance (exc , _Done )
934
- self .assertListEqual (gc .get_referrers (exc ), [] )
944
+ self .assertListEqual (gc .get_referrers (exc ), no_other_refs () )
935
945
936
946
937
947
async def test_exception_refcycles_parent_task (self ):
@@ -953,7 +963,7 @@ async def coro_fn():
953
963
exc = excs .exceptions [0 ].exceptions [0 ]
954
964
955
965
self .assertIsInstance (exc , _Done )
956
- self .assertListEqual (gc .get_referrers (exc ), [] )
966
+ self .assertListEqual (gc .get_referrers (exc ), no_other_refs () )
957
967
958
968
async def test_exception_refcycles_propagate_cancellation_error (self ):
959
969
"""Test that TaskGroup deletes propagate_cancellation_error"""
@@ -968,7 +978,7 @@ async def test_exception_refcycles_propagate_cancellation_error(self):
968
978
exc = e .__cause__
969
979
970
980
self .assertIsInstance (exc , asyncio .CancelledError )
971
- self .assertListEqual (gc .get_referrers (exc ), [] )
981
+ self .assertListEqual (gc .get_referrers (exc ), no_other_refs () )
972
982
973
983
async def test_exception_refcycles_base_error (self ):
974
984
"""Test that TaskGroup deletes self._base_error"""
@@ -985,7 +995,7 @@ class MyKeyboardInterrupt(KeyboardInterrupt):
985
995
exc = e
986
996
987
997
self .assertIsNotNone (exc )
988
- self .assertListEqual (gc .get_referrers (exc ), [] )
998
+ self .assertListEqual (gc .get_referrers (exc ), no_other_refs () )
989
999
990
1000
991
1001
if __name__ == "__main__" :
0 commit comments