Skip to content

Commit 1531496

Browse files
authored
Merge pull request #13 from justmobilize/close-all-and-counts
Close all and counts
2 parents fc33375 + 601ee66 commit 1531496

8 files changed

+381
-141
lines changed

adafruit_connection_manager.py

+122-85
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535

3636

3737
if not sys.implementation.name == "circuitpython":
38-
from typing import Optional, Tuple
38+
from typing import List, Optional, Tuple
3939

4040
from circuitpython_typing.socket import (
4141
CircuitPythonSocketType,
@@ -64,15 +64,14 @@ def connect(self, address: Tuple[str, int]) -> None:
6464
try:
6565
return self._socket.connect(address, self._mode)
6666
except RuntimeError as error:
67-
raise OSError(errno.ENOMEM) from error
67+
raise OSError(errno.ENOMEM, str(error)) from error
6868

6969

7070
class _FakeSSLContext:
7171
def __init__(self, iface: InterfaceType) -> None:
7272
self._iface = iface
7373

74-
# pylint: disable=unused-argument
75-
def wrap_socket(
74+
def wrap_socket( # pylint: disable=unused-argument
7675
self, socket: CircuitPythonSocketType, server_hostname: Optional[str] = None
7776
) -> _FakeSSLSocket:
7877
"""Return the same socket"""
@@ -99,7 +98,8 @@ def create_fake_ssl_context(
9998
return _FakeSSLContext(iface)
10099

101100

102-
_global_socketpool = {}
101+
_global_connection_managers = {}
102+
_global_socketpools = {}
103103
_global_ssl_contexts = {}
104104

105105

@@ -113,7 +113,7 @@ def get_radio_socketpool(radio):
113113
* Using a WIZ5500 (Like the Adafruit Ethernet FeatherWing)
114114
"""
115115
class_name = radio.__class__.__name__
116-
if class_name not in _global_socketpool:
116+
if class_name not in _global_socketpools:
117117
if class_name == "Radio":
118118
import ssl # pylint: disable=import-outside-toplevel
119119

@@ -151,10 +151,10 @@ def get_radio_socketpool(radio):
151151
else:
152152
raise AttributeError(f"Unsupported radio class: {class_name}")
153153

154-
_global_socketpool[class_name] = pool
154+
_global_socketpools[class_name] = pool
155155
_global_ssl_contexts[class_name] = ssl_context
156156

157-
return _global_socketpool[class_name]
157+
return _global_socketpools[class_name]
158158

159159

160160
def get_radio_ssl_context(radio):
@@ -183,42 +183,75 @@ def __init__(
183183
) -> None:
184184
self._socket_pool = socket_pool
185185
# Hang onto open sockets so that we can reuse them.
186-
self._available_socket = {}
187-
self._open_sockets = {}
188-
189-
def _free_sockets(self) -> None:
190-
available_sockets = []
191-
for socket, free in self._available_socket.items():
192-
if free:
193-
available_sockets.append(socket)
186+
self._available_sockets = set()
187+
self._key_by_managed_socket = {}
188+
self._managed_socket_by_key = {}
194189

190+
def _free_sockets(self, force: bool = False) -> None:
191+
# cloning lists since items are being removed
192+
available_sockets = list(self._available_sockets)
195193
for socket in available_sockets:
196194
self.close_socket(socket)
195+
if force:
196+
open_sockets = list(self._managed_socket_by_key.values())
197+
for socket in open_sockets:
198+
self.close_socket(socket)
197199

198-
def _get_key_for_socket(self, socket):
200+
def _get_connected_socket( # pylint: disable=too-many-arguments
201+
self,
202+
addr_info: List[Tuple[int, int, int, str, Tuple[str, int]]],
203+
host: str,
204+
port: int,
205+
timeout: float,
206+
is_ssl: bool,
207+
ssl_context: Optional[SSLContextType] = None,
208+
):
199209
try:
200-
return next(
201-
key for key, value in self._open_sockets.items() if value == socket
202-
)
203-
except StopIteration:
204-
return None
210+
socket = self._socket_pool.socket(addr_info[0], addr_info[1])
211+
except (OSError, RuntimeError) as exc:
212+
return exc
213+
214+
if is_ssl:
215+
socket = ssl_context.wrap_socket(socket, server_hostname=host)
216+
connect_host = host
217+
else:
218+
connect_host = addr_info[-1][0]
219+
socket.settimeout(timeout) # socket read timeout
220+
221+
try:
222+
socket.connect((connect_host, port))
223+
except (MemoryError, OSError) as exc:
224+
socket.close()
225+
return exc
226+
227+
return socket
228+
229+
@property
230+
def available_socket_count(self) -> int:
231+
"""Get the count of freeable open sockets"""
232+
return len(self._available_sockets)
233+
234+
@property
235+
def managed_socket_count(self) -> int:
236+
"""Get the count of open sockets"""
237+
return len(self._managed_socket_by_key)
205238

206239
def close_socket(self, socket: SocketType) -> None:
207240
"""Close a previously opened socket."""
208-
if socket not in self._open_sockets.values():
241+
if socket not in self._managed_socket_by_key.values():
209242
raise RuntimeError("Socket not managed")
210-
key = self._get_key_for_socket(socket)
211243
socket.close()
212-
del self._available_socket[socket]
213-
del self._open_sockets[key]
244+
key = self._key_by_managed_socket.pop(socket)
245+
del self._managed_socket_by_key[key]
246+
if socket in self._available_sockets:
247+
self._available_sockets.remove(socket)
214248

215249
def free_socket(self, socket: SocketType) -> None:
216250
"""Mark a previously opened socket as available so it can be reused if needed."""
217-
if socket not in self._open_sockets.values():
251+
if socket not in self._managed_socket_by_key.values():
218252
raise RuntimeError("Socket not managed")
219-
self._available_socket[socket] = True
253+
self._available_sockets.add(socket)
220254

221-
# pylint: disable=too-many-branches,too-many-locals,too-many-statements
222255
def get_socket(
223256
self,
224257
host: str,
@@ -234,10 +267,10 @@ def get_socket(
234267
if session_id:
235268
session_id = str(session_id)
236269
key = (host, port, proto, session_id)
237-
if key in self._open_sockets:
238-
socket = self._open_sockets[key]
239-
if self._available_socket[socket]:
240-
self._available_socket[socket] = False
270+
if key in self._managed_socket_by_key:
271+
socket = self._managed_socket_by_key[key]
272+
if socket in self._available_sockets:
273+
self._available_sockets.remove(socket)
241274
return socket
242275

243276
raise RuntimeError(f"Socket already connected to {proto}//{host}:{port}")
@@ -253,64 +286,68 @@ def get_socket(
253286
host, port, 0, self._socket_pool.SOCK_STREAM
254287
)[0]
255288

256-
try_count = 0
257-
socket = None
258-
last_exc = None
259-
while try_count < 2 and socket is None:
260-
try_count += 1
261-
if try_count > 1:
262-
if any(
263-
socket
264-
for socket, free in self._available_socket.items()
265-
if free is True
266-
):
267-
self._free_sockets()
268-
else:
269-
break
270-
271-
try:
272-
socket = self._socket_pool.socket(addr_info[0], addr_info[1])
273-
except OSError as exc:
274-
last_exc = exc
275-
continue
276-
except RuntimeError as exc:
277-
last_exc = exc
278-
continue
279-
280-
if is_ssl:
281-
socket = ssl_context.wrap_socket(socket, server_hostname=host)
282-
connect_host = host
283-
else:
284-
connect_host = addr_info[-1][0]
285-
socket.settimeout(timeout) # socket read timeout
286-
287-
try:
288-
socket.connect((connect_host, port))
289-
except MemoryError as exc:
290-
last_exc = exc
291-
socket.close()
292-
socket = None
293-
except OSError as exc:
294-
last_exc = exc
295-
socket.close()
296-
socket = None
297-
298-
if socket is None:
299-
raise RuntimeError(f"Error connecting socket: {last_exc}") from last_exc
300-
301-
self._available_socket[socket] = False
302-
self._open_sockets[key] = socket
303-
return socket
289+
first_exception = None
290+
result = self._get_connected_socket(
291+
addr_info, host, port, timeout, is_ssl, ssl_context
292+
)
293+
if isinstance(result, Exception):
294+
# Got an error, if there are any available sockets, free them and try again
295+
if self.available_socket_count:
296+
first_exception = result
297+
self._free_sockets()
298+
result = self._get_connected_socket(
299+
addr_info, host, port, timeout, is_ssl, ssl_context
300+
)
301+
if isinstance(result, Exception):
302+
last_result = f", first error: {first_exception}" if first_exception else ""
303+
raise RuntimeError(
304+
f"Error connecting socket: {result}{last_result}"
305+
) from result
306+
307+
self._key_by_managed_socket[result] = key
308+
self._managed_socket_by_key[key] = result
309+
return result
304310

305311

306312
# global helpers
307313

308314

309-
_global_connection_manager = {}
315+
def connection_manager_close_all(
316+
socket_pool: Optional[SocketpoolModuleType] = None, release_references: bool = False
317+
) -> None:
318+
"""Close all open sockets for pool"""
319+
if socket_pool:
320+
socket_pools = [socket_pool]
321+
else:
322+
socket_pools = _global_connection_managers.keys()
323+
324+
for pool in socket_pools:
325+
connection_manager = _global_connection_managers.get(pool, None)
326+
if connection_manager is None:
327+
raise RuntimeError("SocketPool not managed")
328+
329+
connection_manager._free_sockets(force=True) # pylint: disable=protected-access
330+
331+
if release_references:
332+
radio_key = None
333+
for radio_check, pool_check in _global_socketpools.items():
334+
if pool == pool_check:
335+
radio_key = radio_check
336+
break
337+
338+
if radio_key:
339+
if radio_key in _global_socketpools:
340+
del _global_socketpools[radio_key]
341+
342+
if radio_key in _global_ssl_contexts:
343+
del _global_ssl_contexts[radio_key]
344+
345+
if pool in _global_connection_managers:
346+
del _global_connection_managers[pool]
310347

311348

312349
def get_connection_manager(socket_pool: SocketpoolModuleType) -> ConnectionManager:
313350
"""Get the ConnectionManager singleton for the given pool"""
314-
if socket_pool not in _global_connection_manager:
315-
_global_connection_manager[socket_pool] = ConnectionManager(socket_pool)
316-
return _global_connection_manager[socket_pool]
351+
if socket_pool not in _global_connection_managers:
352+
_global_connection_managers[socket_pool] = ConnectionManager(socket_pool)
353+
return _global_connection_managers[socket_pool]

examples/connectionmanager_helpers.py

+28-4
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,38 @@
2424

2525
# get request session
2626
requests = adafruit_requests.Session(pool, ssl_context)
27+
connection_manager = adafruit_connection_manager.get_connection_manager(pool)
28+
print("-" * 40)
29+
print("Nothing yet opened")
30+
print(f"Open Sockets: {connection_manager.managed_socket_count}")
31+
print(f"Freeable Open Sockets: {connection_manager.available_socket_count}")
2732

2833
# make request
2934
print("-" * 40)
30-
print(f"Fetching from {TEXT_URL}")
35+
print(f"Fetching from {TEXT_URL} in a context handler")
36+
with requests.get(TEXT_URL) as response:
37+
response_text = response.text
38+
print(f"Text Response {response_text}")
39+
40+
print("-" * 40)
41+
print("1 request, opened and freed")
42+
print(f"Open Sockets: {connection_manager.managed_socket_count}")
43+
print(f"Freeable Open Sockets: {connection_manager.available_socket_count}")
3144

45+
print("-" * 40)
46+
print(f"Fetching from {TEXT_URL} not in a context handler")
3247
response = requests.get(TEXT_URL)
33-
response_text = response.text
34-
response.close()
3548

36-
print(f"Text Response {response_text}")
3749
print("-" * 40)
50+
print("1 request, opened but not freed")
51+
print(f"Open Sockets: {connection_manager.managed_socket_count}")
52+
print(f"Freeable Open Sockets: {connection_manager.available_socket_count}")
53+
54+
print("-" * 40)
55+
print("Closing everything in the pool")
56+
adafruit_connection_manager.connection_manager_close_all(pool)
57+
58+
print("-" * 40)
59+
print("Everything closed")
60+
print(f"Open Sockets: {connection_manager.managed_socket_count}")
61+
print(f"Freeable Open Sockets: {connection_manager.available_socket_count}")

tests/close_socket_test.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,13 @@ def test_close_socket():
2121
socket = connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:")
2222
key = (mocket.MOCK_HOST_1, 80, "http:", None)
2323
assert socket == mock_socket_1
24-
assert socket in connection_manager._available_socket
25-
assert key in connection_manager._open_sockets
24+
assert socket not in connection_manager._available_sockets
25+
assert key in connection_manager._managed_socket_by_key
2626

2727
# validate socket is no longer tracked
2828
connection_manager.close_socket(socket)
29-
assert socket not in connection_manager._available_socket
30-
assert key not in connection_manager._open_sockets
29+
assert socket not in connection_manager._available_sockets
30+
assert key not in connection_manager._managed_socket_by_key
3131

3232

3333
def test_close_socket_not_managed():

tests/conftest.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,11 @@ def adafruit_wiznet5k_with_ssl_socket_module():
6565
@pytest.fixture(autouse=True)
6666
def reset_connection_manager(monkeypatch):
6767
monkeypatch.setattr(
68-
"adafruit_connection_manager._global_socketpool",
68+
"adafruit_connection_manager._global_connection_managers",
69+
{},
70+
)
71+
monkeypatch.setattr(
72+
"adafruit_connection_manager._global_socketpools",
6973
{},
7074
)
7175
monkeypatch.setattr(

0 commit comments

Comments
 (0)