Skip to content

bpo-46252: Raise TypeError if SSLSocket is passed to asyncio transport-based methods #31442

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions Lib/asyncio/base_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,11 @@ def _set_nodelay(sock):
pass


def _check_ssl_socket(sock):
if ssl is not None and isinstance(sock, ssl.SSLSocket):
raise TypeError("Socket cannot be of type SSLSocket")


class _SendfileFallbackProtocol(protocols.Protocol):
def __init__(self, transp):
if not isinstance(transp, transports._FlowControlMixin):
Expand Down Expand Up @@ -862,6 +867,7 @@ async def sock_sendfile(self, sock, file, offset=0, count=None,
*, fallback=True):
if self._debug and sock.gettimeout() != 0:
raise ValueError("the socket must be non-blocking")
_check_ssl_socket(sock)
self._check_sendfile_params(sock, file, offset, count)
try:
return await self._sock_sendfile_native(sock, file,
Expand Down Expand Up @@ -1008,6 +1014,9 @@ async def create_connection(
raise ValueError(
'ssl_shutdown_timeout is only meaningful with ssl')

if sock is not None:
_check_ssl_socket(sock)

if happy_eyeballs_delay is not None and interleave is None:
# If using happy eyeballs, default to interleave addresses by family
interleave = 1
Expand Down Expand Up @@ -1438,6 +1447,9 @@ async def create_server(
raise ValueError(
'ssl_shutdown_timeout is only meaningful with ssl')

if sock is not None:
_check_ssl_socket(sock)

if host is not None or port is not None:
if sock is not None:
raise ValueError(
Expand Down Expand Up @@ -1538,6 +1550,9 @@ async def connect_accepted_socket(
raise ValueError(
'ssl_shutdown_timeout is only meaningful with ssl')

if sock is not None:
_check_ssl_socket(sock)

transport, protocol = await self._create_connection_transport(
sock, protocol_factory, ssl, '', server_side=True,
ssl_handshake_timeout=ssl_handshake_timeout,
Expand Down
15 changes: 5 additions & 10 deletions Lib/asyncio/selector_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,6 @@ def _test_selector_event(selector, fd, event):
return bool(key.events & event)


def _check_ssl_socket(sock):
if ssl is not None and isinstance(sock, ssl.SSLSocket):
raise TypeError("Socket cannot be of type SSLSocket")


class BaseSelectorEventLoop(base_events.BaseEventLoop):
"""Selector event loop.

Expand Down Expand Up @@ -366,7 +361,7 @@ async def sock_recv(self, sock, n):
The maximum amount of data to be received at once is specified by
nbytes.
"""
_check_ssl_socket(sock)
base_events._check_ssl_socket(sock)
if self._debug and sock.gettimeout() != 0:
raise ValueError("the socket must be non-blocking")
try:
Expand Down Expand Up @@ -407,7 +402,7 @@ async def sock_recv_into(self, sock, buf):
The received data is written into *buf* (a writable buffer).
The return value is the number of bytes written.
"""
_check_ssl_socket(sock)
base_events._check_ssl_socket(sock)
if self._debug and sock.gettimeout() != 0:
raise ValueError("the socket must be non-blocking")
try:
Expand Down Expand Up @@ -448,7 +443,7 @@ async def sock_sendall(self, sock, data):
raised, and there is no way to determine how much data, if any, was
successfully processed by the receiving end of the connection.
"""
_check_ssl_socket(sock)
base_events._check_ssl_socket(sock)
if self._debug and sock.gettimeout() != 0:
raise ValueError("the socket must be non-blocking")
try:
Expand Down Expand Up @@ -497,7 +492,7 @@ async def sock_connect(self, sock, address):

This method is a coroutine.
"""
_check_ssl_socket(sock)
base_events._check_ssl_socket(sock)
if self._debug and sock.gettimeout() != 0:
raise ValueError("the socket must be non-blocking")

Expand Down Expand Up @@ -562,7 +557,7 @@ async def sock_accept(self, sock):
object usable to send and receive data on the connection, and address
is the address bound to the socket on the other end of the connection.
"""
_check_ssl_socket(sock)
base_events._check_ssl_socket(sock)
if self._debug and sock.gettimeout() != 0:
raise ValueError("the socket must be non-blocking")
fut = self.create_future()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Raise :exc:`TypeError` if :class:`ssl.SSLSocket` is passed to
transport-based APIs.