diff --git a/gql/client.py b/gql/client.py index 30399eb8..e750c63c 100644 --- a/gql/client.py +++ b/gql/client.py @@ -356,13 +356,11 @@ async def _subscribe( # before a break if python version is too old (pypy3 py 3.6.1) self._generator = inner_generator - async for result in inner_generator: - if result.errors: - # Note: we need to run generator.aclose() here or the finally block in - # transport.subscribe will not be reached in pypy3 (py 3.6.1) - await inner_generator.aclose() - - yield result + try: + async for result in inner_generator: + yield result + finally: + await inner_generator.aclose() async def subscribe( self, document: DocumentNode, *args, **kwargs @@ -372,17 +370,24 @@ async def subscribe( The extra arguments are passed to the transport subscribe method.""" - # Validate and subscribe on the transport - async for result in self._subscribe(document, *args, **kwargs): - - # Raise an error if an error is returned in the ExecutionResult object - if result.errors: - raise TransportQueryError( - str(result.errors[0]), errors=result.errors, data=result.data - ) + inner_generator: AsyncGenerator[ExecutionResult, None] = self._subscribe( + document, *args, **kwargs + ) - elif result.data is not None: - yield result.data + try: + # Validate and subscribe on the transport + async for result in inner_generator: + + # Raise an error if an error is returned in the ExecutionResult object + if result.errors: + raise TransportQueryError( + str(result.errors[0]), errors=result.errors, data=result.data + ) + + elif result.data is not None: + yield result.data + finally: + await inner_generator.aclose() async def _execute( self, document: DocumentNode, *args, **kwargs diff --git a/tests/test_websocket_subscription.py b/tests/test_websocket_subscription.py index fcd176b5..7d87ee81 100644 --- a/tests/test_websocket_subscription.py +++ b/tests/test_websocket_subscription.py @@ -163,7 +163,8 @@ async def test_websocket_subscription_break( if count <= 5: # Note: the following line is only necessary for pypy3 v3.6.1 - await session._generator.aclose() + if sys.version_info < (3, 7): + await session._generator.aclose() break count -= 1