Skip to content

Commit 70e2835

Browse files
authored
Bring full tests directory to typing correctly (#1407)
1 parent 9422c36 commit 70e2835

File tree

2 files changed

+84
-59
lines changed

2 files changed

+84
-59
lines changed

tests/test_ssl.py

+83-58
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import datetime
1111
import gc
12+
import os
1213
import pathlib
1314
import select
1415
import sys
@@ -25,7 +26,6 @@
2526
)
2627
from gc import collect, get_referrers
2728
from os import makedirs
28-
from os.path import join
2929
from socket import (
3030
AF_INET,
3131
AF_INET6,
@@ -124,6 +124,7 @@
124124
WantWriteError,
125125
ZeroReturnError,
126126
_make_requires,
127+
_NoOverlappingProtocols,
127128
)
128129

129130
from .test_crypto import (
@@ -166,25 +167,10 @@ def loopback_address(socket: socket) -> str:
166167
return "::1"
167168

168169

169-
def join_bytes_or_unicode(prefix, suffix):
170-
"""
171-
Join two path components of either ``bytes`` or ``unicode``.
172-
173-
The return type is the same as the type of ``prefix``.
174-
"""
175-
# If the types are the same, nothing special is necessary.
176-
if type(prefix) is type(suffix):
177-
return join(prefix, suffix)
178-
179-
# Otherwise, coerce suffix to the type of prefix.
180-
if isinstance(prefix, str):
181-
return join(prefix, suffix.decode(getfilesystemencoding()))
182-
else:
183-
return join(prefix, suffix.encode(getfilesystemencoding()))
184-
185-
186-
def verify_cb(conn, cert, errnum, depth, ok):
187-
return ok
170+
def verify_cb(
171+
conn: Connection, cert: X509, errnum: int, depth: int, ok: int
172+
) -> bool:
173+
return bool(ok)
188174

189175

190176
def socket_pair() -> tuple[socket, socket]:
@@ -360,7 +346,7 @@ def loopback(
360346

361347
def interact_in_memory(
362348
client_conn: Connection, server_conn: Connection
363-
) -> tuple[Connection, bytes]:
349+
) -> tuple[Connection, bytes] | None:
364350
"""
365351
Try to read application bytes from each of the two `Connection` objects.
366352
Copy bytes back and forth between their send/receive buffers for as long
@@ -405,6 +391,8 @@ def interact_in_memory(
405391
wrote = True
406392
write.bio_write(dirty)
407393

394+
return None
395+
408396

409397
def handshake_in_memory(
410398
client_conn: Connection, server_conn: Connection
@@ -1021,9 +1009,9 @@ def info(conn: Connection, where: int, ret: int) -> None:
10211009
for (conn, where, ret) in called
10221010
if not isinstance(conn, Connection)
10231011
]
1024-
assert (
1025-
[] == notConnections
1026-
), "Some info callback arguments were not Connection instances."
1012+
assert [] == notConnections, (
1013+
"Some info callback arguments were not Connection instances."
1014+
)
10271015

10281016
@pytest.mark.skipif(
10291017
not getattr(_lib, "Cryptography_HAS_KEYLOG", None),
@@ -1168,7 +1156,9 @@ def test_load_verify_invalid_file(self, tmpfile: bytes) -> None:
11681156
with pytest.raises(Error):
11691157
clientContext.load_verify_locations(tmpfile)
11701158

1171-
def _load_verify_directory_locations_capath(self, capath: bytes) -> None:
1159+
def _load_verify_directory_locations_capath(
1160+
self, capath: str | bytes
1161+
) -> None:
11721162
"""
11731163
Verify that if path to a directory containing certificate files is
11741164
passed to ``Context.load_verify_locations`` for the ``capath``
@@ -1180,7 +1170,11 @@ def _load_verify_directory_locations_capath(self, capath: bytes) -> None:
11801170
# c_rehash in the test suite. One is from OpenSSL 0.9.8, the other
11811171
# from OpenSSL 1.0.0.
11821172
for name in [b"c7adac82.0", b"c3705638.0"]:
1183-
cafile = join_bytes_or_unicode(capath, name)
1173+
cafile: str | bytes
1174+
if isinstance(capath, str):
1175+
cafile = os.path.join(capath, name.decode())
1176+
else:
1177+
cafile = os.path.join(capath, name)
11841178
with open(cafile, "w") as fObj:
11851179
fObj.write(root_cert_pem.decode("ascii"))
11861180

@@ -1209,9 +1203,13 @@ def test_load_verify_directory_capath(
12091203
"""
12101204
if pathtype == "unicode_path":
12111205
tmpfile += NON_ASCII.encode(getfilesystemencoding())
1206+
12121207
if argtype == "unicode_arg":
1213-
tmpfile = tmpfile.decode(getfilesystemencoding())
1214-
self._load_verify_directory_locations_capath(tmpfile)
1208+
self._load_verify_directory_locations_capath(
1209+
tmpfile.decode(getfilesystemencoding())
1210+
)
1211+
else:
1212+
self._load_verify_directory_locations_capath(tmpfile)
12151213

12161214
def test_load_verify_locations_wrong_args(self) -> None:
12171215
"""
@@ -1393,7 +1391,14 @@ def test_set_verify_callback_connection_argument(self) -> None:
13931391
serverConnection = Connection(serverContext, None)
13941392

13951393
class VerifyCallback:
1396-
def callback(self, connection: Connection, *args) -> bool:
1394+
def callback(
1395+
self,
1396+
connection: Connection,
1397+
cert: X509,
1398+
err: int,
1399+
depth: int,
1400+
ok: int,
1401+
) -> bool:
13971402
self.connection = connection
13981403
return True
13991404

@@ -1452,7 +1457,9 @@ def test_set_verify_callback_exception(self) -> None:
14521457

14531458
clientContext = Context(TLSv1_2_METHOD)
14541459

1455-
def verify_callback(*args):
1460+
def verify_callback(
1461+
conn: Connection, cert: X509, err: int, depth: int, ok: int
1462+
) -> bool:
14561463
raise Exception("silly verify failure")
14571464

14581465
clientContext.set_verify(VERIFY_PEER, verify_callback)
@@ -1482,7 +1489,7 @@ def test_set_verify_callback_reference(self) -> None:
14821489

14831490
for i in range(5):
14841491

1485-
def verify_callback(*args):
1492+
def verify_callback(*args: object) -> bool:
14861493
return True
14871494

14881495
serverSocket, clientSocket = socket_pair()
@@ -1589,8 +1596,14 @@ def _use_certificate_chain_file_test(self, certdir: str | bytes) -> None:
15891596

15901597
makedirs(certdir)
15911598

1592-
chainFile = join_bytes_or_unicode(certdir, "chain.pem")
1593-
caFile = join_bytes_or_unicode(certdir, "ca.pem")
1599+
chainFile: str | bytes
1600+
caFile: str | bytes
1601+
if isinstance(certdir, str):
1602+
chainFile = os.path.join(certdir, "chain.pem")
1603+
caFile = os.path.join(certdir, "ca.pem")
1604+
else:
1605+
chainFile = os.path.join(certdir, b"chain.pem")
1606+
caFile = os.path.join(certdir, b"ca.pem")
15941607

15951608
# Write out the chain file.
15961609
with open(chainFile, "wb") as fObj:
@@ -1848,9 +1861,9 @@ def replacement(connection: Connection) -> None: # pragma: no cover
18481861
collect()
18491862
collect()
18501863

1851-
callback = tracker()
1852-
if callback is not None:
1853-
referrers = get_referrers(callback)
1864+
callback_ref = tracker()
1865+
if callback_ref is not None:
1866+
referrers = get_referrers(callback_ref)
18541867
assert len(referrers) == 1
18551868

18561869
def test_no_servername(self) -> None:
@@ -2064,7 +2077,9 @@ def test_alpn_no_server_overlap(self) -> None:
20642077
"""
20652078
refusal_args = []
20662079

2067-
def refusal(conn: Connection, options: list[bytes]):
2080+
def refusal(
2081+
conn: Connection, options: list[bytes]
2082+
) -> _NoOverlappingProtocols:
20682083
refusal_args.append((conn, options))
20692084
return NO_OVERLAPPING_PROTOCOLS
20702085

@@ -2218,7 +2233,7 @@ def test_construction(self) -> None:
22182233

22192234

22202235
@pytest.fixture(params=["context", "connection"])
2221-
def ctx_or_conn(request) -> Context | Connection:
2236+
def ctx_or_conn(request: pytest.FixtureRequest) -> Context | Connection:
22222237
ctx = Context(SSLv23_METHOD)
22232238
if request.param == "context":
22242239
return ctx
@@ -2823,9 +2838,9 @@ def callback(
28232838
)
28242839
collect()
28252840
collect()
2826-
callback = tracker()
2827-
if callback is not None: # pragma: nocover
2828-
referrers = get_referrers(callback)
2841+
callback_ref = tracker()
2842+
if callback_ref is not None: # pragma: nocover
2843+
referrers = get_referrers(callback_ref)
28292844
assert len(referrers) == 1
28302845

28312846
def test_get_session_unconnected(self) -> None:
@@ -3862,7 +3877,9 @@ def test_outgoing_overflow(self) -> None:
38623877
# meaningless.
38633878
assert sent < size
38643879

3865-
receiver, received = interact_in_memory(client, server)
3880+
result = interact_in_memory(client, server)
3881+
assert result is not None
3882+
receiver, received = result
38663883
assert receiver is server
38673884

38683885
# We can rely on all of these bytes being received at once because
@@ -4249,7 +4266,7 @@ def test_callbacks_arent_called_by_default(self) -> None:
42494266
called.
42504267
"""
42514268

4252-
def ocsp_callback(*args, **kwargs): # pragma: nocover
4269+
def ocsp_callback(*args: object) -> typing.NoReturn: # pragma: nocover
42534270
pytest.fail("Should not be called")
42544271

42554272
client = self._client_connection(
@@ -4284,7 +4301,7 @@ def test_client_receives_servers_data(self) -> None:
42844301
"""
42854302
calls = []
42864303

4287-
def server_callback(*args, **kwargs):
4304+
def server_callback(*args: object, **kwargs: object) -> bytes:
42884305
return self.sample_ocsp_data
42894306

42904307
def client_callback(
@@ -4307,11 +4324,15 @@ def test_callbacks_are_invoked_with_connections(self) -> None:
43074324
client_calls = []
43084325
server_calls = []
43094326

4310-
def client_callback(conn, *args, **kwargs):
4327+
def client_callback(
4328+
conn: Connection, *args: object, **kwargs: object
4329+
) -> bool:
43114330
client_calls.append(conn)
43124331
return True
43134332

4314-
def server_callback(conn, *args, **kwargs):
4333+
def server_callback(
4334+
conn: Connection, *args: object, **kwargs: object
4335+
) -> bytes:
43154336
server_calls.append(conn)
43164337
return self.sample_ocsp_data
43174338

@@ -4331,11 +4352,11 @@ def test_opaque_data_is_passed_through(self) -> None:
43314352
"""
43324353
calls = []
43334354

4334-
def server_callback(*args):
4355+
def server_callback(*args: object) -> bytes:
43354356
calls.append(args)
43364357
return self.sample_ocsp_data
43374358

4338-
def client_callback(*args):
4359+
def client_callback(*args: object) -> bool:
43394360
calls.append(args)
43404361
return True
43414362

@@ -4360,7 +4381,7 @@ def test_server_returns_empty_string(self) -> None:
43604381
"""
43614382
client_calls = []
43624383

4363-
def server_callback(*args):
4384+
def server_callback(*args: object) -> bytes:
43644385
return b""
43654386

43664387
def client_callback(
@@ -4381,10 +4402,10 @@ def test_client_returns_false_terminates_handshake(self) -> None:
43814402
If the client returns False from its callback, the handshake fails.
43824403
"""
43834404

4384-
def server_callback(*args):
4405+
def server_callback(*args: object) -> bytes:
43854406
return self.sample_ocsp_data
43864407

4387-
def client_callback(*args):
4408+
def client_callback(*args: object) -> bool:
43884409
return False
43894410

43904411
client = self._client_connection(callback=client_callback, data=None)
@@ -4401,10 +4422,10 @@ def test_exceptions_in_client_bubble_up(self) -> None:
44014422
class SentinelException(Exception):
44024423
pass
44034424

4404-
def server_callback(*args):
4425+
def server_callback(*args: object) -> bytes:
44054426
return self.sample_ocsp_data
44064427

4407-
def client_callback(*args):
4428+
def client_callback(*args: object) -> typing.NoReturn:
44084429
raise SentinelException()
44094430

44104431
client = self._client_connection(callback=client_callback, data=None)
@@ -4421,10 +4442,12 @@ def test_exceptions_in_server_bubble_up(self) -> None:
44214442
class SentinelException(Exception):
44224443
pass
44234444

4424-
def server_callback(*args):
4445+
def server_callback(*args: object) -> typing.NoReturn:
44254446
raise SentinelException()
44264447

4427-
def client_callback(*args): # pragma: nocover
4448+
def client_callback(
4449+
*args: object,
4450+
) -> typing.NoReturn: # pragma: nocover
44284451
pytest.fail("Should not be called")
44294452

44304453
client = self._client_connection(callback=client_callback, data=None)
@@ -4438,14 +4461,16 @@ def test_server_must_return_bytes(self) -> None:
44384461
The server callback must return a bytestring, or a TypeError is thrown.
44394462
"""
44404463

4441-
def server_callback(*args):
4464+
def server_callback(*args: object) -> str:
44424465
return self.sample_ocsp_data.decode("ascii")
44434466

4444-
def client_callback(*args): # pragma: nocover
4467+
def client_callback(
4468+
*args: object,
4469+
) -> typing.NoReturn: # pragma: nocover
44454470
pytest.fail("Should not be called")
44464471

44474472
client = self._client_connection(callback=client_callback, data=None)
4448-
server = self._server_connection(callback=server_callback, data=None)
4473+
server = self._server_connection(callback=server_callback, data=None) # type: ignore[arg-type]
44494474

44504475
with pytest.raises(TypeError):
44514476
handshake_in_memory(client, server)

tox.ini

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ extras =
4747
deps =
4848
mypy
4949
commands =
50-
mypy src/ tests/conftest.py tests/test_crypto.py tests/test_debug.py tests/test_rand.py tests/test_util.py tests/util.py
50+
mypy src/ tests/
5151

5252
[testenv:check-manifest]
5353
deps =

0 commit comments

Comments
 (0)