diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index bdd53852..7d71f3a9 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -1,3 +1,4 @@ +import json from ssl import SSLContext from typing import Any, AsyncGenerator, Dict, Optional, Union @@ -8,6 +9,7 @@ from aiohttp.typedefs import LooseCookies, LooseHeaders from graphql import DocumentNode, ExecutionResult, print_ast +from ..utils import extract_files from .async_transport import AsyncTransport from .exceptions import ( TransportAlreadyConnected, @@ -33,7 +35,7 @@ def __init__( auth: Optional[BasicAuth] = None, ssl: Union[SSLContext, bool, Fingerprint] = False, timeout: Optional[int] = None, - client_session_args: Dict[str, Any] = {}, + client_session_args: Optional[Dict[str, Any]] = None, ) -> None: """Initialize the transport with the given aiohttp parameters. @@ -51,7 +53,6 @@ def __init__( self.ssl: Union[SSLContext, bool, Fingerprint] = ssl self.timeout: Optional[int] = timeout self.client_session_args = client_session_args - self.session: Optional[aiohttp.ClientSession] = None async def connect(self) -> None: @@ -76,7 +77,8 @@ async def connect(self) -> None: ) # Adding custom parameters passed from init - client_session_args.update(self.client_session_args) + if self.client_session_args: + client_session_args.update(self.client_session_args) # type: ignore self.session = aiohttp.ClientSession(**client_session_args) @@ -93,7 +95,7 @@ async def execute( document: DocumentNode, variable_values: Optional[Dict[str, str]] = None, operation_name: Optional[str] = None, - extra_args: Dict[str, Any] = {}, + extra_args: Dict[str, Any] = None, ) -> ExecutionResult: """Execute the provided document AST against the configured remote server. This uses the aiohttp library to perform a HTTP POST request asynchronously @@ -103,21 +105,46 @@ async def execute( """ query_str = print_ast(document) + + nulled_variable_values = None + files = None + if variable_values: + nulled_variable_values, files = extract_files(variable_values) + payload: Dict[str, Any] = { "query": query_str, } - if variable_values: - payload["variables"] = variable_values + if nulled_variable_values: + payload["variables"] = nulled_variable_values if operation_name: payload["operationName"] = operation_name - post_args = { - "json": payload, - } + if files: + data = aiohttp.FormData() + + # header + file_map = {str(i): [path] for i, path in enumerate(files)} + # path is nested in a list because the spec allows multiple pointers + # to the same file. But we don't use that. + file_streams = { + str(i): files[path] for i, path in enumerate(files) + } # payload + + data.add_field( + "operations", json.dumps(payload), content_type="application/json" + ) + data.add_field("map", json.dumps(file_map), content_type="application/json") + data.add_fields(*file_streams.items()) + + post_args = {"data": data} + + else: + post_args = {"json": payload} # type: ignore # Pass post_args to aiohttp post method - post_args.update(extra_args) + if extra_args: + post_args.update(extra_args) # type: ignore if self.session is None: raise TransportClosed("Transport is not connected") diff --git a/gql/utils.py b/gql/utils.py index 8f47d97d..ce0318b0 100644 --- a/gql/utils.py +++ b/gql/utils.py @@ -1,5 +1,8 @@ """Utilities to manipulate several python objects.""" +import io +from typing import Any, Dict, Tuple + # From this response in Stackoverflow # http://stackoverflow.com/a/19053800/1072990 @@ -8,3 +11,43 @@ def to_camel_case(snake_str): # We capitalize the first letter of each component except the first one # with the 'title' method and join them together. return components[0] + "".join(x.title() if x else "_" for x in components[1:]) + + +def is_file_like(value: Any) -> bool: + """Check if a value represents a file like object""" + return isinstance(value, io.IOBase) + + +def extract_files(variables: Dict) -> Tuple[Dict, Dict]: + files = {} + + def recurse_extract(path, obj): + """ + recursively traverse obj, doing a deepcopy, but + replacing any file-like objects with nulls and + shunting the originals off to the side. + """ + nonlocal files + if isinstance(obj, list): + nulled_obj = [] + for key, value in enumerate(obj): + value = recurse_extract(f"{path}.{key}", value) + nulled_obj.append(value) + return nulled_obj + elif isinstance(obj, dict): + nulled_obj = {} + for key, value in obj.items(): + value = recurse_extract(f"{path}.{key}", value) + nulled_obj[key] = value + return nulled_obj + elif is_file_like(obj): + # extract obj from its parent and put it into files instead. + files[path] = obj + return None + else: + # base case: pass through unchanged + return obj + + nulled_variables = recurse_extract("variables", variables) + + return nulled_variables, files