Skip to content

Commit f75caaf

Browse files
committed
add cluster "host_port_remap" feature
1 parent 482713a commit f75caaf

File tree

1 file changed

+37
-1
lines changed

1 file changed

+37
-1
lines changed

redis/asyncio/cluster.py

+37-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
Optional,
1414
Type,
1515
TypeVar,
16+
Tuple,
1617
Union,
1718
)
1819

@@ -250,6 +251,7 @@ def __init__(
250251
ssl_certfile: Optional[str] = None,
251252
ssl_check_hostname: bool = False,
252253
ssl_keyfile: Optional[str] = None,
254+
host_port_remap: List[Dict[str, Any]] = [],
253255
) -> None:
254256
if db:
255257
raise RedisClusterException(
@@ -337,7 +339,12 @@ def __init__(
337339
if host and port:
338340
startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs))
339341

340-
self.nodes_manager = NodesManager(startup_nodes, require_full_coverage, kwargs)
342+
self.nodes_manager = NodesManager(
343+
startup_nodes,
344+
require_full_coverage,
345+
kwargs,
346+
host_port_remap=host_port_remap,
347+
)
341348
self.encoder = Encoder(encoding, encoding_errors, decode_responses)
342349
self.read_from_replicas = read_from_replicas
343350
self.reinitialize_steps = reinitialize_steps
@@ -1044,17 +1051,20 @@ class NodesManager:
10441051
"require_full_coverage",
10451052
"slots_cache",
10461053
"startup_nodes",
1054+
"host_port_remap",
10471055
)
10481056

10491057
def __init__(
10501058
self,
10511059
startup_nodes: List["ClusterNode"],
10521060
require_full_coverage: bool,
10531061
connection_kwargs: Dict[str, Any],
1062+
host_port_remap: List[Dict[str, Any]] = [],
10541063
) -> None:
10551064
self.startup_nodes = {node.name: node for node in startup_nodes}
10561065
self.require_full_coverage = require_full_coverage
10571066
self.connection_kwargs = connection_kwargs
1067+
self.host_port_remap = host_port_remap
10581068

10591069
self.default_node: "ClusterNode" = None
10601070
self.nodes_cache: Dict[str, "ClusterNode"] = {}
@@ -1213,6 +1223,7 @@ async def initialize(self) -> None:
12131223
if host == "":
12141224
host = startup_node.host
12151225
port = int(primary_node[1])
1226+
host, port = self.remap_host_port(host, port)
12161227

12171228
target_node = tmp_nodes_cache.get(get_node_name(host, port))
12181229
if not target_node:
@@ -1231,6 +1242,7 @@ async def initialize(self) -> None:
12311242
for replica_node in replica_nodes:
12321243
host = replica_node[0]
12331244
port = replica_node[1]
1245+
host, port = self.remap_host_port(host, port)
12341246

12351247
target_replica_node = tmp_nodes_cache.get(
12361248
get_node_name(host, port)
@@ -1304,6 +1316,30 @@ async def close(self, attr: str = "nodes_cache") -> None:
13041316
)
13051317
)
13061318

1319+
def remap_host_port(self, host: str, port: int) -> Tuple[str, int]:
1320+
"""
1321+
Remap the host and port returned from the cluster to a different
1322+
internal value. Useful if the client is not connecting directly
1323+
to the cluster.
1324+
"""
1325+
for map_entry in self.host_port_remap:
1326+
mapped = False
1327+
if "from_host" in map_entry:
1328+
if host != map_entry["from_host"]:
1329+
continue
1330+
else:
1331+
host = map_entry["to_host"]
1332+
mapped = True
1333+
if "from_port" in map_entry:
1334+
if port != map_entry["from_port"]:
1335+
continue
1336+
else:
1337+
port = map_entry["to_port"]
1338+
mapped = True
1339+
if mapped:
1340+
break
1341+
return host, port
1342+
13071343

13081344
class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands):
13091345
"""

0 commit comments

Comments
 (0)