From d33d897285a423ab338622ce3593a8ea0f166bd3 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Mon, 15 May 2023 12:15:59 +0300 Subject: [PATCH 01/10] sharded pubsub --- redis/client.py | 90 +++++++++++++++++++++++++++++--- redis/cluster.py | 18 +++++++ redis/commands/core.py | 9 ++++ redis/parsers/socket.py | 6 ++- tests/test_pubsub.py | 113 ++++++++++++++++++++++++++++++++++++++-- 5 files changed, 222 insertions(+), 14 deletions(-) diff --git a/redis/client.py b/redis/client.py index c303dbde38..8414be81c7 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1440,8 +1440,8 @@ class PubSub: will be returned and it's safe to start listening again. """ - PUBLISH_MESSAGE_TYPES = ("message", "pmessage") - UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe", "punsubscribe") + PUBLISH_MESSAGE_TYPES = ("message", "pmessage", "smessage") + UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe", "punsubscribe", "sunsubscribe") HEALTH_CHECK_MESSAGE = "redis-py-health-check" def __init__( @@ -1493,11 +1493,13 @@ def reset(self): self.connection.clear_connect_callbacks() self.connection_pool.release(self.connection) self.connection = None - self.channels = {} self.health_check_response_counter = 0 + self.channels = {} self.pending_unsubscribe_channels = set() self.patterns = {} self.pending_unsubscribe_patterns = set() + self.shard_channels = {} + self.pending_unsubscribe_shard_channels = set() self.subscribed_event.clear() def close(self): @@ -1510,6 +1512,7 @@ def on_connect(self, connection): # before passing them to [p]subscribe. self.pending_unsubscribe_channels.clear() self.pending_unsubscribe_patterns.clear() + self.pending_unsubscribe_shard_channels.clear() if self.channels: channels = {} for k, v in self.channels.items(): @@ -1520,6 +1523,11 @@ def on_connect(self, connection): for k, v in self.patterns.items(): patterns[self.encoder.decode(k, force=True)] = v self.psubscribe(**patterns) + if self.shard_channels: + shard_channels = {} + for k, v in self.shard_channels.items(): + shard_channels[self.encoder.decode(k, force=True)] = v + self.ssubscribe(**shard_channels) @property def subscribed(self): @@ -1590,7 +1598,7 @@ def _execute(self, conn, command, *args, **kwargs): lambda error: self._disconnect_raise_connect(conn, error), ) - def parse_response(self, block=True, timeout=0): + def parse_response(self, block=True, timeout=0, **kwargs): """Parse the response from a publish/subscribe command""" conn = self.connection if conn is None: @@ -1603,7 +1611,10 @@ def parse_response(self, block=True, timeout=0): def try_read(): if not block: - if not conn.can_read(timeout=timeout): + print("###################") + can_read = conn.can_read(timeout=timeout) + print(can_read) + if not can_read: return None else: conn.connect() @@ -1728,6 +1739,57 @@ def unsubscribe(self, *args): self.pending_unsubscribe_channels.update(channels) return self.execute_command("UNSUBSCRIBE", *args) + def ssubscribe(self, *args, target_node=None, **kwargs): + """ + # TODO: update docstring + Subscribe to channels. Channels supplied as keyword arguments expect + a channel name as the key and a callable as the value. A channel's + callable will be invoked automatically when a message is received on + that channel rather than producing a message via ``listen()`` or + ``get_message()``. + """ + if args: + args = list_or_args(args[0], args[1:]) + new_s_channels = dict.fromkeys(args) + new_s_channels.update(kwargs) + try: + # cluster mode + ret_val = self.execute_command("SSUBSCRIBE", *new_s_channels.keys(), node=target_node) + except TypeError: + # standalone mode + ret_val = self.execute_command("SSUBSCRIBE", *new_s_channels.keys()) + # update the s_channels dict AFTER we send the command. we don't want to + # subscribe twice to these channels, once for the command and again + # for the reconnection. + new_s_channels = self._normalize_keys(new_s_channels) + self.shard_channels.update(new_s_channels) + if not self.subscribed: + # Set the subscribed_event flag to True + self.subscribed_event.set() + # Clear the health check counter + self.health_check_response_counter = 0 + self.pending_unsubscribe_shard_channels.difference_update(new_s_channels) + return ret_val + + def sunsubscribe(self, *args, target_node=None): + """ + # TODO: update docstring + Unsubscribe from the supplied channels. If empty, unsubscribe from + all channels + """ + if args: + args = list_or_args(args[0], args[1:]) + s_channels = self._normalize_keys(dict.fromkeys(args)) + else: + s_channels = self.shard_channels + self.pending_unsubscribe_shard_channels.update(s_channels) + try: + # cluster mode + return self.execute_command("SUNSUBSCRIBE", *args, node=target_node) + except TypeError: + # standalone mode + return self.execute_command("SUNSUBSCRIBE", *args) + def listen(self): "Listen for messages on channels this client has been subscribed to" while self.subscribed: @@ -1735,7 +1797,7 @@ def listen(self): if response is not None: yield response - def get_message(self, ignore_subscribe_messages=False, timeout=0.0): + def get_message(self, ignore_subscribe_messages=False, timeout=0.0, node=None): """ Get the next message if one is available, otherwise None. @@ -1757,7 +1819,7 @@ def get_message(self, ignore_subscribe_messages=False, timeout=0.0): # so no messages are available return None - response = self.parse_response(block=(timeout is None), timeout=timeout) + response = self.parse_response(block=(timeout is None), timeout=timeout, node=node) if response: return self.handle_message(response, ignore_subscribe_messages) return None @@ -1809,12 +1871,17 @@ def handle_message(self, response, ignore_subscribe_messages=False): if pattern in self.pending_unsubscribe_patterns: self.pending_unsubscribe_patterns.remove(pattern) self.patterns.pop(pattern, None) + elif message_type == "sunsubscribe": + s_channel = response[1] + if s_channel in self.pending_unsubscribe_shard_channels: + self.pending_unsubscribe_shard_channels.remove(s_channel) + self.shard_channels.pop(s_channel, None) else: channel = response[1] if channel in self.pending_unsubscribe_channels: self.pending_unsubscribe_channels.remove(channel) self.channels.pop(channel, None) - if not self.channels and not self.patterns: + if not self.channels and not self.patterns and not self.shard_channels: # There are no subscriptions anymore, set subscribed_event flag # to false self.subscribed_event.clear() @@ -1823,6 +1890,8 @@ def handle_message(self, response, ignore_subscribe_messages=False): # if there's a message handler, invoke it if message_type == "pmessage": handler = self.patterns.get(message["pattern"], None) + elif message_type == "smessage": + handler = self.shard_channels.get(message["channel"], None) else: handler = self.channels.get(message["channel"], None) if handler: @@ -1843,6 +1912,11 @@ def run_in_thread(self, sleep_time=0, daemon=False, exception_handler=None): for pattern, handler in self.patterns.items(): if handler is None: raise PubSubError(f"Pattern: '{pattern}' has no handler registered") + for s_channel, handler in self.shard_channels.items(): + if handler is None: + raise PubSubError( + f"Shard Channel: '{s_channel}' has no handler registered" + ) thread = PubSubWorkerThread( self, sleep_time, daemon=daemon, exception_handler=exception_handler diff --git a/redis/cluster.py b/redis/cluster.py index 182ec6d733..6448115aae 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -1687,6 +1687,16 @@ def execute_command(self, *args, **kwargs): # NOTE: don't parse the response in this function -- it could pull a # legitimate message off the stack if the connection is already # subscribed to one or more channels + node = kwargs.get("node") + if node is not None: + self.node = node + self.connection_pool = ( + self.cluster.get_redis_connection(self.node).connection_pool + ) + self.connection = self.connection_pool.get_connection( + "pubsub", self.shard_hint + ) + self.connection.register_connect_callback(self.on_connect) if self.connection is None: if self.connection_pool is None: @@ -1720,6 +1730,14 @@ def get_redis_connection(self): if self.node is not None: return self.node.redis_connection + def parse_response(self, block=True, timeout=0, node=None): + if node is not None: + self.node = node + self.connection_pool = node.redis_connection.connection_pool + self.connection = self.connection_pool.get_connection( + "pubsub", self.shard_hint + ) + return super().parse_response(block=block, timeout=timeout) class ClusterPipeline(RedisCluster): """ diff --git a/redis/commands/core.py b/redis/commands/core.py index e2cabb85fa..54289ed77c 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -5103,6 +5103,15 @@ def publish(self, channel: ChannelT, message: EncodableT, **kwargs) -> ResponseT """ return self.execute_command("PUBLISH", channel, message, **kwargs) + def spublish(self, shard_channel: ChannelT, message: EncodableT) -> ResponseT: + """ + Posts a message to the given shard channel. + Returns the number of clients that received the message + + For more information see https://redis.io/commands/spublish + """ + return self.execute_command("SPUBLISH", shard_channel, message) + def pubsub_channels(self, pattern: PatternT = "*", **kwargs) -> ResponseT: """ Return a list of channels that have at least one subscriber diff --git a/redis/parsers/socket.py b/redis/parsers/socket.py index 8147243bba..9026fab78f 100644 --- a/redis/parsers/socket.py +++ b/redis/parsers/socket.py @@ -92,9 +92,13 @@ def _read_from_socket( sock.settimeout(self.socket_timeout) def can_read(self, timeout: float) -> bool: - return bool(self.unread_bytes()) or self._read_from_socket( + read = self._read_from_socket( timeout=timeout, raise_on_timeout=False ) + _bytes = bool(self.unread_bytes()) + print("@@@@@@@@@@@@@@@@@@@@@@") + print(read, _bytes) + return _bytes or read def read(self, length: int) -> bytes: length = length + 2 # make sure to read the \r\n terminator diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index e1e4311511..3b321ca310 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -20,13 +20,18 @@ ) -def wait_for_message(pubsub, timeout=0.5, ignore_subscribe_messages=False): +def wait_for_message(pubsub, timeout=0.5, ignore_subscribe_messages=False, node=None): now = time.time() timeout = now + timeout while now < timeout: - message = pubsub.get_message( - ignore_subscribe_messages=ignore_subscribe_messages - ) + if node: + message = pubsub.get_message( + ignore_subscribe_messages=ignore_subscribe_messages, node=node + ) + else: + message = pubsub.get_message( + ignore_subscribe_messages=ignore_subscribe_messages + ) if message is not None: return message time.sleep(0.01) @@ -53,6 +58,15 @@ def make_subscribe_test_data(pubsub, type): "unsub_func": pubsub.unsubscribe, "keys": ["foo", "bar", "uni" + chr(4456) + "code"], } + elif type == "shard_channel": + return { + "p": pubsub, + "sub_type": "ssubscribe", + "unsub_type": "sunsubscribe", + "sub_func": pubsub.ssubscribe, + "unsub_func": pubsub.sunsubscribe, + "keys": ["foo", "bar", "uni" + chr(4456) + "code"], + } elif type == "pattern": return { "p": pubsub, @@ -71,7 +85,7 @@ def _test_subscribe_unsubscribe( ): for key in keys: assert sub_func(key) is None - + print("&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&") # should be a message for each channel/pattern we just subscribed to for i, key in enumerate(keys): assert wait_for_message(p) == make_message(sub_type, key, i + 1) @@ -93,6 +107,37 @@ def test_pattern_subscribe_unsubscribe(self, r): kwargs = make_subscribe_test_data(r.pubsub(), "pattern") self._test_subscribe_unsubscribe(**kwargs) + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("7.0.0") + def test_shard_channel_subscribe_unsubscribe(self, r): + kwargs = make_subscribe_test_data(r.pubsub(), "shard_channel") + self._test_subscribe_unsubscribe(**kwargs) + + @pytest.mark.onlycluster + @skip_if_server_version_lt("7.0.0") + def test_shard_channel_subscribe_unsubscribe_cluster(self, r): + p = r.pubsub() + keys = { + "foo": r.get_node_from_key("foo"), + # "bar": r.get_node_from_key("bar"), + # "uni" + chr(4456) + "code": r.get_node_from_key("uni" + chr(4456) + "code"), + } + for key, node in keys.items(): + assert p.ssubscribe(key, target_node=node) is None + print("&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&") + # should be a message for each channel/pattern we just subscribed to + for i, (key, node) in enumerate(keys.items()): + assert wait_for_message(p,node=node) == make_message("ssubscribe", key, i + 1) + + for key, node in keys.items(): + assert p.sunsubscribe(key, node) is None + + # should be a message for each channel/pattern we just unsubscribed + # from + for i, key in enumerate(keys.keys()): + i = len(keys) - 1 - i + assert wait_for_message(p) == make_message("sunsubscribe", key, i) + def _test_resubscribe_on_reconnection( self, p, sub_type, unsub_type, sub_func, unsub_func, keys ): @@ -136,6 +181,12 @@ def test_resubscribe_to_patterns_on_reconnection(self, r): kwargs = make_subscribe_test_data(r.pubsub(), "pattern") self._test_resubscribe_on_reconnection(**kwargs) + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("7.0.0") + def test_resubscribe_to_shard_channels_on_reconnection(self, r): + kwargs = make_subscribe_test_data(r.pubsub(), "shard_channel") + self._test_resubscribe_on_reconnection(**kwargs) + def _test_subscribed_property( self, p, sub_type, unsub_type, sub_func, unsub_func, keys ): @@ -192,6 +243,12 @@ def test_subscribe_property_with_patterns(self, r): kwargs = make_subscribe_test_data(r.pubsub(), "pattern") self._test_subscribed_property(**kwargs) + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("7.0.0") + def test_subscribe_property_with_shard_channels(self, r): + kwargs = make_subscribe_test_data(r.pubsub(), "shard_channel") + self._test_subscribed_property(**kwargs) + def test_ignore_all_subscribe_messages(self, r): p = r.pubsub(ignore_subscribe_messages=True) @@ -236,6 +293,12 @@ def test_sub_unsub_resub_patterns(self, r): kwargs = make_subscribe_test_data(r.pubsub(), "pattern") self._test_sub_unsub_resub(**kwargs) + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("7.0.0") + def test_sub_unsub_resub_shard_channels(self, r): + kwargs = make_subscribe_test_data(r.pubsub(), "shard_channel") + self._test_sub_unsub_resub(**kwargs) + def _test_sub_unsub_resub( self, p, sub_type, unsub_type, sub_func, unsub_func, keys ): @@ -258,6 +321,12 @@ def test_sub_unsub_all_resub_patterns(self, r): kwargs = make_subscribe_test_data(r.pubsub(), "pattern") self._test_sub_unsub_all_resub(**kwargs) + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("7.0.0") + def test_sub_unsub_all_resub_shard_channels(self, r): + kwargs = make_subscribe_test_data(r.pubsub(), "shard_channel") + self._test_sub_unsub_all_resub(**kwargs) + def _test_sub_unsub_all_resub( self, p, sub_type, unsub_type, sub_func, unsub_func, keys ): @@ -290,6 +359,18 @@ def test_published_message_to_channel(self, r): assert isinstance(message, dict) assert message == make_message("message", "foo", "test message") + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("7.0.0") + def test_published_message_to_shard_channel(self, r): + p = r.pubsub() + p.ssubscribe("foo") + assert wait_for_message(p) == make_message("ssubscribe", "foo", 1) + assert r.spublish("foo", "test message") == 1 + + message = wait_for_message(p) + assert isinstance(message, dict) + assert message == make_message("smessage", "foo", "test message") + def test_published_message_to_pattern(self, r): p = r.pubsub() p.subscribe("foo") @@ -321,6 +402,16 @@ def test_channel_message_handler(self, r): assert wait_for_message(p) is None assert self.message == make_message("message", "foo", "test message") + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("7.0.0") + def test_shard_channel_message_handler(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + p.ssubscribe(foo=self.message_handler) + assert wait_for_message(p) is None + assert r.spublish("foo", "test message") == 1 + assert wait_for_message(p) is None + assert self.message == make_message("smessage", "foo", "test message") + @pytest.mark.onlynoncluster def test_pattern_message_handler(self, r): p = r.pubsub(ignore_subscribe_messages=True) @@ -342,6 +433,18 @@ def test_unicode_channel_message_handler(self, r): assert wait_for_message(p) is None assert self.message == make_message("message", channel, "test message") + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("7.0.0") + def test_unicode_shard_channel_message_handler(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + channel = "uni" + chr(4456) + "code" + channels = {channel: self.message_handler} + p.ssubscribe(**channels) + assert wait_for_message(p) is None + assert r.spublish(channel, "test message") == 1 + assert wait_for_message(p) is None + assert self.message == make_message("smessage", channel, "test message") + @pytest.mark.onlynoncluster # see: https://redis-py-cluster.readthedocs.io/en/stable/pubsub.html # #known-limitations-with-pubsub From 8a60b97b269b32905a395c4fb56f5ef236902bfa Mon Sep 17 00:00:00 2001 From: dvora-h Date: Mon, 22 May 2023 00:24:13 +0300 Subject: [PATCH 02/10] sharded pubsub Co-authored-by: Leibale Eidelman --- redis/client.py | 20 ++----- redis/cluster.py | 87 ++++++++++++++++++++++++------- redis/parsers/socket.py | 2 - tests/test_asyncio/test_pubsub.py | 5 -- tests/test_pubsub.py | 31 ++++++----- 5 files changed, 88 insertions(+), 57 deletions(-) diff --git a/redis/client.py b/redis/client.py index 8414be81c7..b92067c5c5 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1611,9 +1611,7 @@ def parse_response(self, block=True, timeout=0, **kwargs): def try_read(): if not block: - print("###################") can_read = conn.can_read(timeout=timeout) - print(can_read) if not can_read: return None else: @@ -1752,12 +1750,7 @@ def ssubscribe(self, *args, target_node=None, **kwargs): args = list_or_args(args[0], args[1:]) new_s_channels = dict.fromkeys(args) new_s_channels.update(kwargs) - try: - # cluster mode - ret_val = self.execute_command("SSUBSCRIBE", *new_s_channels.keys(), node=target_node) - except TypeError: - # standalone mode - ret_val = self.execute_command("SSUBSCRIBE", *new_s_channels.keys()) + ret_val = self.execute_command("SSUBSCRIBE", *new_s_channels.keys()) # update the s_channels dict AFTER we send the command. we don't want to # subscribe twice to these channels, once for the command and again # for the reconnection. @@ -1783,12 +1776,7 @@ def sunsubscribe(self, *args, target_node=None): else: s_channels = self.shard_channels self.pending_unsubscribe_shard_channels.update(s_channels) - try: - # cluster mode - return self.execute_command("SUNSUBSCRIBE", *args, node=target_node) - except TypeError: - # standalone mode - return self.execute_command("SUNSUBSCRIBE", *args) + return self.execute_command("SUNSUBSCRIBE", *args) def listen(self): "Listen for messages on channels this client has been subscribed to" @@ -1797,7 +1785,7 @@ def listen(self): if response is not None: yield response - def get_message(self, ignore_subscribe_messages=False, timeout=0.0, node=None): + def get_message(self, ignore_subscribe_messages=False, timeout=0.0): """ Get the next message if one is available, otherwise None. @@ -1819,7 +1807,7 @@ def get_message(self, ignore_subscribe_messages=False, timeout=0.0, node=None): # so no messages are available return None - response = self.parse_response(block=(timeout is None), timeout=timeout, node=node) + response = self.parse_response(block=(timeout is None), timeout=timeout) if response: return self.handle_message(response, ignore_subscribe_messages) return None diff --git a/redis/cluster.py b/redis/cluster.py index 6448115aae..32109b75ad 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -9,6 +9,7 @@ from redis.backoff import default_backoff from redis.client import CaseInsensitiveDict, PubSub, Redis, parse_scan from redis.commands import READ_COMMANDS, RedisClusterCommands +from redis.commands.helpers import list_or_args from redis.connection import ConnectionPool, DefaultParser, parse_url from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot from redis.exceptions import ( @@ -1625,6 +1626,8 @@ def __init__(self, redis_cluster, node=None, host=None, port=None, **kwargs): else redis_cluster.get_redis_connection(self.node).connection_pool ) self.cluster = redis_cluster + self.node_pubsub_mapping = {} + self._pubsubs_generator = self._pubsubs_generator() super().__init__( **kwargs, connection_pool=connection_pool, encoder=redis_cluster.encoder ) @@ -1678,25 +1681,15 @@ def _raise_on_invalid_node(self, redis_cluster, node, host, port): f"Node {host}:{port} doesn't exist in the cluster" ) - def execute_command(self, *args, **kwargs): + def execute_command(self, *args, ): """ - Execute a publish/subscribe command. + Execute a subscribe/unsubscribe command. Taken code from redis-py and tweak to make it work within a cluster. """ # NOTE: don't parse the response in this function -- it could pull a # legitimate message off the stack if the connection is already # subscribed to one or more channels - node = kwargs.get("node") - if node is not None: - self.node = node - self.connection_pool = ( - self.cluster.get_redis_connection(self.node).connection_pool - ) - self.connection = self.connection_pool.get_connection( - "pubsub", self.shard_hint - ) - self.connection.register_connect_callback(self.on_connect) if self.connection is None: if self.connection_pool is None: @@ -1723,6 +1716,68 @@ def execute_command(self, *args, **kwargs): connection = self.connection self._execute(connection, connection.send_command, *args) + def _get_node_pubsub(self, node): + try: + return self.node_pubsub_mapping[node.name] + except KeyError: + pubsub = node.redis_connection.pubsub() + self.node_pubsub_mapping[node.name] = pubsub + return pubsub + + def _sharded_message_generator(self, ignore_subscribe_messages=False): + while True: + pubsub = next(self._pubsubs_generator) + message = pubsub.get_message(ignore_subscribe_messages=ignore_subscribe_messages) + if message is not None: + return message + + def _pubsubs_generator(self): + while True: + for pubsub in self.node_pubsub_mapping.values(): + yield pubsub + + def get_sharded_message( + self, ignore_subscribe_messages=False, timeout=0.0, target_node=None + ): + if target_node: + message = self.node_pubsub_mapping[target_node.name].get_message( + ignore_subscribe_messages=ignore_subscribe_messages, timeout=timeout + ) + else: + message = self._sharded_message_generator(ignore_subscribe_messages=ignore_subscribe_messages) + if message is None: + return None + elif str_if_bytes(message["type"]) == "sunsubscribe": + self.shard_channels.pop(message["channel"], None) + if not self.channels and not self.patterns and not self.shard_channels: + # There are no subscriptions anymore, set subscribed_event flag + # to false + self.subscribed_event.clear() + return message + + def ssubscribe(self, *args, **kwargs): + if args: + args = list_or_args(args[0], args[1:]) + for s_channel in args: + node = self.cluster.get_node_from_key(s_channel) + pubsub = self._get_node_pubsub(node) + pubsub.ssubscribe(s_channel) + # self.subscribed = self.subscribed or self._get_node_pubsub(node).subscribed + self.shard_channels.update(pubsub.shard_channels) + if pubsub.subscribed and not self.subscribed: + self.subscribed_event.set() + self.health_check_response_counter = 0 + + def sunsubscribe(self, *args): + if args: + args = list_or_args(args[0], args[1:]) + else: + args = self.shard_channels + + for s_channel in args: + node = self.cluster.get_node_from_key(s_channel) + self._get_node_pubsub(node).sunsubscribe(s_channel) + def get_redis_connection(self): """ Get the Redis connection of the pubsub connected node. @@ -1730,14 +1785,6 @@ def get_redis_connection(self): if self.node is not None: return self.node.redis_connection - def parse_response(self, block=True, timeout=0, node=None): - if node is not None: - self.node = node - self.connection_pool = node.redis_connection.connection_pool - self.connection = self.connection_pool.get_connection( - "pubsub", self.shard_hint - ) - return super().parse_response(block=block, timeout=timeout) class ClusterPipeline(RedisCluster): """ diff --git a/redis/parsers/socket.py b/redis/parsers/socket.py index 9026fab78f..833b5fab2b 100644 --- a/redis/parsers/socket.py +++ b/redis/parsers/socket.py @@ -96,8 +96,6 @@ def can_read(self, timeout: float) -> bool: timeout=timeout, raise_on_timeout=False ) _bytes = bool(self.unread_bytes()) - print("@@@@@@@@@@@@@@@@@@@@@@") - print(read, _bytes) return _bytes or read def read(self, length: int) -> bytes: diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py index 8cd5cf6fba..412398f37b 100644 --- a/tests/test_asyncio/test_pubsub.py +++ b/tests/test_asyncio/test_pubsub.py @@ -675,18 +675,15 @@ async def loop(): nonlocal interrupt await pubsub.subscribe("foo") while True: - # print("loop") try: try: await pubsub.connect() await loop_step() - # print("succ") except redis.ConnectionError: await asyncio.sleep(0.1) except asyncio.CancelledError: # we use a cancel to interrupt the "listen" # when we perform a disconnect - # print("cancel", interrupt) if interrupt: interrupt = False else: @@ -919,7 +916,6 @@ async def loop(self): try: if self.state == 4: break - # print("state a ", self.state) got_msg = await self.get_message() assert got_msg if self.state in (1, 2): @@ -937,7 +933,6 @@ async def loop(self): async def loop_step_get_message(self): # get a single message via get_message message = await self.pubsub.get_message(timeout=0.1) - # print(message) if message is not None: await self.messages.put(message) return True diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index 3b321ca310..c950e12e09 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -25,8 +25,8 @@ def wait_for_message(pubsub, timeout=0.5, ignore_subscribe_messages=False, node= timeout = now + timeout while now < timeout: if node: - message = pubsub.get_message( - ignore_subscribe_messages=ignore_subscribe_messages, node=node + message = pubsub.get_sharded_message( + ignore_subscribe_messages=ignore_subscribe_messages, target_node=node ) else: message = pubsub.get_message( @@ -85,7 +85,7 @@ def _test_subscribe_unsubscribe( ): for key in keys: assert sub_func(key) is None - print("&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&") + # should be a message for each channel/pattern we just subscribed to for i, key in enumerate(keys): assert wait_for_message(p) == make_message(sub_type, key, i + 1) @@ -119,24 +119,27 @@ def test_shard_channel_subscribe_unsubscribe_cluster(self, r): p = r.pubsub() keys = { "foo": r.get_node_from_key("foo"), - # "bar": r.get_node_from_key("bar"), - # "uni" + chr(4456) + "code": r.get_node_from_key("uni" + chr(4456) + "code"), + "bar": r.get_node_from_key("bar"), + "uni" + chr(4456) + "code": r.get_node_from_key("uni" + chr(4456) + "code"), } - for key, node in keys.items(): - assert p.ssubscribe(key, target_node=node) is None - print("&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&") + + for key in keys.keys(): + assert p.ssubscribe(key) is None # should be a message for each channel/pattern we just subscribed to + data = [1, 1, 2] for i, (key, node) in enumerate(keys.items()): - assert wait_for_message(p,node=node) == make_message("ssubscribe", key, i + 1) + assert wait_for_message(p, node=node) == make_message("ssubscribe", key, data[i]) - for key, node in keys.items(): - assert p.sunsubscribe(key, node) is None + for key in keys.keys(): + assert p.sunsubscribe(key) is None # should be a message for each channel/pattern we just unsubscribed # from - for i, key in enumerate(keys.keys()): - i = len(keys) - 1 - i - assert wait_for_message(p) == make_message("sunsubscribe", key, i) + data = [0, 1, 0] + breakpoint() + for i, (key, node) in enumerate(keys.items()): + assert wait_for_message(p, node=node) == make_message("sunsubscribe", key, data[i]) + breakpoint() def _test_resubscribe_on_reconnection( self, p, sub_type, unsub_type, sub_func, unsub_func, keys From c6c8a0485f49c110fc8a07cf876e9a1046de1ede Mon Sep 17 00:00:00 2001 From: dvora-h Date: Mon, 22 May 2023 16:20:39 +0300 Subject: [PATCH 03/10] Shrded Pubsub TestPubSubSubscribeUnsubscribe --- dev_requirements.txt | 1 + redis/cluster.py | 52 +++++++++---- redis/parsers/socket.py | 4 +- tests/test_pubsub.py | 164 ++++++++++++++++++++++++++++++++++------ 4 files changed, 177 insertions(+), 44 deletions(-) diff --git a/dev_requirements.txt b/dev_requirements.txt index 8285b0456f..8ffb1e944f 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -15,4 +15,5 @@ pytest-cov>=4.0.0 vulture>=2.3.0 ujson>=4.2.0 wheel>=0.30.0 +urllib3<2 uvloop diff --git a/redis/cluster.py b/redis/cluster.py index 32109b75ad..a748280c07 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -1681,7 +1681,7 @@ def _raise_on_invalid_node(self, redis_cluster, node, host, port): f"Node {host}:{port} doesn't exist in the cluster" ) - def execute_command(self, *args, ): + def execute_command(self, *args): """ Execute a subscribe/unsubscribe command. @@ -1723,36 +1723,41 @@ def _get_node_pubsub(self, node): pubsub = node.redis_connection.pubsub() self.node_pubsub_mapping[node.name] = pubsub return pubsub - - def _sharded_message_generator(self, ignore_subscribe_messages=False): - while True: + + def _sharded_message_generator(self): + for _ in range(len(self.node_pubsub_mapping)): pubsub = next(self._pubsubs_generator) - message = pubsub.get_message(ignore_subscribe_messages=ignore_subscribe_messages) + message = pubsub.get_message() if message is not None: return message + return None def _pubsubs_generator(self): while True: for pubsub in self.node_pubsub_mapping.values(): yield pubsub - + def get_sharded_message( - self, ignore_subscribe_messages=False, timeout=0.0, target_node=None - ): + self, ignore_subscribe_messages=False, timeout=0.0, target_node=None + ): if target_node: message = self.node_pubsub_mapping[target_node.name].get_message( ignore_subscribe_messages=ignore_subscribe_messages, timeout=timeout ) else: - message = self._sharded_message_generator(ignore_subscribe_messages=ignore_subscribe_messages) + message = self._sharded_message_generator() if message is None: return None elif str_if_bytes(message["type"]) == "sunsubscribe": - self.shard_channels.pop(message["channel"], None) + if message["channel"] in self.pending_unsubscribe_shard_channels: + self.pending_unsubscribe_shard_channels.remove(message["channel"]) + self.shard_channels.pop(message["channel"], None) if not self.channels and not self.patterns and not self.shard_channels: - # There are no subscriptions anymore, set subscribed_event flag - # to false - self.subscribed_event.clear() + # There are no subscriptions anymore, set subscribed_event flag + # to false + self.subscribed_event.clear() + if self.ignore_subscribe_messages or ignore_subscribe_messages: + return None return message def ssubscribe(self, *args, **kwargs): @@ -1762,12 +1767,14 @@ def ssubscribe(self, *args, **kwargs): node = self.cluster.get_node_from_key(s_channel) pubsub = self._get_node_pubsub(node) pubsub.ssubscribe(s_channel) - # self.subscribed = self.subscribed or self._get_node_pubsub(node).subscribed self.shard_channels.update(pubsub.shard_channels) + self.pending_unsubscribe_shard_channels.difference_update( + self._normalize_keys({s_channel: None}) + ) if pubsub.subscribed and not self.subscribed: self.subscribed_event.set() self.health_check_response_counter = 0 - + def sunsubscribe(self, *args): if args: args = list_or_args(args[0], args[1:]) @@ -1776,7 +1783,11 @@ def sunsubscribe(self, *args): for s_channel in args: node = self.cluster.get_node_from_key(s_channel) - self._get_node_pubsub(node).sunsubscribe(s_channel) + p = self._get_node_pubsub(node) + p.sunsubscribe(s_channel) + self.pending_unsubscribe_shard_channels.update( + p.pending_unsubscribe_shard_channels + ) def get_redis_connection(self): """ @@ -1785,6 +1796,15 @@ def get_redis_connection(self): if self.node is not None: return self.node.redis_connection + def disconnect(self): + """ + Disconnect the pubsub connection. + """ + if self.connection: + self.connection.disconnect() + for pubsub in self.node_pubsub_mapping.values(): + pubsub.connection.disconnect() + class ClusterPipeline(RedisCluster): """ diff --git a/redis/parsers/socket.py b/redis/parsers/socket.py index 833b5fab2b..f1b2109cd4 100644 --- a/redis/parsers/socket.py +++ b/redis/parsers/socket.py @@ -92,9 +92,7 @@ def _read_from_socket( sock.settimeout(self.socket_timeout) def can_read(self, timeout: float) -> bool: - read = self._read_from_socket( - timeout=timeout, raise_on_timeout=False - ) + read = self._read_from_socket(timeout=timeout, raise_on_timeout=False) _bytes = bool(self.unread_bytes()) return _bytes or read diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index c950e12e09..a901f47c05 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -3,6 +3,7 @@ import socket import threading import time +from collections import defaultdict from unittest import mock from unittest.mock import patch @@ -20,7 +21,9 @@ ) -def wait_for_message(pubsub, timeout=0.5, ignore_subscribe_messages=False, node=None): +def wait_for_message( + pubsub, timeout=0.5, ignore_subscribe_messages=False, node=None, func=None +): now = time.time() timeout = now + timeout while now < timeout: @@ -28,6 +31,8 @@ def wait_for_message(pubsub, timeout=0.5, ignore_subscribe_messages=False, node= message = pubsub.get_sharded_message( ignore_subscribe_messages=ignore_subscribe_messages, target_node=node ) + elif func: + message = func(ignore_subscribe_messages=ignore_subscribe_messages) else: message = pubsub.get_message( ignore_subscribe_messages=ignore_subscribe_messages @@ -116,6 +121,7 @@ def test_shard_channel_subscribe_unsubscribe(self, r): @pytest.mark.onlycluster @skip_if_server_version_lt("7.0.0") def test_shard_channel_subscribe_unsubscribe_cluster(self, r): + node_channels = defaultdict(int) p = r.pubsub() keys = { "foo": r.get_node_from_key("foo"), @@ -123,23 +129,26 @@ def test_shard_channel_subscribe_unsubscribe_cluster(self, r): "uni" + chr(4456) + "code": r.get_node_from_key("uni" + chr(4456) + "code"), } - for key in keys.keys(): + for key, node in keys.items(): assert p.ssubscribe(key) is None - # should be a message for each channel/pattern we just subscribed to - data = [1, 1, 2] - for i, (key, node) in enumerate(keys.items()): - assert wait_for_message(p, node=node) == make_message("ssubscribe", key, data[i]) + + # should be a message for each shard_channel we just subscribed to + for key, node in keys.items(): + node_channels[node.name] += 1 + assert wait_for_message(p, node=node) == make_message( + "ssubscribe", key, node_channels[node.name] + ) for key in keys.keys(): assert p.sunsubscribe(key) is None - # should be a message for each channel/pattern we just unsubscribed + # should be a message for each shard_channel we just unsubscribed # from - data = [0, 1, 0] - breakpoint() - for i, (key, node) in enumerate(keys.items()): - assert wait_for_message(p, node=node) == make_message("sunsubscribe", key, data[i]) - breakpoint() + for key, node in keys.items(): + node_channels[node.name] -= 1 + assert wait_for_message(p, node=node) == make_message( + "sunsubscribe", key, node_channels[node.name] + ) def _test_resubscribe_on_reconnection( self, p, sub_type, unsub_type, sub_func, unsub_func, keys @@ -154,7 +163,7 @@ def _test_resubscribe_on_reconnection( # manually disconnect p.connection.disconnect() - + # breakpoint() # calling get_message again reconnects and resubscribes # note, we may not re-subscribe to channels in exactly the same order # so we have to do some extra checks to make sure we got them all @@ -252,38 +261,103 @@ def test_subscribe_property_with_shard_channels(self, r): kwargs = make_subscribe_test_data(r.pubsub(), "shard_channel") self._test_subscribed_property(**kwargs) + @pytest.mark.onlycluster + @skip_if_server_version_lt("7.0.0") + def test_subscribe_property_with_shard_channels_cluster(self, r): + p = r.pubsub() + keys = ["foo", "bar", "uni" + chr(4456) + "code"] + nodes = [r.get_node_from_key(key) for key in keys] + assert p.subscribed is False + p.ssubscribe(keys[0]) + # we're now subscribed even though we haven't processed the + # reply from the server just yet + assert p.subscribed is True + assert wait_for_message(p, node=nodes[0]) == make_message( + "ssubscribe", keys[0], 1 + ) + # we're still subscribed + assert p.subscribed is True + + # unsubscribe from all shard_channels + p.sunsubscribe() + # we're still technically subscribed until we process the + # response messages from the server + assert p.subscribed is True + assert wait_for_message(p, node=nodes[0]) == make_message( + "sunsubscribe", keys[0], 0 + ) + # now we're no longer subscribed as no more messages can be delivered + # to any channels we were listening to + assert p.subscribed is False + + # subscribing again flips the flag back + p.ssubscribe(keys[0]) + assert p.subscribed is True + assert wait_for_message(p, node=nodes[0]) == make_message( + "ssubscribe", keys[0], 1 + ) + + # unsubscribe again + p.sunsubscribe() + assert p.subscribed is True + # subscribe to another shard_channel before reading the unsubscribe response + p.ssubscribe(keys[1]) + assert p.subscribed is True + # read the unsubscribe for key1 + assert wait_for_message(p, node=nodes[0]) == make_message( + "sunsubscribe", keys[0], 0 + ) + # we're still subscribed to key2, so subscribed should still be True + assert p.subscribed is True + # read the key2 subscribe message + assert wait_for_message(p, node=nodes[1]) == make_message( + "ssubscribe", keys[1], 1 + ) + p.sunsubscribe() + # haven't read the message yet, so we're still subscribed + assert p.subscribed is True + assert wait_for_message(p, node=nodes[1]) == make_message( + "sunsubscribe", keys[1], 0 + ) + # now we're finally unsubscribed + assert p.subscribed is False + def test_ignore_all_subscribe_messages(self, r): p = r.pubsub(ignore_subscribe_messages=True) checks = ( - (p.subscribe, "foo"), - (p.unsubscribe, "foo"), - (p.psubscribe, "f*"), - (p.punsubscribe, "f*"), + (p.subscribe, "foo", p.get_message), + (p.unsubscribe, "foo", p.get_message), + (p.psubscribe, "f*", p.get_message), + (p.punsubscribe, "f*", p.get_message), + (p.ssubscribe, "foo", p.get_sharded_message), + (p.sunsubscribe, "foo", p.get_sharded_message), ) assert p.subscribed is False - for func, channel in checks: + for func, channel, get_func in checks: assert func(channel) is None assert p.subscribed is True - assert wait_for_message(p) is None + assert wait_for_message(p, func=get_func) is None assert p.subscribed is False def test_ignore_individual_subscribe_messages(self, r): p = r.pubsub() checks = ( - (p.subscribe, "foo"), - (p.unsubscribe, "foo"), - (p.psubscribe, "f*"), - (p.punsubscribe, "f*"), + (p.subscribe, "foo", p.get_message), + (p.unsubscribe, "foo", p.get_message), + (p.psubscribe, "f*", p.get_message), + (p.punsubscribe, "f*", p.get_message), + (p.ssubscribe, "foo", p.get_sharded_message), + (p.sunsubscribe, "foo", p.get_sharded_message), ) assert p.subscribed is False - for func, channel in checks: + for func, channel, get_func in checks: assert func(channel) is None assert p.subscribed is True - message = wait_for_message(p, ignore_subscribe_messages=True) + message = wait_for_message(p, ignore_subscribe_messages=True, func=get_func) assert message is None assert p.subscribed is False @@ -316,6 +390,26 @@ def _test_sub_unsub_resub( assert wait_for_message(p) == make_message(sub_type, key, 1) assert p.subscribed is True + @pytest.mark.onlycluster + @skip_if_server_version_lt("7.0.0") + def test_sub_unsub_resub_shard_channels_cluster(self, r): + p = r.pubsub() + key = "foo" + p.ssubscribe(key) + p.sunsubscribe(key) + p.ssubscribe(key) + assert p.subscribed is True + assert wait_for_message(p, func=p.get_sharded_message) == make_message( + "ssubscribe", key, 1 + ) + assert wait_for_message(p, func=p.get_sharded_message) == make_message( + "sunsubscribe", key, 0 + ) + assert wait_for_message(p, func=p.get_sharded_message) == make_message( + "ssubscribe", key, 1 + ) + assert p.subscribed is True + def test_sub_unsub_all_resub_channels(self, r): kwargs = make_subscribe_test_data(r.pubsub(), "channel") self._test_sub_unsub_all_resub(**kwargs) @@ -344,6 +438,26 @@ def _test_sub_unsub_all_resub( assert wait_for_message(p) == make_message(sub_type, key, 1) assert p.subscribed is True + @pytest.mark.onlycluster + @skip_if_server_version_lt("7.0.0") + def test_sub_unsub_all_resub_shard_channels_cluster(self, r): + p = r.pubsub() + key = "foo" + p.ssubscribe(key) + p.sunsubscribe() + p.ssubscribe(key) + assert p.subscribed is True + assert wait_for_message(p, func=p.get_sharded_message) == make_message( + "ssubscribe", key, 1 + ) + assert wait_for_message(p, func=p.get_sharded_message) == make_message( + "sunsubscribe", key, 0 + ) + assert wait_for_message(p, func=p.get_sharded_message) == make_message( + "ssubscribe", key, 1 + ) + assert p.subscribed is True + class TestPubSubMessages: def setup_method(self, method): From d243bff4c733afbbecaf1a342b1f005a14e37319 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Mon, 22 May 2023 16:48:15 +0300 Subject: [PATCH 04/10] fix TestPubSubSubscribeUnsubscribe --- redis/client.py | 2 ++ tests/test_pubsub.py | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/redis/client.py b/redis/client.py index b92067c5c5..5f134c631a 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1811,6 +1811,8 @@ def get_message(self, ignore_subscribe_messages=False, timeout=0.0): if response: return self.handle_message(response, ignore_subscribe_messages) return None + + get_sharded_message = get_message def ping(self, message=None): """ diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index a901f47c05..a72ead1160 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -163,7 +163,7 @@ def _test_resubscribe_on_reconnection( # manually disconnect p.connection.disconnect() - # breakpoint() + # calling get_message again reconnects and resubscribes # note, we may not re-subscribe to channels in exactly the same order # so we have to do some extra checks to make sure we got them all @@ -322,6 +322,7 @@ def test_subscribe_property_with_shard_channels_cluster(self, r): # now we're finally unsubscribed assert p.subscribed is False + @skip_if_server_version_lt("7.0.0") def test_ignore_all_subscribe_messages(self, r): p = r.pubsub(ignore_subscribe_messages=True) @@ -341,6 +342,7 @@ def test_ignore_all_subscribe_messages(self, r): assert wait_for_message(p, func=get_func) is None assert p.subscribed is False + @skip_if_server_version_lt("7.0.0") def test_ignore_individual_subscribe_messages(self, r): p = r.pubsub() From fe415dc4a0a4ac363358166034dc141d0bf7e12d Mon Sep 17 00:00:00 2001 From: dvora-h Date: Tue, 23 May 2023 00:53:26 +0300 Subject: [PATCH 05/10] more tests --- redis/client.py | 2 +- redis/cluster.py | 9 ++++-- redis/parsers/commands.py | 2 +- tests/test_pubsub.py | 64 +++++++++++++++++++++++++++++++++++---- 4 files changed, 67 insertions(+), 10 deletions(-) diff --git a/redis/client.py b/redis/client.py index 5f134c631a..a56480ecf6 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1811,7 +1811,7 @@ def get_message(self, ignore_subscribe_messages=False, timeout=0.0): if response: return self.handle_message(response, ignore_subscribe_messages) return None - + get_sharded_message = get_message def ping(self, message=None): diff --git a/redis/cluster.py b/redis/cluster.py index a748280c07..597217bd40 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -1763,10 +1763,15 @@ def get_sharded_message( def ssubscribe(self, *args, **kwargs): if args: args = list_or_args(args[0], args[1:]) - for s_channel in args: + s_channels = dict.fromkeys(args) + s_channels.update(kwargs) + for s_channel, handler in s_channels.items(): node = self.cluster.get_node_from_key(s_channel) pubsub = self._get_node_pubsub(node) - pubsub.ssubscribe(s_channel) + if handler: + pubsub.ssubscribe(**{s_channel: handler}) + else: + pubsub.ssubscribe(s_channel) self.shard_channels.update(pubsub.shard_channels) self.pending_unsubscribe_shard_channels.difference_update( self._normalize_keys({s_channel: None}) diff --git a/redis/parsers/commands.py b/redis/parsers/commands.py index 2ea29a75ae..93145ea093 100644 --- a/redis/parsers/commands.py +++ b/redis/parsers/commands.py @@ -161,7 +161,7 @@ def _get_pubsub_keys(self, *args): # format example: # SUBSCRIBE channel [channel ...] keys = list(args[1:]) - elif command == "PUBLISH": + elif command in ["PUBLISH", "SPUBLISH"]: # format example: # PUBLISH channel message keys = [args[1]] diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index a72ead1160..20eb73e2e2 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -490,6 +490,18 @@ def test_published_message_to_shard_channel(self, r): assert isinstance(message, dict) assert message == make_message("smessage", "foo", "test message") + @pytest.mark.onlycluster + @skip_if_server_version_lt("7.0.0") + def test_published_message_to_shard_channel_cluster(self, r): + p = r.pubsub() + p.ssubscribe("foo") + assert wait_for_message(p, func=p.get_sharded_message) == make_message("ssubscribe", "foo", 1) + assert r.spublish("foo", "test message") == 1 + + message = wait_for_message(p, func=p.get_sharded_message) + assert isinstance(message, dict) + assert message == make_message("smessage", "foo", "test message") + def test_published_message_to_pattern(self, r): p = r.pubsub() p.subscribe("foo") @@ -521,14 +533,13 @@ def test_channel_message_handler(self, r): assert wait_for_message(p) is None assert self.message == make_message("message", "foo", "test message") - @pytest.mark.onlynoncluster @skip_if_server_version_lt("7.0.0") def test_shard_channel_message_handler(self, r): p = r.pubsub(ignore_subscribe_messages=True) p.ssubscribe(foo=self.message_handler) - assert wait_for_message(p) is None + assert wait_for_message(p, func=p.get_sharded_message) is None assert r.spublish("foo", "test message") == 1 - assert wait_for_message(p) is None + assert wait_for_message(p, func=p.get_sharded_message) is None assert self.message == make_message("smessage", "foo", "test message") @pytest.mark.onlynoncluster @@ -552,16 +563,15 @@ def test_unicode_channel_message_handler(self, r): assert wait_for_message(p) is None assert self.message == make_message("message", channel, "test message") - @pytest.mark.onlynoncluster @skip_if_server_version_lt("7.0.0") def test_unicode_shard_channel_message_handler(self, r): p = r.pubsub(ignore_subscribe_messages=True) channel = "uni" + chr(4456) + "code" channels = {channel: self.message_handler} p.ssubscribe(**channels) - assert wait_for_message(p) is None + assert wait_for_message(p, func=p.get_sharded_message) is None assert r.spublish(channel, "test message") == 1 - assert wait_for_message(p) is None + assert wait_for_message(p, func=p.get_sharded_message) is None assert self.message == make_message("smessage", channel, "test message") @pytest.mark.onlynoncluster @@ -633,6 +643,15 @@ def test_pattern_subscribe_unsubscribe(self, r): p.punsubscribe(self.pattern) assert wait_for_message(p) == self.make_message("punsubscribe", self.pattern, 0) + @skip_if_server_version_lt("7.0.0") + def test_shard_channel_subscribe_unsubscribe(self, r): + p = r.pubsub() + p.ssubscribe(self.channel) + assert wait_for_message(p, func=p.get_sharded_message) == self.make_message("ssubscribe", self.channel, 1) + + p.sunsubscribe(self.channel) + assert wait_for_message(p, func=p.get_sharded_message) == self.make_message("sunsubscribe", self.channel, 0) + def test_channel_publish(self, r): p = r.pubsub() p.subscribe(self.channel) @@ -652,6 +671,16 @@ def test_pattern_publish(self, r): "pmessage", self.channel, self.data, pattern=self.pattern ) + @skip_if_server_version_lt("7.0.0") + def test_shard_channel_publish(self, r): + p = r.pubsub() + p.ssubscribe(self.channel) + assert wait_for_message(p, func=p.get_sharded_message) == self.make_message("ssubscribe", self.channel, 1) + r.spublish(self.channel, self.data) + assert wait_for_message(p, func=p.get_sharded_message) == self.make_message( + "smessage", self.channel, self.data + ) + def test_channel_message_handler(self, r): p = r.pubsub(ignore_subscribe_messages=True) p.subscribe(**{self.channel: self.message_handler}) @@ -690,6 +719,29 @@ def test_pattern_message_handler(self, r): "pmessage", self.channel, new_data, pattern=self.pattern ) + @skip_if_server_version_lt("7.0.0") + def test_shard_channel_message_handler(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + p.ssubscribe(**{self.channel: self.message_handler}) + assert wait_for_message(p, func=p.get_sharded_message) is None + r.spublish(self.channel, self.data) + assert wait_for_message(p, func=p.get_sharded_message) is None + assert self.message == self.make_message("smessage", self.channel, self.data) + + # test that we reconnected to the correct channel + self.message = None + try: + # cluster mode + p.disconnect() + except AttributeError: + # standalone mode + p.connection.disconnect() + assert wait_for_message(p, func=p.get_sharded_message) is None # should reconnect + new_data = self.data + "new data" + r.spublish(self.channel, new_data) + assert wait_for_message(p, func=p.get_sharded_message) is None + assert self.message == self.make_message("smessage", self.channel, new_data) + def test_context_manager(self, r): with r.pubsub() as pubsub: pubsub.subscribe("foo") From 3ec1e45e8e71b0bf6e3c88d39a67d1674d995287 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Tue, 23 May 2023 00:56:41 +0300 Subject: [PATCH 06/10] linters --- tests/test_pubsub.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index 20eb73e2e2..7b52174788 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -495,7 +495,9 @@ def test_published_message_to_shard_channel(self, r): def test_published_message_to_shard_channel_cluster(self, r): p = r.pubsub() p.ssubscribe("foo") - assert wait_for_message(p, func=p.get_sharded_message) == make_message("ssubscribe", "foo", 1) + assert wait_for_message(p, func=p.get_sharded_message) == make_message( + "ssubscribe", "foo", 1 + ) assert r.spublish("foo", "test message") == 1 message = wait_for_message(p, func=p.get_sharded_message) @@ -647,10 +649,14 @@ def test_pattern_subscribe_unsubscribe(self, r): def test_shard_channel_subscribe_unsubscribe(self, r): p = r.pubsub() p.ssubscribe(self.channel) - assert wait_for_message(p, func=p.get_sharded_message) == self.make_message("ssubscribe", self.channel, 1) + assert wait_for_message(p, func=p.get_sharded_message) == self.make_message( + "ssubscribe", self.channel, 1 + ) p.sunsubscribe(self.channel) - assert wait_for_message(p, func=p.get_sharded_message) == self.make_message("sunsubscribe", self.channel, 0) + assert wait_for_message(p, func=p.get_sharded_message) == self.make_message( + "sunsubscribe", self.channel, 0 + ) def test_channel_publish(self, r): p = r.pubsub() @@ -675,7 +681,9 @@ def test_pattern_publish(self, r): def test_shard_channel_publish(self, r): p = r.pubsub() p.ssubscribe(self.channel) - assert wait_for_message(p, func=p.get_sharded_message) == self.make_message("ssubscribe", self.channel, 1) + assert wait_for_message(p, func=p.get_sharded_message) == self.make_message( + "ssubscribe", self.channel, 1 + ) r.spublish(self.channel, self.data) assert wait_for_message(p, func=p.get_sharded_message) == self.make_message( "smessage", self.channel, self.data @@ -736,7 +744,8 @@ def test_shard_channel_message_handler(self, r): except AttributeError: # standalone mode p.connection.disconnect() - assert wait_for_message(p, func=p.get_sharded_message) is None # should reconnect + # should reconnect + assert wait_for_message(p, func=p.get_sharded_message) is None new_data = self.data + "new data" r.spublish(self.channel, new_data) assert wait_for_message(p, func=p.get_sharded_message) is None From ec34d7d9babc55f047d7060f513815f423a30de9 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Wed, 24 May 2023 12:06:16 +0300 Subject: [PATCH 07/10] TestPubSubSubcommands --- redis/client.py | 1 + redis/cluster.py | 11 ++++++-- redis/commands/core.py | 17 ++++++++++++ redis/parsers/commands.py | 2 +- tests/test_pubsub.py | 58 +++++++++++++++++++++++++++++++++++++++ 5 files changed, 85 insertions(+), 4 deletions(-) diff --git a/redis/client.py b/redis/client.py index a56480ecf6..82028a1a53 100755 --- a/redis/client.py +++ b/redis/client.py @@ -833,6 +833,7 @@ class AbstractRedis: "QUIT": bool_ok, "STRALGO": parse_stralgo, "PUBSUB NUMSUB": parse_pubsub_numsub, + "PUBSUB SHARDNUMSUB": parse_pubsub_numsub, "RANDOMKEY": lambda r: r and r or None, "RESET": str_if_bytes, "SCAN": parse_scan, diff --git a/redis/cluster.py b/redis/cluster.py index 597217bd40..da238b8c54 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -223,7 +223,7 @@ class AbstractRedisCluster: "PUBSUB CHANNELS", "PUBSUB NUMPAT", "PUBSUB NUMSUB", - "PING", + "PUBSUB SHARDCHANNELS" "PUBSUB SHARDNUMSUB" "PING", "INFO", "SHUTDOWN", "KEYS", @@ -347,11 +347,13 @@ class AbstractRedisCluster: } RESULT_CALLBACKS = dict_merge( - list_keys_to_dict(["PUBSUB NUMSUB"], parse_pubsub_numsub), + list_keys_to_dict(["PUBSUB NUMSUB", "PUBSUB SHARDNUMSUB"], parse_pubsub_numsub), list_keys_to_dict( ["PUBSUB NUMPAT"], lambda command, res: sum(list(res.values())) ), - list_keys_to_dict(["KEYS", "PUBSUB CHANNELS"], merge_result), + list_keys_to_dict( + ["KEYS", "PUBSUB CHANNELS", "PUBSUB SHARDCHANNELS"], merge_result + ), list_keys_to_dict( [ "PING", @@ -1752,6 +1754,9 @@ def get_sharded_message( if message["channel"] in self.pending_unsubscribe_shard_channels: self.pending_unsubscribe_shard_channels.remove(message["channel"]) self.shard_channels.pop(message["channel"], None) + node = self.cluster.get_node_from_key(message["channel"]) + if self.node_pubsub_mapping[node.name].subscribed is False: + self.node_pubsub_mapping.pop(node.name) if not self.channels and not self.patterns and not self.shard_channels: # There are no subscriptions anymore, set subscribed_event flag # to false diff --git a/redis/commands/core.py b/redis/commands/core.py index 54289ed77c..6676ea8d71 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -5120,6 +5120,14 @@ def pubsub_channels(self, pattern: PatternT = "*", **kwargs) -> ResponseT: """ return self.execute_command("PUBSUB CHANNELS", pattern, **kwargs) + def pubsub_shardchannels(self, pattern: PatternT = "*", **kwargs) -> ResponseT: + """ + Return a list of shard_channels that have at least one subscriber + + For more information see https://redis.io/commands/pubsub-shardchannels + """ + return self.execute_command("PUBSUB SHARDCHANNELS", pattern, **kwargs) + def pubsub_numpat(self, **kwargs) -> ResponseT: """ Returns the number of subscriptions to patterns @@ -5137,6 +5145,15 @@ def pubsub_numsub(self, *args: ChannelT, **kwargs) -> ResponseT: """ return self.execute_command("PUBSUB NUMSUB", *args, **kwargs) + def pubsub_shardnumsub(self, *args: ChannelT, **kwargs) -> ResponseT: + """ + Return a list of (shard_channel, number of subscribers) tuples + for each channel given in ``*args`` + + For more information see https://redis.io/commands/pubsub-shardnumsub + """ + return self.execute_command("PUBSUB SHARDNUMSUB", *args, **kwargs) + AsyncPubSubCommands = PubSubCommands diff --git a/redis/parsers/commands.py b/redis/parsers/commands.py index 93145ea093..d3b4a99ed3 100644 --- a/redis/parsers/commands.py +++ b/redis/parsers/commands.py @@ -155,7 +155,7 @@ def _get_pubsub_keys(self, *args): # the second argument is a part of the command name, e.g. # ['PUBSUB', 'NUMSUB', 'foo']. pubsub_type = args[1].upper() - if pubsub_type in ["CHANNELS", "NUMSUB"]: + if pubsub_type in ["CHANNELS", "NUMSUB", "SHARDCHANNELS", "SHARDNUMSUB"]: keys = args[2:] elif command in ["SUBSCRIBE", "PSUBSCRIBE", "UNSUBSCRIBE", "PUNSUBSCRIBE"]: # format example: diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index 7b52174788..2f6b4bad80 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -780,6 +780,38 @@ def test_pubsub_channels(self, r): expected = [b"bar", b"baz", b"foo", b"quux"] assert all([channel in r.pubsub_channels() for channel in expected]) + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("7.0.0") + def test_pubsub_shardchannels(self, r): + p = r.pubsub() + p.ssubscribe("foo", "bar", "baz", "quux") + for i in range(4): + assert wait_for_message(p)["type"] == "ssubscribe" + expected = [b"bar", b"baz", b"foo", b"quux"] + assert all([channel in r.pubsub_shardchannels() for channel in expected]) + + @pytest.mark.onlycluster + @skip_if_server_version_lt("7.0.0") + def test_pubsub_shardchannels_cluster(self, r): + channels = { + b"foo": r.get_node_from_key("foo"), + b"bar": r.get_node_from_key("bar"), + b"baz": r.get_node_from_key("baz"), + b"quux": r.get_node_from_key("quux"), + } + p = r.pubsub() + p.ssubscribe("foo", "bar", "baz", "quux") + for node in channels.values(): + assert wait_for_message(p, node=node)["type"] == "ssubscribe" + for channel, node in channels.items(): + assert channel in r.pubsub_shardchannels(target_nodes=node) + assert all( + [ + channel in r.pubsub_shardchannels(target_nodes="all") + for channel in channels.keys() + ] + ) + @pytest.mark.onlynoncluster @skip_if_server_version_lt("2.8.0") def test_pubsub_numsub(self, r): @@ -806,6 +838,32 @@ def test_pubsub_numpat(self, r): assert wait_for_message(p)["type"] == "psubscribe" assert r.pubsub_numpat() == 3 + @pytest.mark.onlycluster + @skip_if_server_version_lt("7.0.0") + def test_pubsub_shardnumsub(self, r): + channels = { + b"foo": r.get_node_from_key("foo"), + b"bar": r.get_node_from_key("bar"), + b"baz": r.get_node_from_key("baz"), + } + p1 = r.pubsub() + p1.ssubscribe(*channels.keys()) + for node in channels.values(): + assert wait_for_message(p1, node=node)["type"] == "ssubscribe" + p2 = r.pubsub() + p2.ssubscribe("bar", "baz") + for i in range(2): + assert ( + wait_for_message(p2, func=p2.get_sharded_message)["type"] + == "ssubscribe" + ) + p3 = r.pubsub() + p3.ssubscribe("baz") + assert wait_for_message(p3, node=channels[b"baz"])["type"] == "ssubscribe" + + channels = [(b"foo", 1), (b"bar", 2), (b"baz", 3)] + assert r.pubsub_shardnumsub("foo", "bar", "baz", target_nodes="all") == channels + class TestPubSubPings: @skip_if_server_version_lt("3.0.0") From dae5d0ee3dd80b9a722121378d06448d2cb51c2f Mon Sep 17 00:00:00 2001 From: dvora-h Date: Wed, 24 May 2023 12:48:34 +0300 Subject: [PATCH 08/10] fix @leibale comments --- redis/client.py | 21 +++++++++------------ redis/parsers/socket.py | 6 +++--- 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/redis/client.py b/redis/client.py index 82028a1a53..e4b42c3fff 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1599,7 +1599,7 @@ def _execute(self, conn, command, *args, **kwargs): lambda error: self._disconnect_raise_connect(conn, error), ) - def parse_response(self, block=True, timeout=0, **kwargs): + def parse_response(self, block=True, timeout=0): """Parse the response from a publish/subscribe command""" conn = self.connection if conn is None: @@ -1612,8 +1612,7 @@ def parse_response(self, block=True, timeout=0, **kwargs): def try_read(): if not block: - can_read = conn.can_read(timeout=timeout) - if not can_read: + if not conn.can_read(timeout=timeout): return None else: conn.connect() @@ -1740,12 +1739,11 @@ def unsubscribe(self, *args): def ssubscribe(self, *args, target_node=None, **kwargs): """ - # TODO: update docstring - Subscribe to channels. Channels supplied as keyword arguments expect - a channel name as the key and a callable as the value. A channel's - callable will be invoked automatically when a message is received on - that channel rather than producing a message via ``listen()`` or - ``get_message()``. + Subscribes the client to the specified shard channels. + Channels supplied as keyword arguments expect a channel name as the key + and a callable as the value. A channel's callable will be invoked automatically + when a message is received on that channel rather than producing a message via + ``listen()`` or ``get_sharded_message()``. """ if args: args = list_or_args(args[0], args[1:]) @@ -1767,9 +1765,8 @@ def ssubscribe(self, *args, target_node=None, **kwargs): def sunsubscribe(self, *args, target_node=None): """ - # TODO: update docstring - Unsubscribe from the supplied channels. If empty, unsubscribe from - all channels + Unsubscribe from the supplied shard_channels. If empty, unsubscribe from + all shard_channels """ if args: args = list_or_args(args[0], args[1:]) diff --git a/redis/parsers/socket.py b/redis/parsers/socket.py index f1b2109cd4..8147243bba 100644 --- a/redis/parsers/socket.py +++ b/redis/parsers/socket.py @@ -92,9 +92,9 @@ def _read_from_socket( sock.settimeout(self.socket_timeout) def can_read(self, timeout: float) -> bool: - read = self._read_from_socket(timeout=timeout, raise_on_timeout=False) - _bytes = bool(self.unread_bytes()) - return _bytes or read + return bool(self.unread_bytes()) or self._read_from_socket( + timeout=timeout, raise_on_timeout=False + ) def read(self, length: int) -> bytes: length = length + 2 # make sure to read the \r\n terminator From 195f3d2087bedf0d3c2cf40d2367bccda05b2dc7 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Wed, 24 May 2023 12:51:21 +0300 Subject: [PATCH 09/10] linters --- redis/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis/client.py b/redis/client.py index e4b42c3fff..e94e6e54a5 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1740,7 +1740,7 @@ def unsubscribe(self, *args): def ssubscribe(self, *args, target_node=None, **kwargs): """ Subscribes the client to the specified shard channels. - Channels supplied as keyword arguments expect a channel name as the key + Channels supplied as keyword arguments expect a channel name as the key and a callable as the value. A channel's callable will be invoked automatically when a message is received on that channel rather than producing a message via ``listen()`` or ``get_sharded_message()``. From 5bd425db367071b3d0c572c2c8b299ca538e6808 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Wed, 24 May 2023 15:38:27 +0300 Subject: [PATCH 10/10] fix @chayim comments --- redis/client.py | 23 ++++++++++++----------- redis/cluster.py | 4 +++- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/redis/client.py b/redis/client.py index e94e6e54a5..ef327b5922 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1497,10 +1497,10 @@ def reset(self): self.health_check_response_counter = 0 self.channels = {} self.pending_unsubscribe_channels = set() - self.patterns = {} - self.pending_unsubscribe_patterns = set() self.shard_channels = {} self.pending_unsubscribe_shard_channels = set() + self.patterns = {} + self.pending_unsubscribe_patterns = set() self.subscribed_event.clear() def close(self): @@ -1515,19 +1515,20 @@ def on_connect(self, connection): self.pending_unsubscribe_patterns.clear() self.pending_unsubscribe_shard_channels.clear() if self.channels: - channels = {} - for k, v in self.channels.items(): - channels[self.encoder.decode(k, force=True)] = v + channels = { + self.encoder.decode(k, force=True): v for k, v in self.channels.items() + } self.subscribe(**channels) if self.patterns: - patterns = {} - for k, v in self.patterns.items(): - patterns[self.encoder.decode(k, force=True)] = v + patterns = { + self.encoder.decode(k, force=True): v for k, v in self.patterns.items() + } self.psubscribe(**patterns) if self.shard_channels: - shard_channels = {} - for k, v in self.shard_channels.items(): - shard_channels[self.encoder.decode(k, force=True)] = v + shard_channels = { + self.encoder.decode(k, force=True): v + for k, v in self.shard_channels.items() + } self.ssubscribe(**shard_channels) @property diff --git a/redis/cluster.py b/redis/cluster.py index da238b8c54..d3956e45f5 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -223,7 +223,9 @@ class AbstractRedisCluster: "PUBSUB CHANNELS", "PUBSUB NUMPAT", "PUBSUB NUMSUB", - "PUBSUB SHARDCHANNELS" "PUBSUB SHARDNUMSUB" "PING", + "PUBSUB SHARDCHANNELS", + "PUBSUB SHARDNUMSUB", + "PING", "INFO", "SHUTDOWN", "KEYS",