Skip to content

Commit 64705e6

Browse files
authored
[3.9] Raise TypeError if SSLSocket is passed to asyncio transport-based methods (GH-31442) (GH-31444)
(cherry picked from commit 1f9d4c9) Co-authored-by: Andrew Svetlov <[email protected]>
1 parent a6116a9 commit 64705e6

File tree

3 files changed

+22
-10
lines changed

3 files changed

+22
-10
lines changed

Lib/asyncio/base_events.py

+15
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,11 @@ def _set_nodelay(sock):
202202
pass
203203

204204

205+
def _check_ssl_socket(sock):
206+
if ssl is not None and isinstance(sock, ssl.SSLSocket):
207+
raise TypeError("Socket cannot be of type SSLSocket")
208+
209+
205210
class _SendfileFallbackProtocol(protocols.Protocol):
206211
def __init__(self, transp):
207212
if not isinstance(transp, transports._FlowControlMixin):
@@ -864,6 +869,7 @@ async def sock_sendfile(self, sock, file, offset=0, count=None,
864869
*, fallback=True):
865870
if self._debug and sock.gettimeout() != 0:
866871
raise ValueError("the socket must be non-blocking")
872+
_check_ssl_socket(sock)
867873
self._check_sendfile_params(sock, file, offset, count)
868874
try:
869875
return await self._sock_sendfile_native(sock, file,
@@ -1005,6 +1011,9 @@ async def create_connection(
10051011
raise ValueError(
10061012
'ssl_handshake_timeout is only meaningful with ssl')
10071013

1014+
if sock is not None:
1015+
_check_ssl_socket(sock)
1016+
10081017
if happy_eyeballs_delay is not None and interleave is None:
10091018
# If using happy eyeballs, default to interleave addresses by family
10101019
interleave = 1
@@ -1438,6 +1447,9 @@ async def create_server(
14381447
raise ValueError(
14391448
'ssl_handshake_timeout is only meaningful with ssl')
14401449

1450+
if sock is not None:
1451+
_check_ssl_socket(sock)
1452+
14411453
if host is not None or port is not None:
14421454
if sock is not None:
14431455
raise ValueError(
@@ -1540,6 +1552,9 @@ async def connect_accepted_socket(
15401552
raise ValueError(
15411553
'ssl_handshake_timeout is only meaningful with ssl')
15421554

1555+
if sock is not None:
1556+
_check_ssl_socket(sock)
1557+
15431558
transport, protocol = await self._create_connection_transport(
15441559
sock, protocol_factory, ssl, '', server_side=True,
15451560
ssl_handshake_timeout=ssl_handshake_timeout)

Lib/asyncio/selector_events.py

+5-10
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,6 @@ def _test_selector_event(selector, fd, event):
4040
return bool(key.events & event)
4141

4242

43-
def _check_ssl_socket(sock):
44-
if ssl is not None and isinstance(sock, ssl.SSLSocket):
45-
raise TypeError("Socket cannot be of type SSLSocket")
46-
47-
4843
class BaseSelectorEventLoop(base_events.BaseEventLoop):
4944
"""Selector event loop.
5045
@@ -357,7 +352,7 @@ async def sock_recv(self, sock, n):
357352
The maximum amount of data to be received at once is specified by
358353
nbytes.
359354
"""
360-
_check_ssl_socket(sock)
355+
base_events._check_ssl_socket(sock)
361356
if self._debug and sock.gettimeout() != 0:
362357
raise ValueError("the socket must be non-blocking")
363358
try:
@@ -398,7 +393,7 @@ async def sock_recv_into(self, sock, buf):
398393
The received data is written into *buf* (a writable buffer).
399394
The return value is the number of bytes written.
400395
"""
401-
_check_ssl_socket(sock)
396+
base_events._check_ssl_socket(sock)
402397
if self._debug and sock.gettimeout() != 0:
403398
raise ValueError("the socket must be non-blocking")
404399
try:
@@ -439,7 +434,7 @@ async def sock_sendall(self, sock, data):
439434
raised, and there is no way to determine how much data, if any, was
440435
successfully processed by the receiving end of the connection.
441436
"""
442-
_check_ssl_socket(sock)
437+
base_events._check_ssl_socket(sock)
443438
if self._debug and sock.gettimeout() != 0:
444439
raise ValueError("the socket must be non-blocking")
445440
try:
@@ -488,7 +483,7 @@ async def sock_connect(self, sock, address):
488483
489484
This method is a coroutine.
490485
"""
491-
_check_ssl_socket(sock)
486+
base_events._check_ssl_socket(sock)
492487
if self._debug and sock.gettimeout() != 0:
493488
raise ValueError("the socket must be non-blocking")
494489

@@ -553,7 +548,7 @@ async def sock_accept(self, sock):
553548
object usable to send and receive data on the connection, and address
554549
is the address bound to the socket on the other end of the connection.
555550
"""
556-
_check_ssl_socket(sock)
551+
base_events._check_ssl_socket(sock)
557552
if self._debug and sock.gettimeout() != 0:
558553
raise ValueError("the socket must be non-blocking")
559554
fut = self.create_future()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Raise :exc:`TypeError` if :class:`ssl.SSLSocket` is passed to
2+
transport-based APIs.

0 commit comments

Comments
 (0)