Skip to content

Commit 8c74c2c

Browse files
tiangoloKludexadriangb
authored
Merge pull request from GHSA-74m5-2c7w-9w3x
* ♻️ Refactor multipart parser logic to support limiting max fields and files * ✨ Add support for new request.form() parameters max_files and max_fields * ✅ Add tests for limiting max fields and files in form data * 📝 Add docs about request.form() with new parameters max_files and max_fields * 📝 Update `docs/requests.md` Co-authored-by: Marcelo Trylesinski <[email protected]> * 📝 Tweak docs for request.form() * ✏ Fix typo in `starlette/formparsers.py` Co-authored-by: Adrian Garcia Badaracco <[email protected]> --------- Co-authored-by: Marcelo Trylesinski <[email protected]> Co-authored-by: Adrian Garcia Badaracco <[email protected]>
1 parent 5771a78 commit 8c74c2c

File tree

4 files changed

+356
-91
lines changed

4 files changed

+356
-91
lines changed

docs/requests.md

+12
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,18 @@ state with `disconnected = await request.is_disconnected()`.
114114

115115
Request files are normally sent as multipart form data (`multipart/form-data`).
116116

117+
Signature: `request.form(max_files=1000, max_fields=1000)`
118+
119+
You can configure the number of maximum fields or files with the parameters `max_files` and `max_fields`:
120+
121+
```python
122+
async with request.form(max_files=1000, max_fields=1000):
123+
...
124+
```
125+
126+
!!! info
127+
These limits are for security reasons, allowing an unlimited number of fields or files could lead to a denial of service attack by consuming a lot of CPU and memory parsing too many empty fields.
128+
117129
When you call `async with request.form() as form` you receive a `starlette.datastructures.FormData` which is an immutable
118130
multidict, containing both file uploads and text input. File upload items are represented as instances of `starlette.datastructures.UploadFile`.
119131

starlette/formparsers.py

+101-87
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import typing
2+
from dataclasses import dataclass, field
23
from enum import Enum
34
from tempfile import SpooledTemporaryFile
45
from urllib.parse import unquote_plus
@@ -21,15 +22,13 @@ class FormMessage(Enum):
2122
END = 5
2223

2324

24-
class MultiPartMessage(Enum):
25-
PART_BEGIN = 1
26-
PART_DATA = 2
27-
PART_END = 3
28-
HEADER_FIELD = 4
29-
HEADER_VALUE = 5
30-
HEADER_END = 6
31-
HEADERS_FINISHED = 7
32-
END = 8
25+
@dataclass
26+
class MultipartPart:
27+
content_disposition: typing.Optional[bytes] = None
28+
field_name: str = ""
29+
data: bytes = b""
30+
file: typing.Optional[UploadFile] = None
31+
item_headers: typing.List[typing.Tuple[bytes, bytes]] = field(default_factory=list)
3332

3433

3534
def _user_safe_decode(src: bytes, codec: str) -> str:
@@ -120,53 +119,115 @@ class MultiPartParser:
120119
max_file_size = 1024 * 1024
121120

122121
def __init__(
123-
self, headers: Headers, stream: typing.AsyncGenerator[bytes, None]
122+
self,
123+
headers: Headers,
124+
stream: typing.AsyncGenerator[bytes, None],
125+
*,
126+
max_files: typing.Union[int, float] = 1000,
127+
max_fields: typing.Union[int, float] = 1000,
124128
) -> None:
125129
assert (
126130
multipart is not None
127131
), "The `python-multipart` library must be installed to use form parsing."
128132
self.headers = headers
129133
self.stream = stream
130-
self.messages: typing.List[typing.Tuple[MultiPartMessage, bytes]] = []
134+
self.max_files = max_files
135+
self.max_fields = max_fields
136+
self.items: typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]] = []
137+
self._current_files = 0
138+
self._current_fields = 0
139+
self._current_partial_header_name: bytes = b""
140+
self._current_partial_header_value: bytes = b""
141+
self._current_part = MultipartPart()
142+
self._charset = ""
143+
self._file_parts_to_write: typing.List[typing.Tuple[MultipartPart, bytes]] = []
144+
self._file_parts_to_finish: typing.List[MultipartPart] = []
131145

132146
def on_part_begin(self) -> None:
133-
message = (MultiPartMessage.PART_BEGIN, b"")
134-
self.messages.append(message)
147+
self._current_part = MultipartPart()
135148

