Skip to content

Saving http response headers reference in transports #293

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
2 changes: 2 additions & 0 deletions docs/usage/headers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@ If you want to add additional http headers for your connection, you can specify
.. code-block:: python

transport = AIOHTTPTransport(url='YOUR_URL', headers={'Authorization': 'token'})

After the connection, the latest response headers can be found in :code:`transport.response_headers`
5 changes: 5 additions & 0 deletions gql/transport/aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from aiohttp.helpers import BasicAuth
from aiohttp.typedefs import LooseCookies, LooseHeaders
from graphql import DocumentNode, ExecutionResult, print_ast
from multidict import CIMultiDictProxy

from ..utils import extract_files
from .appsync_auth import AppSyncAuthentication
Expand Down Expand Up @@ -75,6 +76,7 @@ def __init__(
self.ssl_close_timeout: Optional[Union[int, float]] = ssl_close_timeout
self.client_session_args = client_session_args
self.session: Optional[aiohttp.ClientSession] = None
self.response_headers: Optional[CIMultiDictProxy[str]]

async def connect(self) -> None:
"""Coroutine which will create an aiohttp ClientSession() as self.session.
Expand Down Expand Up @@ -311,6 +313,9 @@ async def raise_response_error(resp: aiohttp.ClientResponse, reason: str):
if "errors" not in result and "data" not in result:
await raise_response_error(resp, 'No "data" or "errors" keys in answer')

# Saving latest response headers in the transport
self.response_headers = resp.headers

return ExecutionResult(
errors=result.get("errors"),
data=result.get("data"),
Expand Down
3 changes: 3 additions & 0 deletions gql/transport/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def __init__(

self.session = None

self.response_headers = None

def connect(self):

if self.session is None:
Expand Down Expand Up @@ -217,6 +219,7 @@ def execute( # type: ignore
response = self.session.request(
self.method, self.url, **post_args # type: ignore
)
self.response_headers = response.headers

def raise_response_error(resp: requests.Response, reason: str):
# We raise a TransportServerError if the status code is 400 or higher
Expand Down
6 changes: 5 additions & 1 deletion gql/transport/websockets_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import websockets
from graphql import DocumentNode, ExecutionResult
from websockets.client import WebSocketClientProtocol
from websockets.datastructures import HeadersLike
from websockets.datastructures import Headers, HeadersLike
from websockets.exceptions import ConnectionClosed
from websockets.typing import Data, Subprotocol

Expand Down Expand Up @@ -169,6 +169,8 @@ def __init__(
# The list of supported subprotocols should be defined in the subclass
self.supported_subprotocols: List[Subprotocol] = []

self.response_headers: Optional[Headers] = None

async def _initialize(self):
"""Hook to send the initialization messages after the connection
and potentially wait for the backend ack.
Expand Down Expand Up @@ -495,6 +497,8 @@ async def connect(self) -> None:

self.websocket = cast(WebSocketClientProtocol, self.websocket)

self.response_headers = self.websocket.response_headers

# Run the after_connect hook of the subclass
await self._after_connect()

Expand Down
3 changes: 3 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ async def start(self, handler, extra_serve_args=None):
self.testcert, ssl_context = get_localhost_ssl_context()
extra_serve_args["ssl"] = ssl_context

# Adding dummy response headers
extra_serve_args["extra_headers"] = {"dummy": "test1234"}

# Start a server with a random open port
self.start_server = websockets.server.serve(
handler, "127.0.0.1", 0, **extra_serve_args
Expand Down
Loading