35
35
36
36
37
37
if not sys .implementation .name == "circuitpython" :
38
- from typing import Optional , Tuple
38
+ from typing import List , Optional , Tuple
39
39
40
40
from circuitpython_typing .socket import (
41
41
CircuitPythonSocketType ,
@@ -64,15 +64,14 @@ def connect(self, address: Tuple[str, int]) -> None:
64
64
try :
65
65
return self ._socket .connect (address , self ._mode )
66
66
except RuntimeError as error :
67
- raise OSError (errno .ENOMEM ) from error
67
+ raise OSError (errno .ENOMEM , str ( error ) ) from error
68
68
69
69
70
70
class _FakeSSLContext :
71
71
def __init__ (self , iface : InterfaceType ) -> None :
72
72
self ._iface = iface
73
73
74
- # pylint: disable=unused-argument
75
- def wrap_socket (
74
+ def wrap_socket ( # pylint: disable=unused-argument
76
75
self , socket : CircuitPythonSocketType , server_hostname : Optional [str ] = None
77
76
) -> _FakeSSLSocket :
78
77
"""Return the same socket"""
@@ -99,7 +98,8 @@ def create_fake_ssl_context(
99
98
return _FakeSSLContext (iface )
100
99
101
100
102
- _global_socketpool = {}
101
+ _global_connection_managers = {}
102
+ _global_socketpools = {}
103
103
_global_ssl_contexts = {}
104
104
105
105
@@ -113,7 +113,7 @@ def get_radio_socketpool(radio):
113
113
* Using a WIZ5500 (Like the Adafruit Ethernet FeatherWing)
114
114
"""
115
115
class_name = radio .__class__ .__name__
116
- if class_name not in _global_socketpool :
116
+ if class_name not in _global_socketpools :
117
117
if class_name == "Radio" :
118
118
import ssl # pylint: disable=import-outside-toplevel
119
119
@@ -151,10 +151,10 @@ def get_radio_socketpool(radio):
151
151
else :
152
152
raise AttributeError (f"Unsupported radio class: { class_name } " )
153
153
154
- _global_socketpool [class_name ] = pool
154
+ _global_socketpools [class_name ] = pool
155
155
_global_ssl_contexts [class_name ] = ssl_context
156
156
157
- return _global_socketpool [class_name ]
157
+ return _global_socketpools [class_name ]
158
158
159
159
160
160
def get_radio_ssl_context (radio ):
@@ -183,42 +183,75 @@ def __init__(
183
183
) -> None :
184
184
self ._socket_pool = socket_pool
185
185
# Hang onto open sockets so that we can reuse them.
186
- self ._available_socket = {}
187
- self ._open_sockets = {}
188
-
189
- def _free_sockets (self ) -> None :
190
- available_sockets = []
191
- for socket , free in self ._available_socket .items ():
192
- if free :
193
- available_sockets .append (socket )
186
+ self ._available_sockets = set ()
187
+ self ._key_by_managed_socket = {}
188
+ self ._managed_socket_by_key = {}
194
189
190
+ def _free_sockets (self , force : bool = False ) -> None :
191
+ # cloning lists since items are being removed
192
+ available_sockets = list (self ._available_sockets )
195
193
for socket in available_sockets :
196
194
self .close_socket (socket )
195
+ if force :
196
+ open_sockets = list (self ._managed_socket_by_key .values ())
197
+ for socket in open_sockets :
198
+ self .close_socket (socket )
197
199
198
- def _get_key_for_socket (self , socket ):
200
+ def _get_connected_socket ( # pylint: disable=too-many-arguments
201
+ self ,
202
+ addr_info : List [Tuple [int , int , int , str , Tuple [str , int ]]],
203
+ host : str ,
204
+ port : int ,
205
+ timeout : float ,
206
+ is_ssl : bool ,
207
+ ssl_context : Optional [SSLContextType ] = None ,
208
+ ):
199
209
try :
200
- return next (
201
- key for key , value in self ._open_sockets .items () if value == socket
202
- )
203
- except StopIteration :
204
- return None
210
+ socket = self ._socket_pool .socket (addr_info [0 ], addr_info [1 ])
211
+ except (OSError , RuntimeError ) as exc :
212
+ return exc
213
+
214
+ if is_ssl :
215
+ socket = ssl_context .wrap_socket (socket , server_hostname = host )
216
+ connect_host = host
217
+ else :
218
+ connect_host = addr_info [- 1 ][0 ]
219
+ socket .settimeout (timeout ) # socket read timeout
220
+
221
+ try :
222
+ socket .connect ((connect_host , port ))
223
+ except (MemoryError , OSError ) as exc :
224
+ socket .close ()
225
+ return exc
226
+
227
+ return socket
228
+
229
+ @property
230
+ def available_socket_count (self ) -> int :
231
+ """Get the count of freeable open sockets"""
232
+ return len (self ._available_sockets )
233
+
234
+ @property
235
+ def managed_socket_count (self ) -> int :
236
+ """Get the count of open sockets"""
237
+ return len (self ._managed_socket_by_key )
205
238
206
239
def close_socket (self , socket : SocketType ) -> None :
207
240
"""Close a previously opened socket."""
208
- if socket not in self ._open_sockets .values ():
241
+ if socket not in self ._managed_socket_by_key .values ():
209
242
raise RuntimeError ("Socket not managed" )
210
- key = self ._get_key_for_socket (socket )
211
243
socket .close ()
212
- del self ._available_socket [socket ]
213
- del self ._open_sockets [key ]
244
+ key = self ._key_by_managed_socket .pop (socket )
245
+ del self ._managed_socket_by_key [key ]
246
+ if socket in self ._available_sockets :
247
+ self ._available_sockets .remove (socket )
214
248
215
249
def free_socket (self , socket : SocketType ) -> None :
216
250
"""Mark a previously opened socket as available so it can be reused if needed."""
217
- if socket not in self ._open_sockets .values ():
251
+ if socket not in self ._managed_socket_by_key .values ():
218
252
raise RuntimeError ("Socket not managed" )
219
- self ._available_socket [ socket ] = True
253
+ self ._available_sockets . add ( socket )
220
254
221
- # pylint: disable=too-many-branches,too-many-locals,too-many-statements
222
255
def get_socket (
223
256
self ,
224
257
host : str ,
@@ -234,10 +267,10 @@ def get_socket(
234
267
if session_id :
235
268
session_id = str (session_id )
236
269
key = (host , port , proto , session_id )
237
- if key in self ._open_sockets :
238
- socket = self ._open_sockets [key ]
239
- if self ._available_socket [ socket ] :
240
- self ._available_socket [ socket ] = False
270
+ if key in self ._managed_socket_by_key :
271
+ socket = self ._managed_socket_by_key [key ]
272
+ if socket in self ._available_sockets :
273
+ self ._available_sockets . remove ( socket )
241
274
return socket
242
275
243
276
raise RuntimeError (f"Socket already connected to { proto } //{ host } :{ port } " )
@@ -253,64 +286,68 @@ def get_socket(
253
286
host , port , 0 , self ._socket_pool .SOCK_STREAM
254
287
)[0 ]
255
288
256
- try_count = 0
257
- socket = None
258
- last_exc = None
259
- while try_count < 2 and socket is None :
260
- try_count += 1
261
- if try_count > 1 :
262
- if any (
263
- socket
264
- for socket , free in self ._available_socket .items ()
265
- if free is True
266
- ):
267
- self ._free_sockets ()
268
- else :
269
- break
270
-
271
- try :
272
- socket = self ._socket_pool .socket (addr_info [0 ], addr_info [1 ])
273
- except OSError as exc :
274
- last_exc = exc
275
- continue
276
- except RuntimeError as exc :
277
- last_exc = exc
278
- continue
279
-
280
- if is_ssl :
281
- socket = ssl_context .wrap_socket (socket , server_hostname = host )
282
- connect_host = host
283
- else :
284
- connect_host = addr_info [- 1 ][0 ]
285
- socket .settimeout (timeout ) # socket read timeout
286
-
287
- try :
288
- socket .connect ((connect_host , port ))
289
- except MemoryError as exc :
290
- last_exc = exc
291
- socket .close ()
292
- socket = None
293
- except OSError as exc :
294
- last_exc = exc
295
- socket .close ()
296
- socket = None
297
-
298
- if socket is None :
299
- raise RuntimeError (f"Error connecting socket: { last_exc } " ) from last_exc
300
-
301
- self ._available_socket [socket ] = False
302
- self ._open_sockets [key ] = socket
303
- return socket
289
+ first_exception = None
290
+ result = self ._get_connected_socket (
291
+ addr_info , host , port , timeout , is_ssl , ssl_context
292
+ )
293
+ if isinstance (result , Exception ):
294
+ # Got an error, if there are any available sockets, free them and try again
295
+ if self .available_socket_count :
296
+ first_exception = result
297
+ self ._free_sockets ()
298
+ result = self ._get_connected_socket (
299
+ addr_info , host , port , timeout , is_ssl , ssl_context
300
+ )
301
+ if isinstance (result , Exception ):
302
+ last_result = f", first error: { first_exception } " if first_exception else ""
303
+ raise RuntimeError (
304
+ f"Error connecting socket: { result } { last_result } "
305
+ ) from result
306
+
307
+ self ._key_by_managed_socket [result ] = key
308
+ self ._managed_socket_by_key [key ] = result
309
+ return result
304
310
305
311
306
312
# global helpers
307
313
308
314
309
- _global_connection_manager = {}
315
+ def connection_manager_close_all (
316
+ socket_pool : Optional [SocketpoolModuleType ] = None , release_references : bool = False
317
+ ) -> None :
318
+ """Close all open sockets for pool"""
319
+ if socket_pool :
320
+ socket_pools = [socket_pool ]
321
+ else :
322
+ socket_pools = _global_connection_managers .keys ()
323
+
324
+ for pool in socket_pools :
325
+ connection_manager = _global_connection_managers .get (pool , None )
326
+ if connection_manager is None :
327
+ raise RuntimeError ("SocketPool not managed" )
328
+
329
+ connection_manager ._free_sockets (force = True ) # pylint: disable=protected-access
330
+
331
+ if release_references :
332
+ radio_key = None
333
+ for radio_check , pool_check in _global_socketpools .items ():
334
+ if pool == pool_check :
335
+ radio_key = radio_check
336
+ break
337
+
338
+ if radio_key :
339
+ if radio_key in _global_socketpools :
340
+ del _global_socketpools [radio_key ]
341
+
342
+ if radio_key in _global_ssl_contexts :
343
+ del _global_ssl_contexts [radio_key ]
344
+
345
+ if pool in _global_connection_managers :
346
+ del _global_connection_managers [pool ]
310
347
311
348
312
349
def get_connection_manager (socket_pool : SocketpoolModuleType ) -> ConnectionManager :
313
350
"""Get the ConnectionManager singleton for the given pool"""
314
- if socket_pool not in _global_connection_manager :
315
- _global_connection_manager [socket_pool ] = ConnectionManager (socket_pool )
316
- return _global_connection_manager [socket_pool ]
351
+ if socket_pool not in _global_connection_managers :
352
+ _global_connection_managers [socket_pool ] = ConnectionManager (socket_pool )
353
+ return _global_connection_managers [socket_pool ]
0 commit comments