136149
def on_part_data(self, data: bytes, start: int, end: int) -> None:
137-
message = (MultiPartMessage.PART_DATA, data[start:end])
138-
self.messages.append(message)
150+
message_bytes = data[start:end]
151+
if self._current_part.file is None:
152+
self._current_part.data += message_bytes
153+
else:
154+
self._file_parts_to_write.append((self._current_part, message_bytes))
139155

140156
def on_part_end(self) -> None:
141-
message = (MultiPartMessage.PART_END, b"")
142-
self.messages.append(message)
157+
if self._current_part.file is None:
158+
self.items.append(
159+
(
160+
self._current_part.field_name,
161+
_user_safe_decode(self._current_part.data, self._charset),
162+
)
163+
)
164+
else:
165+
self._file_parts_to_finish.append(self._current_part)
166+
# The file can be added to the items right now even though it's not
167+
# finished yet, because it will be finished in the `parse()` method, before
168+
# self.items is used in the return value.
169+
self.items.append((self._current_part.field_name, self._current_part.file))
143170

144171
def on_header_field(self, data: bytes, start: int, end: int) -> None:
145-
message = (MultiPartMessage.HEADER_FIELD, data[start:end])
146-
self.messages.append(message)
172+
self._current_partial_header_name += data[start:end]
147173

148174
def on_header_value(self, data: bytes, start: int, end: int) -> None:
149-
message = (MultiPartMessage.HEADER_VALUE, data[start:end])
150-
self.messages.append(message)
175+
self._current_partial_header_value += data[start:end]
151176

152177
def on_header_end(self) -> None:
153-
message = (MultiPartMessage.HEADER_END, b"")
154-
self.messages.append(message)
178+
field = self._current_partial_header_name.lower()
179+
if field == b"content-disposition":
180+
self._current_part.content_disposition = self._current_partial_header_value
181+
self._current_part.item_headers.append(
182+
(field, self._current_partial_header_value)
183+
)
184+
self._current_partial_header_name = b""
185+
self._current_partial_header_value = b""
155186

156187
def on_headers_finished(self) -> None:
157-
message = (MultiPartMessage.HEADERS_FINISHED, b"")
158-
self.messages.append(message)
188+
disposition, options = parse_options_header(
189+
self._current_part.content_disposition
190+
)
191+
try:
192+
self._current_part.field_name = _user_safe_decode(
193+
options[b"name"], self._charset
194+
)
195+
except KeyError:
196+
raise MultiPartException(
197+
'The Content-Disposition header field "name" must be ' "provided."
198+
)
199+
if b"filename" in options:
200+
self._current_files += 1
201+
if self._current_files > self.max_files:
202+
raise MultiPartException(
203+
f"Too many files. Maximum number of files is {self.max_files}."
204+
)
205+
filename = _user_safe_decode(options[b"filename"], self._charset)
206+
tempfile = SpooledTemporaryFile(max_size=self.max_file_size)
207+
self._current_part.file = UploadFile(
208+
file=tempfile, # type: ignore[arg-type]
209+
size=0,
210+
filename=filename,
211+
headers=Headers(raw=self._current_part.item_headers),
212+
)
213+
else:
214+
self._current_fields += 1
215+
if self._current_fields > self.max_fields:
216+
raise MultiPartException(
217+
f"Too many fields. Maximum number of fields is {self.max_fields}."
218+
)
219+
self._current_part.file = None
159220

160221
def on_end(self) -> None:
161-
message = (MultiPartMessage.END, b"")
162-
self.messages.append(message)
222+
pass
163223

164224
async def parse(self) -> FormData:
165225
# Parse the Content-Type header to get the multipart boundary.
166226
_, params = parse_options_header(self.headers["Content-Type"])
167227
charset = params.get(b"charset", "utf-8")
168228
if type(charset) == bytes:
169229
charset = charset.decode("latin-1")
230+
self._charset = charset
170231
try:
171232
boundary = params[b"boundary"]
172233
except KeyError:
@@ -186,68 +247,21 @@ async def parse(self) -> FormData:
186247

