Skip to content

Commit 3db7a86

Browse files
authored
Add upload_files functionality for requests transport (#244)
1 parent af8f223 commit 3db7a86

File tree

4 files changed

+487
-5
lines changed

4 files changed

+487
-5
lines changed

docs/usage/file_upload.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ File uploads
22
============
33

44
GQL supports file uploads with the :ref:`aiohttp transport <aiohttp_transport>`
5+
and the :ref:`requests transport <requests_transport>`
56
using the `GraphQL multipart request spec`_.
67

78
.. _GraphQL multipart request spec: https://github.com/jaydenseric/graphql-multipart-request-spec
@@ -18,6 +19,7 @@ In order to upload a single file, you need to:
1819
.. code-block:: python
1920
2021
transport = AIOHTTPTransport(url='YOUR_URL')
22+
# Or transport = RequestsHTTPTransport(url='YOUR_URL')
2123
2224
client = Client(transport=transport)
2325
@@ -45,6 +47,7 @@ It is also possible to upload multiple files using a list.
4547
.. code-block:: python
4648
4749
transport = AIOHTTPTransport(url='YOUR_URL')
50+
# Or transport = RequestsHTTPTransport(url='YOUR_URL')
4851
4952
client = Client(transport=transport)
5053
@@ -84,6 +87,9 @@ We provide methods to do that for two different uses cases:
8487
* Sending local files
8588
* Streaming downloaded files from an external URL to the GraphQL API
8689

90+
.. note::
91+
Streaming is only supported with the :ref:`aiohttp transport <aiohttp_transport>`
92+
8793
Streaming local files
8894
^^^^^^^^^^^^^^^^^^^^^
8995

gql/transport/requests.py

Lines changed: 68 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
1+
import io
12
import json
23
import logging
3-
from typing import Any, Dict, Optional, Union
4+
from typing import Any, Dict, Optional, Tuple, Type, Union
45

56
import requests
67
from graphql import DocumentNode, ExecutionResult, print_ast
78
from requests.adapters import HTTPAdapter, Retry
89
from requests.auth import AuthBase
910
from requests.cookies import RequestsCookieJar
11+
from requests_toolbelt.multipart.encoder import MultipartEncoder
1012

1113
from gql.transport import Transport
1214

15+
from ..utils import extract_files
1316
from .exceptions import (
1417
TransportAlreadyConnected,
1518
TransportClosed,
@@ -27,6 +30,8 @@ class RequestsHTTPTransport(Transport):
2730
The transport uses the requests library to send HTTP POST requests.
2831
"""
2932

33+
file_classes: Tuple[Type[Any], ...] = (io.IOBase,)
34+
3035
def __init__(
3136
self,
3237
url: str,
@@ -104,6 +109,7 @@ def execute( # type: ignore
104109
operation_name: Optional[str] = None,
105110
timeout: Optional[int] = None,
106111
extra_args: Dict[str, Any] = None,
112+
upload_files: bool = False,
107113
) -> ExecutionResult:
108114
"""Execute GraphQL query.
109115
@@ -116,6 +122,7 @@ def execute( # type: ignore
116122
Only required in multi-operation documents (Default: None).
117123
:param timeout: Specifies a default timeout for requests (Default: None).
118124
:param extra_args: additional arguments to send to the requests post method
125+
:param upload_files: Set to True if you want to put files in the variable values
119126
:return: The result of execution.
120127
`data` is the result of executing the query, `errors` is null
121128
if no errors occurred, and is a non-empty array if an error occurred.
@@ -126,21 +133,77 @@ def execute( # type: ignore
126133

127134
query_str = print_ast(document)
128135
payload: Dict[str, Any] = {"query": query_str}
129-
if variable_values:
130-
payload["variables"] = variable_values
136+
131137
if operation_name:
132138
payload["operationName"] = operation_name
133139

134-
data_key = "json" if self.use_json else "data"
135140
post_args = {
136141
"headers": self.headers,
137142
"auth": self.auth,
138143
"cookies": self.cookies,
139144
"timeout": timeout or self.default_timeout,
140145
"verify": self.verify,
141-
data_key: payload,
142146
}
143147

148+
if upload_files:
149+
# If the upload_files flag is set, then we need variable_values
150+
assert variable_values is not None
151+
152+
# If we upload files, we will extract the files present in the
153+
# variable_values dict and replace them by null values
154+
nulled_variable_values, files = extract_files(
155+
variables=variable_values, file_classes=self.file_classes,
156+
)
157+
158+
# Save the nulled variable values in the payload
159+
payload["variables"] = nulled_variable_values
160+
161+
# Add the payload to the operations field
162+
operations_str = json.dumps(payload)
163+
log.debug("operations %s", operations_str)
164+
165+
# Generate the file map
166+
# path is nested in a list because the spec allows multiple pointers
167+
# to the same file. But we don't support that.
168+
# Will generate something like {"0": ["variables.file"]}
169+
file_map = {str(i): [path] for i, path in enumerate(files)}
170+
171+
# Enumerate the file streams
172+
# Will generate something like {'0': <_io.BufferedReader ...>}
173+
file_streams = {str(i): files[path] for i, path in enumerate(files)}
174+
175+
# Add the file map field
176+
file_map_str = json.dumps(file_map)
177+
log.debug("file_map %s", file_map_str)
178+
179+
fields = {"operations": operations_str, "map": file_map_str}
180+
181+
# Add the extracted files as remaining fields
182+
for k, v in file_streams.items():
183+
fields[k] = (getattr(v, "name", k), v)
184+
185+
# Prepare requests http to send multipart-encoded data
186+
data = MultipartEncoder(fields=fields)
187+
188+
post_args["data"] = data
189+
190+
if post_args["headers"] is None:
191+
post_args["headers"] = {}
192+
else:
193+
post_args["headers"] = {**post_args["headers"]}
194+
195+
post_args["headers"]["Content-Type"] = data.content_type
196+
197+
else:
198+
if variable_values:
199+
payload["variables"] = variable_values
200+
201+
if log.isEnabledFor(logging.INFO):
202+
log.info(">>> %s", json.dumps(payload))
203+
204+
data_key = "json" if self.use_json else "data"
205+
post_args[data_key] = payload
206+
144207
# Log the payload
145208
if log.isEnabledFor(logging.INFO):
146209
log.info(">>> %s", json.dumps(payload))

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838

3939
install_requests_requires = [
4040
"requests>=2.23,<3",
41+
"requests_toolbelt>=0.9.1,<1",
4142
]
4243

4344
install_websockets_requires = [

0 commit comments

Comments
 (0)