13
13
Optional ,
14
14
Type ,
15
15
TypeVar ,
16
+ Tuple ,
16
17
Union ,
17
18
)
18
19
@@ -250,6 +251,7 @@ def __init__(
250
251
ssl_certfile : Optional [str ] = None ,
251
252
ssl_check_hostname : bool = False ,
252
253
ssl_keyfile : Optional [str ] = None ,
254
+ host_port_remap : List [Dict [str , Any ]] = [],
253
255
) -> None :
254
256
if db :
255
257
raise RedisClusterException (
@@ -337,7 +339,12 @@ def __init__(
337
339
if host and port :
338
340
startup_nodes .append (ClusterNode (host , port , ** self .connection_kwargs ))
339
341
340
- self .nodes_manager = NodesManager (startup_nodes , require_full_coverage , kwargs )
342
+ self .nodes_manager = NodesManager (
343
+ startup_nodes ,
344
+ require_full_coverage ,
345
+ kwargs ,
346
+ host_port_remap = host_port_remap ,
347
+ )
341
348
self .encoder = Encoder (encoding , encoding_errors , decode_responses )
342
349
self .read_from_replicas = read_from_replicas
343
350
self .reinitialize_steps = reinitialize_steps
@@ -1044,17 +1051,20 @@ class NodesManager:
1044
1051
"require_full_coverage" ,
1045
1052
"slots_cache" ,
1046
1053
"startup_nodes" ,
1054
+ "host_port_remap" ,
1047
1055
)
1048
1056
1049
1057
def __init__ (
1050
1058
self ,
1051
1059
startup_nodes : List ["ClusterNode" ],
1052
1060
require_full_coverage : bool ,
1053
1061
connection_kwargs : Dict [str , Any ],
1062
+ host_port_remap : List [Dict [str , Any ]] = [],
1054
1063
) -> None :
1055
1064
self .startup_nodes = {node .name : node for node in startup_nodes }
1056
1065
self .require_full_coverage = require_full_coverage
1057
1066
self .connection_kwargs = connection_kwargs
1067
+ self .host_port_remap = host_port_remap
1058
1068
1059
1069
self .default_node : "ClusterNode" = None
1060
1070
self .nodes_cache : Dict [str , "ClusterNode" ] = {}
@@ -1213,6 +1223,7 @@ async def initialize(self) -> None:
1213
1223
if host == "" :
1214
1224
host = startup_node .host
1215
1225
port = int (primary_node [1 ])
1226
+ host , port = self .remap_host_port (host , port )
1216
1227
1217
1228
target_node = tmp_nodes_cache .get (get_node_name (host , port ))
1218
1229
if not target_node :
@@ -1231,6 +1242,7 @@ async def initialize(self) -> None:
1231
1242
for replica_node in replica_nodes :
1232
1243
host = replica_node [0 ]
1233
1244
port = replica_node [1 ]
1245
+ host , port = self .remap_host_port (host , port )
1234
1246
1235
1247
target_replica_node = tmp_nodes_cache .get (
1236
1248
get_node_name (host , port )
@@ -1304,6 +1316,30 @@ async def close(self, attr: str = "nodes_cache") -> None:
1304
1316
)
1305
1317
)
1306
1318
1319
+ def remap_host_port (self , host : str , port : int ) -> Tuple [str , int ]:
1320
+ """
1321
+ Remap the host and port returned from the cluster to a different
1322
+ internal value. Useful if the client is not connecting directly
1323
+ to the cluster.
1324
+ """
1325
+ for map_entry in self .host_port_remap :
1326
+ mapped = False
1327
+ if "from_host" in map_entry :
1328
+ if host != map_entry ["from_host" ]:
1329
+ continue
1330
+ else :
1331
+ host = map_entry ["to_host" ]
1332
+ mapped = True
1333
+ if "from_port" in map_entry :
1334
+ if port != map_entry ["from_port" ]:
1335
+ continue
1336
+ else :
1337
+ port = map_entry ["to_port" ]
1338
+ mapped = True
1339
+ if mapped :
1340
+ break
1341
+ return host , port
1342
+
1307
1343
1308
1344
class ClusterPipeline (AbstractRedis , AbstractRedisCluster , AsyncRedisClusterCommands ):
1309
1345
"""
0 commit comments