Skip to content

Commit d9cadda

Browse files
tallakhgithub-actions[bot]
authored andcommitted
Fix node_pool_class override (#2581)
Co-authored-by: Quentin Pradet <[email protected]> (cherry picked from commit 7e9ea0d)
1 parent 6afb6cb commit d9cadda

File tree

4 files changed

+74
-4
lines changed

4 files changed

+74
-4
lines changed

Diff for: elasticsearch/_async/client/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ def __init__(
352352
if node_class is not DEFAULT:
353353
transport_kwargs["node_class"] = node_class
354354
if node_pool_class is not DEFAULT:
355-
transport_kwargs["node_pool_class"] = node_class
355+
transport_kwargs["node_pool_class"] = node_pool_class
356356
if randomize_nodes_in_pool is not DEFAULT:
357357
transport_kwargs["randomize_nodes_in_pool"] = randomize_nodes_in_pool
358358
if node_selector_class is not DEFAULT:

Diff for: elasticsearch/_sync/client/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ def __init__(
352352
if node_class is not DEFAULT:
353353
transport_kwargs["node_class"] = node_class
354354
if node_pool_class is not DEFAULT:
355-
transport_kwargs["node_pool_class"] = node_class
355+
transport_kwargs["node_pool_class"] = node_pool_class
356356
if randomize_nodes_in_pool is not DEFAULT:
357357
transport_kwargs["randomize_nodes_in_pool"] = randomize_nodes_in_pool
358358
if node_selector_class is not DEFAULT:

Diff for: test_elasticsearch/test_async/test_transport.py

+36-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,13 @@
2424
from typing import Any, Dict, Optional, Union
2525

2626
import pytest
27-
from elastic_transport import ApiResponseMeta, BaseAsyncNode, HttpHeaders, NodeConfig
27+
from elastic_transport import (
28+
ApiResponseMeta,
29+
BaseAsyncNode,
30+
HttpHeaders,
31+
NodeConfig,
32+
NodePool,
33+
)
2834
from elastic_transport._node import NodeApiResponse
2935
from elastic_transport.client_utils import DEFAULT
3036

@@ -73,6 +79,14 @@ async def close(self):
7379
self.closed = True
7480

7581

82+
class NoTimeoutConnectionPool(NodePool):
83+
def mark_dead(self, connection):
84+
pass
85+
86+
def mark_live(self, connection):
87+
pass
88+
89+
7690
CLUSTER_NODES = """{
7791
"_nodes" : {
7892
"total" : 1,
@@ -345,6 +359,27 @@ async def test_resurrected_connection_will_be_marked_as_live_on_success(self):
345359
assert len(client.transport.node_pool._alive_nodes) == 1
346360
assert len(client.transport.node_pool._dead_consecutive_failures) == 1
347361

362+
async def test_override_mark_dead_mark_live(self):
363+
client = AsyncElasticsearch(
364+
[
365+
NodeConfig("http", "localhost", 9200),
366+
NodeConfig("http", "localhost", 9201),
367+
],
368+
node_class=DummyNode,
369+
node_pool_class=NoTimeoutConnectionPool,
370+
)
371+
node1 = client.transport.node_pool.get()
372+
node2 = client.transport.node_pool.get()
373+
assert node1 is not node2
374+
client.transport.node_pool.mark_dead(node1)
375+
client.transport.node_pool.mark_dead(node2)
376+
assert len(client.transport.node_pool._alive_nodes) == 2
377+
378+
await client.info()
379+
380+
assert len(client.transport.node_pool._alive_nodes) == 2
381+
assert len(client.transport.node_pool._dead_consecutive_failures) == 0
382+
348383
@pytest.mark.parametrize(
349384
["nodes_info_response", "node_host"],
350385
[(CLUSTER_NODES, "1.1.1.1"), (CLUSTER_NODES_7x_PUBLISH_HOST, "somehost.tld")],

Diff for: test_elasticsearch/test_transport.py

+36-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,13 @@
2222
from typing import Any, Dict, Optional, Union
2323

2424
import pytest
25-
from elastic_transport import ApiResponseMeta, BaseNode, HttpHeaders, NodeConfig
25+
from elastic_transport import (
26+
ApiResponseMeta,
27+
BaseNode,
28+
HttpHeaders,
29+
NodeConfig,
30+
NodePool,
31+
)
2632
from elastic_transport._node import NodeApiResponse
2733
from elastic_transport.client_utils import DEFAULT
2834

@@ -64,6 +70,14 @@ def perform_request(self, *args, **kwargs):
6470
)
6571

6672

73+
class NoTimeoutConnectionPool(NodePool):
74+
def mark_dead(self, connection):
75+
pass
76+
77+
def mark_live(self, connection):
78+
pass
79+
80+
6781
CLUSTER_NODES = """{
6882
"_nodes" : {
6983
"total" : 1,
@@ -376,6 +390,27 @@ def test_resurrected_connection_will_be_marked_as_live_on_success(self):
376390
assert len(client.transport.node_pool._alive_nodes) == 1
377391
assert len(client.transport.node_pool._dead_consecutive_failures) == 1
378392

393+
def test_override_mark_dead_mark_live(self):
394+
client = Elasticsearch(
395+
[
396+
NodeConfig("http", "localhost", 9200),
397+
NodeConfig("http", "localhost", 9201),
398+
],
399+
node_class=DummyNode,
400+
node_pool_class=NoTimeoutConnectionPool,
401+
)
402+
node1 = client.transport.node_pool.get()
403+
node2 = client.transport.node_pool.get()
404+
assert node1 is not node2
405+
client.transport.node_pool.mark_dead(node1)
406+
client.transport.node_pool.mark_dead(node2)
407+
assert len(client.transport.node_pool._alive_nodes) == 2
408+
409+
client.info()
410+
411+
assert len(client.transport.node_pool._alive_nodes) == 2
412+
assert len(client.transport.node_pool._dead_consecutive_failures) == 0
413+
379414
@pytest.mark.parametrize(
380415
["nodes_info_response", "node_host"],
381416
[(CLUSTER_NODES, "1.1.1.1"), (CLUSTER_NODES_7x_PUBLISH_HOST, "somehost.tld")],

0 commit comments

Comments
 (0)