|
1 | 1 | # Copyright 2017 Vector Creations Ltd
|
2 |
| -# Copyright 2020 The Matrix.org Foundation C.I.C. |
| 2 | +# Copyright 2020, 2022 The Matrix.org Foundation C.I.C. |
3 | 3 | #
|
4 | 4 | # Licensed under the Apache License, Version 2.0 (the "License");
|
5 | 5 | # you may not use this file except in compliance with the License.
|
@@ -101,6 +101,9 @@ def __init__(self, hs: "HomeServer"):
|
101 | 101 | self._instance_id = hs.get_instance_id()
|
102 | 102 | self._instance_name = hs.get_instance_name()
|
103 | 103 |
|
| 104 | + # Additional Redis channel suffixes to subscribe to. |
| 105 | + self._channels_to_subscribe_to: List[str] = [] |
| 106 | + |
104 | 107 | self._is_presence_writer = (
|
105 | 108 | hs.get_instance_name() in hs.config.worker.writers.presence
|
106 | 109 | )
|
@@ -243,6 +246,31 @@ def __init__(self, hs: "HomeServer"):
|
243 | 246 | # If we're NOT using Redis, this must be handled by the master
|
244 | 247 | self._should_insert_client_ips = hs.get_instance_name() == "master"
|
245 | 248 |
|
| 249 | + if self._is_master or self._should_insert_client_ips: |
| 250 | + self.subscribe_to_channel("USER_IP") |
| 251 | + |
| 252 | + def subscribe_to_channel(self, channel_name: str) -> None: |
| 253 | + """ |
| 254 | + Indicates that we wish to subscribe to a Redis channel by name. |
| 255 | +
|
| 256 | + (The name will later be prefixed with the server name; i.e. subscribing |
| 257 | + to the 'ABC' channel actually subscribes to 'example.com/ABC' Redis-side.) |
| 258 | +
|
| 259 | + Raises: |
| 260 | + - If replication has already started, then it's too late to subscribe |
| 261 | + to new channels. |
| 262 | + """ |
| 263 | + |
| 264 | + if self._factory is not None: |
| 265 | + # We don't allow subscribing after the fact to avoid the chance |
| 266 | + # of missing an important message because we didn't subscribe in time. |
| 267 | + raise RuntimeError( |
| 268 | + "Cannot subscribe to more channels after replication started." |
| 269 | + ) |
| 270 | + |
| 271 | + if channel_name not in self._channels_to_subscribe_to: |
| 272 | + self._channels_to_subscribe_to.append(channel_name) |
| 273 | + |
246 | 274 | def _add_command_to_stream_queue(
|
247 | 275 | self, conn: IReplicationConnection, cmd: Union[RdataCommand, PositionCommand]
|
248 | 276 | ) -> None:
|
@@ -323,9 +351,7 @@ def start_replication(self, hs: "HomeServer") -> None:
|
323 | 351 | self._factory = RedisDirectTcpReplicationClientFactory(
|
324 | 352 | hs,
|
325 | 353 | outbound_redis_connection,
|
326 |
| - channel_names=RedisDirectTcpReplicationClientFactory.channels_to_subscribe_to_for_config( |
327 |
| - hs.config |
328 |
| - ), |
| 354 | + channel_names=self._channels_to_subscribe_to, |
329 | 355 | )
|
330 | 356 | hs.get_reactor().connectTCP(
|
331 | 357 | hs.config.redis.redis_host,
|
|
0 commit comments