1
1
import typing
2
+ from dataclasses import dataclass , field
2
3
from enum import Enum
3
4
from tempfile import SpooledTemporaryFile
4
5
from urllib .parse import unquote_plus
@@ -21,15 +22,13 @@ class FormMessage(Enum):
21
22
END = 5
22
23
23
24
24
- class MultiPartMessage (Enum ):
25
- PART_BEGIN = 1
26
- PART_DATA = 2
27
- PART_END = 3
28
- HEADER_FIELD = 4
29
- HEADER_VALUE = 5
30
- HEADER_END = 6
31
- HEADERS_FINISHED = 7
32
- END = 8
25
+ @dataclass
26
+ class MultipartPart :
27
+ content_disposition : typing .Optional [bytes ] = None
28
+ field_name : str = ""
29
+ data : bytes = b""
30
+ file : typing .Optional [UploadFile ] = None
31
+ item_headers : typing .List [typing .Tuple [bytes , bytes ]] = field (default_factory = list )
33
32
34
33
35
34
def _user_safe_decode (src : bytes , codec : str ) -> str :
@@ -120,53 +119,115 @@ class MultiPartParser:
120
119
max_file_size = 1024 * 1024
121
120
122
121
def __init__ (
123
- self , headers : Headers , stream : typing .AsyncGenerator [bytes , None ]
122
+ self ,
123
+ headers : Headers ,
124
+ stream : typing .AsyncGenerator [bytes , None ],
125
+ * ,
126
+ max_files : typing .Union [int , float ] = 1000 ,
127
+ max_fields : typing .Union [int , float ] = 1000 ,
124
128
) -> None :
125
129
assert (
126
130
multipart is not None
127
131
), "The `python-multipart` library must be installed to use form parsing."
128
132
self .headers = headers
129
133
self .stream = stream
130
- self .messages : typing .List [typing .Tuple [MultiPartMessage , bytes ]] = []
134
+ self .max_files = max_files
135
+ self .max_fields = max_fields
136
+ self .items : typing .List [typing .Tuple [str , typing .Union [str , UploadFile ]]] = []
137
+ self ._current_files = 0
138
+ self ._current_fields = 0
139
+ self ._current_partial_header_name : bytes = b""
140
+ self ._current_partial_header_value : bytes = b""
141
+ self ._current_part = MultipartPart ()
142
+ self ._charset = ""
143
+ self ._file_parts_to_write : typing .List [typing .Tuple [MultipartPart , bytes ]] = []
144
+ self ._file_parts_to_finish : typing .List [MultipartPart ] = []
131
145
132
146
def on_part_begin (self ) -> None :
133
- message = (MultiPartMessage .PART_BEGIN , b"" )
134
- self .messages .append (message )
147
+ self ._current_part = MultipartPart ()
135
148
136
149
def on_part_data (self , data : bytes , start : int , end : int ) -> None :
137
- message = (MultiPartMessage .PART_DATA , data [start :end ])
138
- self .messages .append (message )
150
+ message_bytes = data [start :end ]
151
+ if self ._current_part .file is None :
152
+ self ._current_part .data += message_bytes
153
+ else :
154
+ self ._file_parts_to_write .append ((self ._current_part , message_bytes ))
139
155
140
156
def on_part_end (self ) -> None :
141
- message = (MultiPartMessage .PART_END , b"" )
142
- self .messages .append (message )
157
+ if self ._current_part .file is None :
158
+ self .items .append (
159
+ (
160
+ self ._current_part .field_name ,
161
+ _user_safe_decode (self ._current_part .data , self ._charset ),
162
+ )
163
+ )
164
+ else :
165
+ self ._file_parts_to_finish .append (self ._current_part )
166
+ # The file can be added to the items right now even though it's not
167
+ # finished yet, because it will be finished in the `parse()` method, before
168
+ # self.items is used in the return value.
169
+ self .items .append ((self ._current_part .field_name , self ._current_part .file ))
143
170
144
171
def on_header_field (self , data : bytes , start : int , end : int ) -> None :
145
- message = (MultiPartMessage .HEADER_FIELD , data [start :end ])
146
- self .messages .append (message )
172
+ self ._current_partial_header_name += data [start :end ]
147
173
148
174
def on_header_value (self , data : bytes , start : int , end : int ) -> None :
149
- message = (MultiPartMessage .HEADER_VALUE , data [start :end ])
150
- self .messages .append (message )
175
+ self ._current_partial_header_value += data [start :end ]
151
176
152
177
def on_header_end (self ) -> None :
153
- message = (MultiPartMessage .HEADER_END , b"" )
154
- self .messages .append (message )
178
+ field = self ._current_partial_header_name .lower ()
179
+ if field == b"content-disposition" :
180
+ self ._current_part .content_disposition = self ._current_partial_header_value
181
+ self ._current_part .item_headers .append (
182
+ (field , self ._current_partial_header_value )
183
+ )
184
+ self ._current_partial_header_name = b""
185
+ self ._current_partial_header_value = b""
155
186
156
187
def on_headers_finished (self ) -> None :
157
- message = (MultiPartMessage .HEADERS_FINISHED , b"" )
158
- self .messages .append (message )
188
+ disposition , options = parse_options_header (
189
+ self ._current_part .content_disposition
190
+ )
191
+ try :
192
+ self ._current_part .field_name = _user_safe_decode (
193
+ options [b"name" ], self ._charset
194
+ )
195
+ except KeyError :
196
+ raise MultiPartException (
197
+ 'The Content-Disposition header field "name" must be ' "provided."
198
+ )
199
+ if b"filename" in options :
200
+ self ._current_files += 1
201
+ if self ._current_files > self .max_files :
202
+ raise MultiPartException (
203
+ f"Too many files. Maximum number of files is { self .max_files } ."
204
+ )
205
+ filename = _user_safe_decode (options [b"filename" ], self ._charset )
206
+ tempfile = SpooledTemporaryFile (max_size = self .max_file_size )
207
+ self ._current_part .file = UploadFile (
208
+ file = tempfile , # type: ignore[arg-type]
209
+ size = 0 ,
210
+ filename = filename ,
211
+ headers = Headers (raw = self ._current_part .item_headers ),
212
+ )
213
+ else :
214
+ self ._current_fields += 1
215
+ if self ._current_fields > self .max_fields :
216
+ raise MultiPartException (
217
+ f"Too many fields. Maximum number of fields is { self .max_fields } ."
218
+ )
219
+ self ._current_part .file = None
159
220
160
221
def on_end (self ) -> None :
161
- message = (MultiPartMessage .END , b"" )
162
- self .messages .append (message )
222
+ pass
163
223
164
224
async def parse (self ) -> FormData :
165
225
# Parse the Content-Type header to get the multipart boundary.
166
226
_ , params = parse_options_header (self .headers ["Content-Type" ])
167
227
charset = params .get (b"charset" , "utf-8" )
168
228
if type (charset ) == bytes :
169
229
charset = charset .decode ("latin-1" )
230
+ self ._charset = charset
170
231
try :
171
232
boundary = params [b"boundary" ]
172
233
except KeyError :
@@ -186,68 +247,21 @@ async def parse(self) -> FormData:
186
247
187
248
# Create the parser.
188
249
parser = multipart .MultipartParser (boundary , callbacks )
189
- header_field = b""
190
- header_value = b""
191
- content_disposition = None
192
- field_name = ""
193
- data = b""
194
- file : typing .Optional [UploadFile ] = None
195
-
196
- items : typing .List [typing .Tuple [str , typing .Union [str , UploadFile ]]] = []
197
- item_headers : typing .List [typing .Tuple [bytes , bytes ]] = []
198
-
199
250
# Feed the parser with data from the request.
200
251
async for chunk in self .stream :
201
252
parser .write (chunk )
202
- messages = list (self .messages )
203
- self .messages .clear ()
204
- for message_type , message_bytes in messages :
205
- if message_type == MultiPartMessage .PART_BEGIN :
206
- content_disposition = None
207
- data = b""
208
- item_headers = []
209
- elif message_type == MultiPartMessage .HEADER_FIELD :
210
- header_field += message_bytes
211
- elif message_type == MultiPartMessage .HEADER_VALUE :
212
- header_value += message_bytes
213
- elif message_type == MultiPartMessage .HEADER_END :
214
- field = header_field .lower ()
215
- if field == b"content-disposition" :
216
- content_disposition = header_value
217
- item_headers .append ((field , header_value ))
218
- header_field = b""
219
- header_value = b""
220
- elif message_type == MultiPartMessage .HEADERS_FINISHED :
221
- disposition , options = parse_options_header (content_disposition )
222
- try :
223
- field_name = _user_safe_decode (options [b"name" ], charset )
224
- except KeyError :
225
- raise MultiPartException (
226
- 'The Content-Disposition header field "name" must be '
227
- "provided."
228
- )
229
- if b"filename" in options :
230
- filename = _user_safe_decode (options [b"filename" ], charset )
231
- tempfile = SpooledTemporaryFile (max_size = self .max_file_size )
232
- file = UploadFile (
233
- file = tempfile , # type: ignore[arg-type]
234
- size = 0 ,
235
- filename = filename ,
236
- headers = Headers (raw = item_headers ),
237
- )
238
- else :
239
- file = None
240
- elif message_type == MultiPartMessage .PART_DATA :
241
- if file is None :
242
- data += message_bytes
243
- else :
244
- await file .write (message_bytes )
245
- elif message_type == MultiPartMessage .PART_END :
246
- if file is None :
247
- items .append ((field_name , _user_safe_decode (data , charset )))
248
- else :
249
- await file .seek (0 )
250
- items .append ((field_name , file ))
253
+ # Write file data, it needs to use await with the UploadFile methods that
254
+ # call the corresponding file methods *in a threadpool*, otherwise, if
255
+ # they were called directly in the callback methods above (regular,
256
+ # non-async functions), that would block the event loop in the main thread.
257
+ for part , data in self ._file_parts_to_write :
258
+ assert part .file # for type checkers
259
+ await part .file .write (data )
260
+ for part in self ._file_parts_to_finish :
261
+ assert part .file # for type checkers
262
+ await part .file .seek (0 )
263
+ self ._file_parts_to_write .clear ()
264
+ self ._file_parts_to_finish .clear ()
251
265
252
266
parser .finalize ()
253
- return FormData (items )
267
+ return FormData (self . items )
0 commit comments