Skip to content

Commit 64f0eb4

Browse files
authored
fix LimitedStream.read method to work with raw IO streams (#2559)
2 parents d554cb7 + 591b115 commit 64f0eb4

File tree

3 files changed

+110
-23
lines changed

3 files changed

+110
-23
lines changed

CHANGES.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ Unreleased
1717
client. :issue:`2549`
1818
- Fix handling of header extended parameters such that they are no longer quoted.
1919
:issue:`2529`
20+
- ``LimitedStream.read`` works correctly when wrapping a stream that may not return
21+
the requested size in one ``read`` call. :issue:`2558`
2022

2123

2224
Version 2.2.2

src/werkzeug/wsgi.py

Lines changed: 55 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -928,37 +928,77 @@ def on_disconnect(self) -> bytes:
928928

929929
raise ClientDisconnected()
930930

931-
def exhaust(self, chunk_size: int = 1024 * 64) -> None:
932-
"""Exhaust the stream. This consumes all the data left until the
933-
limit is reached.
931+
def _exhaust_chunks(self, chunk_size: int = 1024 * 64) -> t.Iterator[bytes]:
932+
"""Exhaust the stream by reading until the limit is reached or the client
933+
disconnects, yielding each chunk.
934+
935+
:param chunk_size: How many bytes to read at a time.
936+
937+
:meta private:
934938
935-
:param chunk_size: the size for a chunk. It will read the chunk
936-
until the stream is exhausted and throw away
937-
the results.
939+
.. versionadded:: 2.2.3
938940
"""
939941
to_read = self.limit - self._pos
940-
chunk = chunk_size
942+
941943
while to_read > 0:
942-
chunk = min(to_read, chunk)
943-
self.read(chunk)
944-
to_read -= chunk
944+
chunk = self.read(min(to_read, chunk_size))
945+
yield chunk
946+
to_read -= len(chunk)
947+
948+
def exhaust(self, chunk_size: int = 1024 * 64) -> None:
949+
"""Exhaust the stream by reading until the limit is reached or the client
950+
disconnects, discarding the data.
951+
952+
:param chunk_size: How many bytes to read at a time.
953+
954+
.. versionchanged:: 2.2.3
955+
Handle case where wrapped stream returns fewer bytes than requested.
956+
"""
957+
for _ in self._exhaust_chunks(chunk_size):
958+
pass
945959

946960
def read(self, size: t.Optional[int] = None) -> bytes:
947-
"""Read `size` bytes or if size is not provided everything is read.
961+
"""Read up to ``size`` bytes from the underlying stream. If size is not
962+
provided, read until the limit.
963+
964+
If the limit is reached, :meth:`on_exhausted` is called, which returns empty
965+
bytes.
948966
949-
:param size: the number of bytes read.
967+
If no bytes are read and the limit is not reached, or if an error occurs during
968+
the read, :meth:`on_disconnect` is called, which raises
969+
:exc:`.ClientDisconnected`.
970+
971+
:param size: The number of bytes to read. ``None``, default, reads until the
972+
limit is reached.
973+
974+
.. versionchanged:: 2.2.3
975+
Handle case where wrapped stream returns fewer bytes than requested.
950976
"""
951977
if self._pos >= self.limit:
952978
return self.on_exhausted()
953-
if size is None or size == -1: # -1 is for consistence with file
954-
size = self.limit
979+
980+
if size is None or size == -1: # -1 is for consistency with file
981+
# Keep reading from the wrapped stream until the limit is reached. Can't
982+
# rely on stream.read(size) because it's not guaranteed to return size.
983+
buf = bytearray()
984+
985+
for chunk in self._exhaust_chunks():
986+
buf.extend(chunk)
987+
988+
return bytes(buf)
989+
955990
to_read = min(self.limit - self._pos, size)
991+
956992
try:
957993
read = self._read(to_read)
958994
except (OSError, ValueError):
959995
return self.on_disconnect()
960-
if to_read and len(read) != to_read:
996+
997+
if to_read and not len(read):
998+
# If no data was read, treat it as a disconnect. As long as some data was
999+
# read, a subsequent call can still return more before reaching the limit.
9611000
return self.on_disconnect()
1001+
9621002
self._pos += len(read)
9631003
return read
9641004

tests/test_wsgi.py

Lines changed: 53 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
from __future__ import annotations
2+
13
import io
24
import json
35
import os
6+
import typing as t
47

58
import pytest
69

@@ -165,21 +168,63 @@ def test_limited_stream_json_load():
165168

166169

167170
def test_limited_stream_disconnection():
168-
io_ = io.BytesIO(b"A bit of content")
169-
170-
# disconnect detection on out of bytes
171-
stream = wsgi.LimitedStream(io_, 255)
171+
# disconnect because stream returns zero bytes
172+
stream = wsgi.LimitedStream(io.BytesIO(), 255)
172173
with pytest.raises(ClientDisconnected):
173174
stream.read()
174175

175-
# disconnect detection because file close
176-
io_ = io.BytesIO(b"x" * 255)
177-
io_.close()
178-
stream = wsgi.LimitedStream(io_, 255)
176+
# disconnect because stream is closed
177+
data = io.BytesIO(b"x" * 255)
178+
data.close()
179+
stream = wsgi.LimitedStream(data, 255)
180+
179181
with pytest.raises(ClientDisconnected):
180182
stream.read()
181183

182184

185+
def test_limited_stream_read_with_raw_io():
186+
class OneByteStream(t.BinaryIO):
187+
def __init__(self, buf: bytes) -> None:
188+
self.buf = buf
189+
self.pos = 0
190+
191+
def read(self, size: int | None = None) -> bytes:
192+
"""Return one byte at a time regardless of requested size."""
193+
194+
if size is None or size == -1:
195+
raise ValueError("expected read to be called with specific limit")
196+
197+
if size == 0 or len(self.buf) < self.pos:
198+
return b""
199+
200+
b = self.buf[self.pos : self.pos + 1]
201+
self.pos += 1
202+
return b
203+
204+
stream = wsgi.LimitedStream(OneByteStream(b"foo"), 4)
205+
assert stream.read(5) == b"f"
206+
assert stream.read(5) == b"o"
207+
assert stream.read(5) == b"o"
208+
209+
# The stream has fewer bytes (3) than the limit (4), therefore the read returns 0
210+
# bytes before the limit is reached.
211+
with pytest.raises(ClientDisconnected):
212+
stream.read(5)
213+
214+
stream = wsgi.LimitedStream(OneByteStream(b"foo123"), 3)
215+
assert stream.read(5) == b"f"
216+
assert stream.read(5) == b"o"
217+
assert stream.read(5) == b"o"
218+
# The limit was reached, therefore the wrapper is exhausted, not disconnected.
219+
assert stream.read(5) == b""
220+
221+
stream = wsgi.LimitedStream(OneByteStream(b"foo"), 3)
222+
assert stream.read() == b"foo"
223+
224+
stream = wsgi.LimitedStream(OneByteStream(b"foo"), 2)
225+
assert stream.read() == b"fo"
226+
227+
183228
def test_get_host_fallback():
184229
assert (
185230
wsgi.get_host(

0 commit comments

Comments
 (0)