9
9
10
10
import datetime
11
11
import gc
12
+ import os
12
13
import pathlib
13
14
import select
14
15
import sys
25
26
)
26
27
from gc import collect , get_referrers
27
28
from os import makedirs
28
- from os .path import join
29
29
from socket import (
30
30
AF_INET ,
31
31
AF_INET6 ,
124
124
WantWriteError ,
125
125
ZeroReturnError ,
126
126
_make_requires ,
127
+ _NoOverlappingProtocols ,
127
128
)
128
129
129
130
from .test_crypto import (
@@ -166,25 +167,10 @@ def loopback_address(socket: socket) -> str:
166
167
return "::1"
167
168
168
169
169
- def join_bytes_or_unicode (prefix , suffix ):
170
- """
171
- Join two path components of either ``bytes`` or ``unicode``.
172
-
173
- The return type is the same as the type of ``prefix``.
174
- """
175
- # If the types are the same, nothing special is necessary.
176
- if type (prefix ) is type (suffix ):
177
- return join (prefix , suffix )
178
-
179
- # Otherwise, coerce suffix to the type of prefix.
180
- if isinstance (prefix , str ):
181
- return join (prefix , suffix .decode (getfilesystemencoding ()))
182
- else :
183
- return join (prefix , suffix .encode (getfilesystemencoding ()))
184
-
185
-
186
- def verify_cb (conn , cert , errnum , depth , ok ):
187
- return ok
170
+ def verify_cb (
171
+ conn : Connection , cert : X509 , errnum : int , depth : int , ok : int
172
+ ) -> bool :
173
+ return bool (ok )
188
174
189
175
190
176
def socket_pair () -> tuple [socket , socket ]:
@@ -360,7 +346,7 @@ def loopback(
360
346
361
347
def interact_in_memory (
362
348
client_conn : Connection , server_conn : Connection
363
- ) -> tuple [Connection , bytes ]:
349
+ ) -> tuple [Connection , bytes ] | None :
364
350
"""
365
351
Try to read application bytes from each of the two `Connection` objects.
366
352
Copy bytes back and forth between their send/receive buffers for as long
@@ -405,6 +391,8 @@ def interact_in_memory(
405
391
wrote = True
406
392
write .bio_write (dirty )
407
393
394
+ return None
395
+
408
396
409
397
def handshake_in_memory (
410
398
client_conn : Connection , server_conn : Connection
@@ -1021,9 +1009,9 @@ def info(conn: Connection, where: int, ret: int) -> None:
1021
1009
for (conn , where , ret ) in called
1022
1010
if not isinstance (conn , Connection )
1023
1011
]
1024
- assert (
1025
- [] == notConnections
1026
- ), "Some info callback arguments were not Connection instances."
1012
+ assert [] == notConnections , (
1013
+ "Some info callback arguments were not Connection instances."
1014
+ )
1027
1015
1028
1016
@pytest .mark .skipif (
1029
1017
not getattr (_lib , "Cryptography_HAS_KEYLOG" , None ),
@@ -1168,7 +1156,9 @@ def test_load_verify_invalid_file(self, tmpfile: bytes) -> None:
1168
1156
with pytest .raises (Error ):
1169
1157
clientContext .load_verify_locations (tmpfile )
1170
1158
1171
- def _load_verify_directory_locations_capath (self , capath : bytes ) -> None :
1159
+ def _load_verify_directory_locations_capath (
1160
+ self , capath : str | bytes
1161
+ ) -> None :
1172
1162
"""
1173
1163
Verify that if path to a directory containing certificate files is
1174
1164
passed to ``Context.load_verify_locations`` for the ``capath``
@@ -1180,7 +1170,11 @@ def _load_verify_directory_locations_capath(self, capath: bytes) -> None:
1180
1170
# c_rehash in the test suite. One is from OpenSSL 0.9.8, the other
1181
1171
# from OpenSSL 1.0.0.
1182
1172
for name in [b"c7adac82.0" , b"c3705638.0" ]:
1183
- cafile = join_bytes_or_unicode (capath , name )
1173
+ cafile : str | bytes
1174
+ if isinstance (capath , str ):
1175
+ cafile = os .path .join (capath , name .decode ())
1176
+ else :
1177
+ cafile = os .path .join (capath , name )
1184
1178
with open (cafile , "w" ) as fObj :
1185
1179
fObj .write (root_cert_pem .decode ("ascii" ))
1186
1180
@@ -1209,9 +1203,13 @@ def test_load_verify_directory_capath(
1209
1203
"""
1210
1204
if pathtype == "unicode_path" :
1211
1205
tmpfile += NON_ASCII .encode (getfilesystemencoding ())
1206
+
1212
1207
if argtype == "unicode_arg" :
1213
- tmpfile = tmpfile .decode (getfilesystemencoding ())
1214
- self ._load_verify_directory_locations_capath (tmpfile )
1208
+ self ._load_verify_directory_locations_capath (
1209
+ tmpfile .decode (getfilesystemencoding ())
1210
+ )
1211
+ else :
1212
+ self ._load_verify_directory_locations_capath (tmpfile )
1215
1213
1216
1214
def test_load_verify_locations_wrong_args (self ) -> None :
1217
1215
"""
@@ -1393,7 +1391,14 @@ def test_set_verify_callback_connection_argument(self) -> None:
1393
1391
serverConnection = Connection (serverContext , None )
1394
1392
1395
1393
class VerifyCallback :
1396
- def callback (self , connection : Connection , * args ) -> bool :
1394
+ def callback (
1395
+ self ,
1396
+ connection : Connection ,
1397
+ cert : X509 ,
1398
+ err : int ,
1399
+ depth : int ,
1400
+ ok : int ,
1401
+ ) -> bool :
1397
1402
self .connection = connection
1398
1403
return True
1399
1404
@@ -1452,7 +1457,9 @@ def test_set_verify_callback_exception(self) -> None:
1452
1457
1453
1458
clientContext = Context (TLSv1_2_METHOD )
1454
1459
1455
- def verify_callback (* args ):
1460
+ def verify_callback (
1461
+ conn : Connection , cert : X509 , err : int , depth : int , ok : int
1462
+ ) -> bool :
1456
1463
raise Exception ("silly verify failure" )
1457
1464
1458
1465
clientContext .set_verify (VERIFY_PEER , verify_callback )
@@ -1482,7 +1489,7 @@ def test_set_verify_callback_reference(self) -> None:
1482
1489
1483
1490
for i in range (5 ):
1484
1491
1485
- def verify_callback (* args ) :
1492
+ def verify_callback (* args : object ) -> bool :
1486
1493
return True
1487
1494
1488
1495
serverSocket , clientSocket = socket_pair ()
@@ -1589,8 +1596,14 @@ def _use_certificate_chain_file_test(self, certdir: str | bytes) -> None:
1589
1596
1590
1597
makedirs (certdir )
1591
1598
1592
- chainFile = join_bytes_or_unicode (certdir , "chain.pem" )
1593
- caFile = join_bytes_or_unicode (certdir , "ca.pem" )
1599
+ chainFile : str | bytes
1600
+ caFile : str | bytes
1601
+ if isinstance (certdir , str ):
1602
+ chainFile = os .path .join (certdir , "chain.pem" )
1603
+ caFile = os .path .join (certdir , "ca.pem" )
1604
+ else :
1605
+ chainFile = os .path .join (certdir , b"chain.pem" )
1606
+ caFile = os .path .join (certdir , b"ca.pem" )
1594
1607
1595
1608
# Write out the chain file.
1596
1609
with open (chainFile , "wb" ) as fObj :
@@ -1848,9 +1861,9 @@ def replacement(connection: Connection) -> None: # pragma: no cover
1848
1861
collect ()
1849
1862
collect ()
1850
1863
1851
- callback = tracker ()
1852
- if callback is not None :
1853
- referrers = get_referrers (callback )
1864
+ callback_ref = tracker ()
1865
+ if callback_ref is not None :
1866
+ referrers = get_referrers (callback_ref )
1854
1867
assert len (referrers ) == 1
1855
1868
1856
1869
def test_no_servername (self ) -> None :
@@ -2064,7 +2077,9 @@ def test_alpn_no_server_overlap(self) -> None:
2064
2077
"""
2065
2078
refusal_args = []
2066
2079
2067
- def refusal (conn : Connection , options : list [bytes ]):
2080
+ def refusal (
2081
+ conn : Connection , options : list [bytes ]
2082
+ ) -> _NoOverlappingProtocols :
2068
2083
refusal_args .append ((conn , options ))
2069
2084
return NO_OVERLAPPING_PROTOCOLS
2070
2085
@@ -2218,7 +2233,7 @@ def test_construction(self) -> None:
2218
2233
2219
2234
2220
2235
@pytest .fixture (params = ["context" , "connection" ])
2221
- def ctx_or_conn (request ) -> Context | Connection :
2236
+ def ctx_or_conn (request : pytest . FixtureRequest ) -> Context | Connection :
2222
2237
ctx = Context (SSLv23_METHOD )
2223
2238
if request .param == "context" :
2224
2239
return ctx
@@ -2823,9 +2838,9 @@ def callback(
2823
2838
)
2824
2839
collect ()
2825
2840
collect ()
2826
- callback = tracker ()
2827
- if callback is not None : # pragma: nocover
2828
- referrers = get_referrers (callback )
2841
+ callback_ref = tracker ()
2842
+ if callback_ref is not None : # pragma: nocover
2843
+ referrers = get_referrers (callback_ref )
2829
2844
assert len (referrers ) == 1
2830
2845
2831
2846
def test_get_session_unconnected (self ) -> None :
@@ -3862,7 +3877,9 @@ def test_outgoing_overflow(self) -> None:
3862
3877
# meaningless.
3863
3878
assert sent < size
3864
3879
3865
- receiver , received = interact_in_memory (client , server )
3880
+ result = interact_in_memory (client , server )
3881
+ assert result is not None
3882
+ receiver , received = result
3866
3883
assert receiver is server
3867
3884
3868
3885
# We can rely on all of these bytes being received at once because
@@ -4249,7 +4266,7 @@ def test_callbacks_arent_called_by_default(self) -> None:
4249
4266
called.
4250
4267
"""
4251
4268
4252
- def ocsp_callback (* args , ** kwargs ) : # pragma: nocover
4269
+ def ocsp_callback (* args : object ) -> typing . NoReturn : # pragma: nocover
4253
4270
pytest .fail ("Should not be called" )
4254
4271
4255
4272
client = self ._client_connection (
@@ -4284,7 +4301,7 @@ def test_client_receives_servers_data(self) -> None:
4284
4301
"""
4285
4302
calls = []
4286
4303
4287
- def server_callback (* args , ** kwargs ) :
4304
+ def server_callback (* args : object , ** kwargs : object ) -> bytes :
4288
4305
return self .sample_ocsp_data
4289
4306
4290
4307
def client_callback (
@@ -4307,11 +4324,15 @@ def test_callbacks_are_invoked_with_connections(self) -> None:
4307
4324
client_calls = []
4308
4325
server_calls = []
4309
4326
4310
- def client_callback (conn , * args , ** kwargs ):
4327
+ def client_callback (
4328
+ conn : Connection , * args : object , ** kwargs : object
4329
+ ) -> bool :
4311
4330
client_calls .append (conn )
4312
4331
return True
4313
4332
4314
- def server_callback (conn , * args , ** kwargs ):
4333
+ def server_callback (
4334
+ conn : Connection , * args : object , ** kwargs : object
4335
+ ) -> bytes :
4315
4336
server_calls .append (conn )
4316
4337
return self .sample_ocsp_data
4317
4338
@@ -4331,11 +4352,11 @@ def test_opaque_data_is_passed_through(self) -> None:
4331
4352
"""
4332
4353
calls = []
4333
4354
4334
- def server_callback (* args ) :
4355
+ def server_callback (* args : object ) -> bytes :
4335
4356
calls .append (args )
4336
4357
return self .sample_ocsp_data
4337
4358
4338
- def client_callback (* args ) :
4359
+ def client_callback (* args : object ) -> bool :
4339
4360
calls .append (args )
4340
4361
return True
4341
4362
@@ -4360,7 +4381,7 @@ def test_server_returns_empty_string(self) -> None:
4360
4381
"""
4361
4382
client_calls = []
4362
4383
4363
- def server_callback (* args ) :
4384
+ def server_callback (* args : object ) -> bytes :
4364
4385
return b""
4365
4386
4366
4387
def client_callback (
@@ -4381,10 +4402,10 @@ def test_client_returns_false_terminates_handshake(self) -> None:
4381
4402
If the client returns False from its callback, the handshake fails.
4382
4403
"""
4383
4404
4384
- def server_callback (* args ) :
4405
+ def server_callback (* args : object ) -> bytes :
4385
4406
return self .sample_ocsp_data
4386
4407
4387
- def client_callback (* args ) :
4408
+ def client_callback (* args : object ) -> bool :
4388
4409
return False
4389
4410
4390
4411
client = self ._client_connection (callback = client_callback , data = None )
@@ -4401,10 +4422,10 @@ def test_exceptions_in_client_bubble_up(self) -> None:
4401
4422
class SentinelException (Exception ):
4402
4423
pass
4403
4424
4404
- def server_callback (* args ) :
4425
+ def server_callback (* args : object ) -> bytes :
4405
4426
return self .sample_ocsp_data
4406
4427
4407
- def client_callback (* args ) :
4428
+ def client_callback (* args : object ) -> typing . NoReturn :
4408
4429
raise SentinelException ()
4409
4430
4410
4431
client = self ._client_connection (callback = client_callback , data = None )
@@ -4421,10 +4442,12 @@ def test_exceptions_in_server_bubble_up(self) -> None:
4421
4442
class SentinelException (Exception ):
4422
4443
pass
4423
4444
4424
- def server_callback (* args ) :
4445
+ def server_callback (* args : object ) -> typing . NoReturn :
4425
4446
raise SentinelException ()
4426
4447
4427
- def client_callback (* args ): # pragma: nocover
4448
+ def client_callback (
4449
+ * args : object ,
4450
+ ) -> typing .NoReturn : # pragma: nocover
4428
4451
pytest .fail ("Should not be called" )
4429
4452
4430
4453
client = self ._client_connection (callback = client_callback , data = None )
@@ -4438,14 +4461,16 @@ def test_server_must_return_bytes(self) -> None:
4438
4461
The server callback must return a bytestring, or a TypeError is thrown.
4439
4462
"""
4440
4463
4441
- def server_callback (* args ) :
4464
+ def server_callback (* args : object ) -> str :
4442
4465
return self .sample_ocsp_data .decode ("ascii" )
4443
4466
4444
- def client_callback (* args ): # pragma: nocover
4467
+ def client_callback (
4468
+ * args : object ,
4469
+ ) -> typing .NoReturn : # pragma: nocover
4445
4470
pytest .fail ("Should not be called" )
4446
4471
4447
4472
client = self ._client_connection (callback = client_callback , data = None )
4448
- server = self ._server_connection (callback = server_callback , data = None )
4473
+ server = self ._server_connection (callback = server_callback , data = None ) # type: ignore[arg-type]
4449
4474
4450
4475
with pytest .raises (TypeError ):
4451
4476
handshake_in_memory (client , server )
0 commit comments