187248
# Create the parser.
188249
parser = multipart.MultipartParser(boundary, callbacks)
189-
header_field = b""
190-
header_value = b""
191-
content_disposition = None
192-
field_name = ""
193-
data = b""
194-
file: typing.Optional[UploadFile] = None
195-
196-
items: typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]] = []
197-
item_headers: typing.List[typing.Tuple[bytes, bytes]] = []
198-
199250
# Feed the parser with data from the request.
200251
async for chunk in self.stream:
201252
parser.write(chunk)
202-
messages = list(self.messages)
203-
self.messages.clear()
204-
for message_type, message_bytes in messages:
205-
if message_type == MultiPartMessage.PART_BEGIN:
206-
content_disposition = None
207-
data = b""
208-
item_headers = []
209-
elif message_type == MultiPartMessage.HEADER_FIELD:
210-
header_field += message_bytes
211-
elif message_type == MultiPartMessage.HEADER_VALUE:
212-
header_value += message_bytes
213-
elif message_type == MultiPartMessage.HEADER_END:
214-
field = header_field.lower()
215-
if field == b"content-disposition":
216-
content_disposition = header_value
217-
item_headers.append((field, header_value))
218-
header_field = b""
219-
header_value = b""
220-
elif message_type == MultiPartMessage.HEADERS_FINISHED:
221-
disposition, options = parse_options_header(content_disposition)
222-
try:
223-
field_name = _user_safe_decode(options[b"name"], charset)
224-
except KeyError:
225-
raise MultiPartException(
226-
'The Content-Disposition header field "name" must be '
227-
"provided."
228-
)
229-
if b"filename" in options:
230-
filename = _user_safe_decode(options[b"filename"], charset)
231-
tempfile = SpooledTemporaryFile(max_size=self.max_file_size)
232-
file = UploadFile(
233-
file=tempfile, # type: ignore[arg-type]
234-
size=0,
235-
filename=filename,
236-
headers=Headers(raw=item_headers),
237-
)
238-
else:
239-
file = None
240-
elif message_type == MultiPartMessage.PART_DATA:
241-
if file is None:
242-
data += message_bytes
243-
else:
244-
await file.write(message_bytes)
245-
elif message_type == MultiPartMessage.PART_END:
246-
if file is None:
247-
items.append((field_name, _user_safe_decode(data, charset)))
248-
else:
249-
await file.seek(0)
250-
items.append((field_name, file))
253+
# Write file data, it needs to use await with the UploadFile methods that
254+
# call the corresponding file methods *in a threadpool*, otherwise, if
255+
# they were called directly in the callback methods above (regular,
256+
# non-async functions), that would block the event loop in the main thread.
257+
for part, data in self._file_parts_to_write:
258+
assert part.file # for type checkers
259+
await part.file.write(data)
260+
for part in self._file_parts_to_finish:
261+
assert part.file # for type checkers
262+
await part.file.seek(0)
263+
self._file_parts_to_write.clear()
264+
self._file_parts_to_finish.clear()
251265

252266
parser.finalize()
253-
return FormData(items)
267+
return FormData(self.items)

starlette/requests.py

+21-4
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,12 @@ async def json(self) -> typing.Any:
244244
self._json = json.loads(body)
245245
return self._json
246246

247-
async def _get_form(self) -> FormData:
247+
async def _get_form(
248+
self,
249+
*,
250+
max_files: typing.Union[int, float] = 1000,
251+
max_fields: typing.Union[int, float] = 1000,
252+
) -> FormData:
248253
if self._form is None:
249254
assert (
250255
parse_options_header is not None
@@ -254,7 +259,12 @@ async def _get_form(self) -> FormData:
254259
content_type, _ = parse_options_header(content_type_header)
255260
if content_type == b"multipart/form-data":
256261
try:
257-
multipart_parser = MultiPartParser(self.headers, self.stream())
262+
multipart_parser = MultiPartParser(
263+
self.headers,
264+
self.stream(),
265+
max_files=max_files,
266+
max_fields=max_fields,
267+
)
258268
self._form = await multipart_parser.parse()
259269
except MultiPartException as exc:
260270
if "app" in self.scope:
@@ -267,8 +277,15 @@ async def _get_form(self) -> FormData:
267277
self._form = FormData()
268278
return self._form
269279

270-
def form(self) -> AwaitableOrContextManager[FormData]:
271-
return AwaitableOrContextManagerWrapper(self._get_form())
280+
def form(
281+
self,
282+
*,
283+
max_files: typing.Union[int, float] = 1000,
284+
max_fields: typing.Union[int, float] = 1000,
285+
) -> AwaitableOrContextManager[FormData]:
286+
return AwaitableOrContextManagerWrapper(
287+
self._get_form(max_files=max_files, max_fields=max_fields)
288+
)
272289

273290
async def close(self) -> None:
274291
if self._form is not None:

0 commit comments

Comments
 (0)