Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 177b884

Browse files
authored
Lay some foundation work to allow workers to only subscribe to some kinds of messages, reducing replication traffic. (#12672)
1 parent eb4aaa1 commit 177b884

File tree

5 files changed

+173
-24
lines changed

5 files changed

+173
-24
lines changed

changelog.d/12672.misc

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Lay some foundation work to allow workers to only subscribe to some kinds of messages, reducing replication traffic.

synapse/replication/tcp/handler.py

+32-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Copyright 2017 Vector Creations Ltd
2-
# Copyright 2020 The Matrix.org Foundation C.I.C.
2+
# Copyright 2020, 2022 The Matrix.org Foundation C.I.C.
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
@@ -101,6 +101,9 @@ def __init__(self, hs: "HomeServer"):
101101
self._instance_id = hs.get_instance_id()
102102
self._instance_name = hs.get_instance_name()
103103

104+
# Additional Redis channel suffixes to subscribe to.
105+
self._channels_to_subscribe_to: List[str] = []
106+
104107
self._is_presence_writer = (
105108
hs.get_instance_name() in hs.config.worker.writers.presence
106109
)
@@ -243,6 +246,31 @@ def __init__(self, hs: "HomeServer"):
243246
# If we're NOT using Redis, this must be handled by the master
244247
self._should_insert_client_ips = hs.get_instance_name() == "master"
245248

249+
if self._is_master or self._should_insert_client_ips:
250+
self.subscribe_to_channel("USER_IP")
251+
252+
def subscribe_to_channel(self, channel_name: str) -> None:
253+
"""
254+
Indicates that we wish to subscribe to a Redis channel by name.
255+
256+
(The name will later be prefixed with the server name; i.e. subscribing
257+
to the 'ABC' channel actually subscribes to 'example.com/ABC' Redis-side.)
258+
259+
Raises:
260+
- If replication has already started, then it's too late to subscribe
261+
to new channels.
262+
"""
263+
264+
if self._factory is not None:
265+
# We don't allow subscribing after the fact to avoid the chance
266+
# of missing an important message because we didn't subscribe in time.
267+
raise RuntimeError(
268+
"Cannot subscribe to more channels after replication started."
269+
)
270+
271+
if channel_name not in self._channels_to_subscribe_to:
272+
self._channels_to_subscribe_to.append(channel_name)
273+
246274
def _add_command_to_stream_queue(
247275
self, conn: IReplicationConnection, cmd: Union[RdataCommand, PositionCommand]
248276
) -> None:
@@ -321,7 +349,9 @@ def start_replication(self, hs: "HomeServer") -> None:
321349

322350
# Now create the factory/connection for the subscription stream.
323351
self._factory = RedisDirectTcpReplicationClientFactory(
324-
hs, outbound_redis_connection
352+
hs,
353+
outbound_redis_connection,
354+
channel_names=self._channels_to_subscribe_to,
325355
)
326356
hs.get_reactor().connectTCP(
327357
hs.config.redis.redis_host,

synapse/replication/tcp/redis.py

+25-10
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import logging
1616
from inspect import isawaitable
17-
from typing import TYPE_CHECKING, Any, Generic, Optional, Type, TypeVar, cast
17+
from typing import TYPE_CHECKING, Any, Generic, List, Optional, Type, TypeVar, cast
1818

1919
import attr
2020
import txredisapi
@@ -85,14 +85,15 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
8585
8686
Attributes:
8787
synapse_handler: The command handler to handle incoming commands.
88-
synapse_stream_name: The *redis* stream name to subscribe to and publish
88+
synapse_stream_prefix: The *redis* stream name to subscribe to and publish
8989
from (not anything to do with Synapse replication streams).
9090
synapse_outbound_redis_connection: The connection to redis to use to send
9191
commands.
9292
"""
9393

9494
synapse_handler: "ReplicationCommandHandler"
95-
synapse_stream_name: str
95+
synapse_stream_prefix: str
96+
synapse_channel_names: List[str]
9697
synapse_outbound_redis_connection: txredisapi.ConnectionHandler
9798

9899
def __init__(self, *args: Any, **kwargs: Any):
@@ -117,8 +118,13 @@ async def _send_subscribe(self) -> None:
117118
# it's important to make sure that we only send the REPLICATE command once we
118119
# have successfully subscribed to the stream - otherwise we might miss the
119120
# POSITION response sent back by the other end.
120-
logger.info("Sending redis SUBSCRIBE for %s", self.synapse_stream_name)
121-
await make_deferred_yieldable(self.subscribe(self.synapse_stream_name))
121+
fully_qualified_stream_names = [
122+
f"{self.synapse_stream_prefix}/{stream_suffix}"
123+
for stream_suffix in self.synapse_channel_names
124+
] + [self.synapse_stream_prefix]
125+
logger.info("Sending redis SUBSCRIBE for %r", fully_qualified_stream_names)
126+
await make_deferred_yieldable(self.subscribe(fully_qualified_stream_names))
127+
122128
logger.info(
123129
"Successfully subscribed to redis stream, sending REPLICATE command"
124130
)
@@ -217,7 +223,7 @@ async def _async_send_command(self, cmd: Command) -> None:
217223

218224
await make_deferred_yieldable(
219225
self.synapse_outbound_redis_connection.publish(
220-
self.synapse_stream_name, encoded_string
226+
self.synapse_stream_prefix, encoded_string
221227
)
222228
)
223229

@@ -300,20 +306,27 @@ def format_address(address: IAddress) -> str:
300306

301307
class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
302308
"""This is a reconnecting factory that connects to redis and immediately
303-
subscribes to a stream.
309+
subscribes to some streams.
304310
305311
Args:
306312
hs
307313
outbound_redis_connection: A connection to redis that will be used to
308314
send outbound commands (this is separate to the redis connection
309315
used to subscribe).
316+
channel_names: A list of channel names to append to the base channel name
317+
to additionally subscribe to.
318+
e.g. if ['ABC', 'DEF'] is specified then we'll listen to:
319+
example.com; example.com/ABC; and example.com/DEF.
310320
"""
311321

312322
maxDelay = 5
313323
protocol = RedisSubscriber
314324

315325
def __init__(
316-
self, hs: "HomeServer", outbound_redis_connection: txredisapi.ConnectionHandler
326+
self,
327+
hs: "HomeServer",
328+
outbound_redis_connection: txredisapi.ConnectionHandler,
329+
channel_names: List[str],
317330
):
318331

319332
super().__init__(
@@ -326,7 +339,8 @@ def __init__(
326339
)
327340

328341
self.synapse_handler = hs.get_replication_command_handler()
329-
self.synapse_stream_name = hs.hostname
342+
self.synapse_stream_prefix = hs.hostname
343+
self.synapse_channel_names = channel_names
330344

331345
self.synapse_outbound_redis_connection = outbound_redis_connection
332346

@@ -340,7 +354,8 @@ def buildProtocol(self, addr: IAddress) -> RedisSubscriber:
340354
# protocol.
341355
p.synapse_handler = self.synapse_handler
342356
p.synapse_outbound_redis_connection = self.synapse_outbound_redis_connection
343-
p.synapse_stream_name = self.synapse_stream_name
357+
p.synapse_stream_prefix = self.synapse_stream_prefix
358+
p.synapse_channel_names = self.synapse_channel_names
344359

345360
return p
346361

tests/replication/_base.py

+42-12
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import logging
15-
from typing import Any, Dict, List, Optional, Tuple
15+
from collections import defaultdict
16+
from typing import Any, Dict, List, Optional, Set, Tuple
1617

1718
from twisted.internet.address import IPv4Address
1819
from twisted.internet.protocol import Protocol
@@ -32,6 +33,7 @@
3233

3334
from tests import unittest
3435
from tests.server import FakeTransport
36+
from tests.utils import USE_POSTGRES_FOR_TESTS
3537

3638
try:
3739
import hiredis
@@ -475,22 +477,25 @@ class FakeRedisPubSubServer:
475477
"""A fake Redis server for pub/sub."""
476478

477479
def __init__(self):
478-
self._subscribers = set()
480+
self._subscribers_by_channel: Dict[
481+
bytes, Set["FakeRedisPubSubProtocol"]
482+
] = defaultdict(set)
479483

480-
def add_subscriber(self, conn):
484+
def add_subscriber(self, conn, channel: bytes):
481485
"""A connection has called SUBSCRIBE"""
482-
self._subscribers.add(conn)
486+
self._subscribers_by_channel[channel].add(conn)
483487

484488
def remove_subscriber(self, conn):
485-
"""A connection has called UNSUBSCRIBE"""
486-
self._subscribers.discard(conn)
489+
"""A connection has lost connection"""
490+
for subscribers in self._subscribers_by_channel.values():
491+
subscribers.discard(conn)
487492

488-
def publish(self, conn, channel, msg) -> int:
493+
def publish(self, conn, channel: bytes, msg) -> int:
489494
"""A connection want to publish a message to subscribers."""
490-
for sub in self._subscribers:
495+
for sub in self._subscribers_by_channel[channel]:
491496
sub.send(["message", channel, msg])
492497

493-
return len(self._subscribers)
498+
return len(self._subscribers_by_channel)
494499

495500
def buildProtocol(self, addr):
496501
return FakeRedisPubSubProtocol(self)
@@ -531,9 +536,10 @@ def handle_command(self, command, *args):
531536
num_subscribers = self._server.publish(self, channel, message)
532537
self.send(num_subscribers)
533538
elif command == b"SUBSCRIBE":
534-
(channel,) = args
535-
self._server.add_subscriber(self)
536-
self.send(["subscribe", channel, 1])
539+
for idx, channel in enumerate(args):
540+
num_channels = idx + 1
541+
self._server.add_subscriber(self, channel)
542+
self.send(["subscribe", channel, num_channels])
537543

538544
# Since we use SET/GET to cache things we can safely no-op them.
539545
elif command == b"SET":
@@ -576,3 +582,27 @@ def encode(self, obj):
576582

577583
def connectionLost(self, reason):
578584
self._server.remove_subscriber(self)
585+
586+
587+
class RedisMultiWorkerStreamTestCase(BaseMultiWorkerStreamTestCase):
588+
"""
589+
A test case that enables Redis, providing a fake Redis server.
590+
"""
591+
592+
if not hiredis:
593+
skip = "Requires hiredis"
594+
595+
if not USE_POSTGRES_FOR_TESTS:
596+
# Redis replication only takes place on Postgres
597+
skip = "Requires Postgres"
598+
599+
def default_config(self) -> Dict[str, Any]:
600+
"""
601+
Overrides the default config to enable Redis.
602+
Even if the test only uses make_worker_hs, the main process needs Redis
603+
enabled otherwise it won't create a Fake Redis server to listen on the
604+
Redis port and accept fake TCP connections.
605+
"""
606+
base = super().default_config()
607+
base["redis"] = {"enabled": True}
608+
return base

tests/replication/tcp/test_handler.py

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright 2022 The Matrix.org Foundation C.I.C.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from tests.replication._base import RedisMultiWorkerStreamTestCase
16+
17+
18+
class ChannelsTestCase(RedisMultiWorkerStreamTestCase):
19+
def test_subscribed_to_enough_redis_channels(self) -> None:
20+
# The default main process is subscribed to the USER_IP channel.
21+
self.assertCountEqual(
22+
self.hs.get_replication_command_handler()._channels_to_subscribe_to,
23+
["USER_IP"],
24+
)
25+
26+
def test_background_worker_subscribed_to_user_ip(self) -> None:
27+
# The default main process is subscribed to the USER_IP channel.
28+
worker1 = self.make_worker_hs(
29+
"synapse.app.generic_worker",
30+
extra_config={
31+
"worker_name": "worker1",
32+
"run_background_tasks_on": "worker1",
33+
"redis": {"enabled": True},
34+
},
35+
)
36+
self.assertIn(
37+
"USER_IP",
38+
worker1.get_replication_command_handler()._channels_to_subscribe_to,
39+
)
40+
41+
# Advance so the Redis subscription gets processed
42+
self.pump(0.1)
43+
44+
# The counts are 2 because both the main process and the worker are subscribed.
45+
self.assertEqual(len(self._redis_server._subscribers_by_channel[b"test"]), 2)
46+
self.assertEqual(
47+
len(self._redis_server._subscribers_by_channel[b"test/USER_IP"]), 2
48+
)
49+
50+
def test_non_background_worker_not_subscribed_to_user_ip(self) -> None:
51+
# The default main process is subscribed to the USER_IP channel.
52+
worker2 = self.make_worker_hs(
53+
"synapse.app.generic_worker",
54+
extra_config={
55+
"worker_name": "worker2",
56+
"run_background_tasks_on": "worker1",
57+
"redis": {"enabled": True},
58+
},
59+
)
60+
self.assertNotIn(
61+
"USER_IP",
62+
worker2.get_replication_command_handler()._channels_to_subscribe_to,
63+
)
64+
65+
# Advance so the Redis subscription gets processed
66+
self.pump(0.1)
67+
68+
# The count is 2 because both the main process and the worker are subscribed.
69+
self.assertEqual(len(self._redis_server._subscribers_by_channel[b"test"]), 2)
70+
# For USER_IP, the count is 1 because only the main process is subscribed.
71+
self.assertEqual(
72+
len(self._redis_server._subscribers_by_channel[b"test/USER_IP"]), 1
73+
)

0 commit comments

Comments
 (0)