Skip to content

Commit 1f9d4c9

Browse files
authored
Raise TypeError if SSLSocket is passed to asyncio transport-based methods (pythonGH-31442)
1 parent 4ab8167 commit 1f9d4c9

File tree

3 files changed

+22
-10
lines changed

3 files changed

+22
-10
lines changed

Lib/asyncio/base_events.py

+15
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,11 @@ def _set_nodelay(sock):
198198
pass
199199

200200

201+
def _check_ssl_socket(sock):
202+
if ssl is not None and isinstance(sock, ssl.SSLSocket):
203+
raise TypeError("Socket cannot be of type SSLSocket")
204+
205+
201206
class _SendfileFallbackProtocol(protocols.Protocol):
202207
def __init__(self, transp):
203208
if not isinstance(transp, transports._FlowControlMixin):
@@ -862,6 +867,7 @@ async def sock_sendfile(self, sock, file, offset=0, count=None,
862867
*, fallback=True):
863868
if self._debug and sock.gettimeout() != 0:
864869
raise ValueError("the socket must be non-blocking")
870+
_check_ssl_socket(sock)
865871
self._check_sendfile_params(sock, file, offset, count)
866872
try:
867873
return await self._sock_sendfile_native(sock, file,
@@ -1008,6 +1014,9 @@ async def create_connection(
10081014
raise ValueError(
10091015
'ssl_shutdown_timeout is only meaningful with ssl')
10101016

1017+
if sock is not None:
1018+
_check_ssl_socket(sock)
1019+
10111020
if happy_eyeballs_delay is not None and interleave is None:
10121021
# If using happy eyeballs, default to interleave addresses by family
10131022
interleave = 1
@@ -1438,6 +1447,9 @@ async def create_server(
14381447
raise ValueError(
14391448
'ssl_shutdown_timeout is only meaningful with ssl')
14401449

1450+
if sock is not None:
1451+
_check_ssl_socket(sock)
1452+
14411453
if host is not None or port is not None:
14421454
if sock is not None:
14431455
raise ValueError(
@@ -1538,6 +1550,9 @@ async def connect_accepted_socket(
15381550
raise ValueError(
15391551
'ssl_shutdown_timeout is only meaningful with ssl')
15401552

1553+
if sock is not None:
1554+
_check_ssl_socket(sock)
1555+
15411556
transport, protocol = await self._create_connection_transport(
15421557
sock, protocol_factory, ssl, '', server_side=True,
15431558
ssl_handshake_timeout=ssl_handshake_timeout,

Lib/asyncio/selector_events.py

+5-10
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,6 @@ def _test_selector_event(selector, fd, event):
4040
return bool(key.events & event)
4141

4242

43-
def _check_ssl_socket(sock):
44-
if ssl is not None and isinstance(sock, ssl.SSLSocket):
45-
raise TypeError("Socket cannot be of type SSLSocket")
46-
47-
4843
class BaseSelectorEventLoop(base_events.BaseEventLoop):
4944
"""Selector event loop.
5045
@@ -366,7 +361,7 @@ async def sock_recv(self, sock, n):
366361
The maximum amount of data to be received at once is specified by
367362
nbytes.
368363
"""
369-
_check_ssl_socket(sock)
364+
base_events._check_ssl_socket(sock)
370365
if self._debug and sock.gettimeout() != 0:
371366
raise ValueError("the socket must be non-blocking")
372367
try:
@@ -407,7 +402,7 @@ async def sock_recv_into(self, sock, buf):
407402
The received data is written into *buf* (a writable buffer).
408403
The return value is the number of bytes written.
409404
"""
410-
_check_ssl_socket(sock)
405+
base_events._check_ssl_socket(sock)
411406
if self._debug and sock.gettimeout() != 0:
412407
raise ValueError("the socket must be non-blocking")
413408
try:
@@ -448,7 +443,7 @@ async def sock_sendall(self, sock, data):
448443
raised, and there is no way to determine how much data, if any, was
449444
successfully processed by the receiving end of the connection.
450445
"""
451-
_check_ssl_socket(sock)
446+
base_events._check_ssl_socket(sock)
452447
if self._debug and sock.gettimeout() != 0:
453448
raise ValueError("the socket must be non-blocking")
454449
try:
@@ -497,7 +492,7 @@ async def sock_connect(self, sock, address):
497492
498493
This method is a coroutine.
499494
"""
500-
_check_ssl_socket(sock)
495+
base_events._check_ssl_socket(sock)
501496
if self._debug and sock.gettimeout() != 0:
502497
raise ValueError("the socket must be non-blocking")
503498

@@ -562,7 +557,7 @@ async def sock_accept(self, sock):
562557
object usable to send and receive data on the connection, and address
563558
is the address bound to the socket on the other end of the connection.
564559
"""
565-
_check_ssl_socket(sock)
560+
base_events._check_ssl_socket(sock)
566561
if self._debug and sock.gettimeout() != 0:
567562
raise ValueError("the socket must be non-blocking")
568563
fut = self.create_future()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Raise :exc:`TypeError` if :class:`ssl.SSLSocket` is passed to
2+
transport-based APIs.

0 commit comments

Comments
 (0)