Skip to content

Fix bug where Sliding Sync could get stuck when using workers #17438

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/17438.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix rare bug where `/sync` would break for a user when using workers with multiple stream writers.
11 changes: 9 additions & 2 deletions synapse/handlers/sliding_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,10 +640,17 @@ async def get_sync_room_ids_for_user(
instance_to_max_stream_ordering_map[instance_name] = stream_ordering

# Then assemble the `RoomStreamToken`
min_stream_pos = min(instance_to_max_stream_ordering_map.values())
membership_snapshot_token = RoomStreamToken(
# Minimum position in the `instance_map`
stream=min(instance_to_max_stream_ordering_map.values()),
instance_map=immutabledict(instance_to_max_stream_ordering_map),
Comment on lines -645 to -646
Copy link
Contributor

@MadLittleMods MadLittleMods Jul 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused how this previous code could result in an instance_map that "contained entries from before the minimum token." We're getting the minimum in that map and using the same map.

I could see how it could include entries at the minimum stream_ordering (and not being necessary) but not before.

Did this spot actually cause bugs?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, so if all the instances are equal to the minimum it will cause a problem. The exact bug here is where during serialization we filtered out any instances with an equal value as a minimum, and then if that filtered all instances out we wrote an invalid token (we wrote eg m54~ instead of s54).

stream=min_stream_pos,
instance_map=immutabledict(
{
instance_name: stream_pos
for instance_name, stream_pos in instance_to_max_stream_ordering_map.items()
if stream_pos > min_stream_pos
}
),
)

# Since we fetched the users room list at some point in time after the from/to
Expand Down
65 changes: 57 additions & 8 deletions synapse/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#
#
import abc
import logging
import re
import string
from enum import Enum
Expand Down Expand Up @@ -74,6 +75,9 @@
from synapse.storage.databases.main import DataStore, PurgeEventsStore
from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore


logger = logging.getLogger(__name__)

# Define a state map type from type/state_key to T (usually an event ID or
# event)
T = TypeVar("T")
Expand Down Expand Up @@ -454,6 +458,8 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
represented by a default `stream` attribute and a map of instance name to
stream position of any writers that are ahead of the default stream
position.

The values in `instance_map` must be greater than the `stream` attribute.
"""

stream: int = attr.ib(validator=attr.validators.instance_of(int), kw_only=True)
Expand All @@ -468,6 +474,15 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
kw_only=True,
)

def __attrs_post_init__(self) -> None:
# Enforce that all instances have a value greater than the min stream
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

greater than or equal to*

likewise in test_instance_map_assertion

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, its probably better to do a strict check.

# position.
for i, v in self.instance_map.items():
if v <= self.stream:
raise ValueError(
f"'instance_map' includes a stream position before the main 'stream' attribute. Instance: {i}"
)

@classmethod
@abc.abstractmethod
async def parse(cls, store: "DataStore", string: str) -> "Self":
Expand All @@ -494,6 +509,9 @@ def copy_and_advance(self, other: "Self") -> "Self":
for instance in set(self.instance_map).union(other.instance_map)
}

# Filter out any redundant entries.
instance_map = {i: s for i, s in instance_map.items() if s > max_stream}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is new, I forgot to unstash the change. It's now important since we do strict checks.


return attr.evolve(
self, stream=max_stream, instance_map=immutabledict(instance_map)
)
Expand Down Expand Up @@ -539,10 +557,15 @@ def is_before_or_eq(self, other_token: Self) -> bool:
def bound_stream_token(self, max_stream: int) -> "Self":
"""Bound the stream positions to a maximum value"""

min_pos = min(self.stream, max_stream)
return type(self)(
stream=min(self.stream, max_stream),
stream=min_pos,
instance_map=immutabledict(
{k: min(s, max_stream) for k, s in self.instance_map.items()}
{
k: min(s, max_stream)
for k, s in self.instance_map.items()
if min(s, max_stream) > min_pos
}
),
)

Expand Down Expand Up @@ -637,6 +660,8 @@ def __attrs_post_init__(self) -> None:
"Cannot set both 'topological' and 'instance_map' on 'RoomStreamToken'."
)

super().__attrs_post_init__()

@classmethod
async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken":
try:
Expand All @@ -651,6 +676,11 @@ async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken

instance_map = {}
for part in parts[1:]:
if not part:
# Handle tokens of the form `m5~`, which were created by
# a bug
continue

key, value = part.split(".")
instance_id = int(key)
pos = int(value)
Expand All @@ -666,7 +696,10 @@ async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken
except CancelledError:
raise
except Exception:
pass
# We log an exception here as even though this *might* be a client
# handing a bad token, its more likely that Synapse returned a bad
# token (and we really want to catch those!).
logger.exception("Failed to parse stream token: %r", string)
raise SynapseError(400, "Invalid room stream token %r" % (string,))

@classmethod
Expand Down Expand Up @@ -713,6 +746,8 @@ def get_stream_pos_for_instance(self, instance_name: str) -> int:
return self.instance_map.get(instance_name, self.stream)

async def to_string(self, store: "DataStore") -> str:
"""See class level docstring for information about the format."""

if self.topological is not None:
return "t%d-%d" % (self.topological, self.stream)
elif self.instance_map:
Expand All @@ -727,8 +762,10 @@ async def to_string(self, store: "DataStore") -> str:
instance_id = await store.get_id_for_instance(name)
entries.append(f"{instance_id}.{pos}")

encoded_map = "~".join(entries)
return f"m{self.stream}~{encoded_map}"
if entries:
encoded_map = "~".join(entries)
return f"m{self.stream}~{encoded_map}"
return f"s{self.stream}"
else:
return "s%d" % (self.stream,)

Expand Down Expand Up @@ -756,6 +793,11 @@ async def parse(cls, store: "DataStore", string: str) -> "MultiWriterStreamToken

instance_map = {}
for part in parts[1:]:
if not part:
# Handle tokens of the form `m5~`, which were created by
# a bug
continue

key, value = part.split(".")
instance_id = int(key)
pos = int(value)
Expand All @@ -770,10 +812,15 @@ async def parse(cls, store: "DataStore", string: str) -> "MultiWriterStreamToken
except CancelledError:
raise
except Exception:
pass
# We log an exception here as even though this *might* be a client
# handing a bad token, its more likely that Synapse returned a bad
# token (and we really want to catch those!).
logger.exception("Failed to parse stream token: %r", string)
raise SynapseError(400, "Invalid stream token %r" % (string,))

async def to_string(self, store: "DataStore") -> str:
"""See class level docstring for information about the format."""

if self.instance_map:
entries = []
for name, pos in self.instance_map.items():
Expand All @@ -786,8 +833,10 @@ async def to_string(self, store: "DataStore") -> str:
instance_id = await store.get_id_for_instance(name)
entries.append(f"{instance_id}.{pos}")

encoded_map = "~".join(entries)
return f"m{self.stream}~{encoded_map}"
if entries:
encoded_map = "~".join(entries)
return f"m{self.stream}~{encoded_map}"
return str(self.stream)
else:
return str(self.stream)

Expand Down
71 changes: 71 additions & 0 deletions tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,26 @@
#
#

from typing import Type
from unittest import skipUnless

from immutabledict import immutabledict
from parameterized import parameterized_class

from synapse.api.errors import SynapseError
from synapse.types import (
AbstractMultiWriterStreamToken,
MultiWriterStreamToken,
RoomAlias,
RoomStreamToken,
UserID,
get_domain_from_id,
get_localpart_from_id,
map_username_to_mxid_localpart,
)

from tests import unittest
from tests.utils import USE_POSTGRES_FOR_TESTS


class IsMineIDTests(unittest.HomeserverTestCase):
Expand Down Expand Up @@ -127,3 +137,64 @@ def test_non_ascii(self) -> None:
# this should work with either a unicode or a bytes
self.assertEqual(map_username_to_mxid_localpart("têst"), "t=c3=aast")
self.assertEqual(map_username_to_mxid_localpart("têst".encode()), "t=c3=aast")


@parameterized_class(
("token_type",),
[
(MultiWriterStreamToken,),
(RoomStreamToken,),
],
class_name_func=lambda cls, num, params_dict: f"{cls.__name__}_{params_dict['token_type'].__name__}",
)
class MultiWriterTokenTestCase(unittest.HomeserverTestCase):
"""Tests for the different types of multi writer tokens."""

token_type: Type[AbstractMultiWriterStreamToken]

def test_basic_token(self) -> None:
"""Test that a simple stream token can be serialized and unserialized"""
store = self.hs.get_datastores().main

token = self.token_type(stream=5)

string_token = self.get_success(token.to_string(store))

if isinstance(token, RoomStreamToken):
self.assertEqual(string_token, "s5")
else:
self.assertEqual(string_token, "5")

parsed_token = self.get_success(self.token_type.parse(store, string_token))
self.assertEqual(parsed_token, token)

@skipUnless(USE_POSTGRES_FOR_TESTS, "Requires Postgres")
def test_instance_map(self) -> None:
"""Test for stream token with instance map"""
store = self.hs.get_datastores().main

token = self.token_type(stream=5, instance_map=immutabledict({"foo": 6}))

string_token = self.get_success(token.to_string(store))
self.assertEqual(string_token, "m5~1.6")

parsed_token = self.get_success(self.token_type.parse(store, string_token))
self.assertEqual(parsed_token, token)

def test_instance_map_assertion(self) -> None:
"""Test that we assert values in the instance map are greater than the
min stream position"""

with self.assertRaises(ValueError):
self.token_type(stream=5, instance_map=immutabledict({"foo": 4}))

with self.assertRaises(ValueError):
self.token_type(stream=5, instance_map=immutabledict({"foo": 5}))

def test_parse_bad_token(self) -> None:
"""Test that we can parse tokens produced by a bug in Synapse of the
form `m5~`"""
store = self.hs.get_datastores().main

parsed_token = self.get_success(self.token_type.parse(store, "m5~"))
self.assertEqual(parsed_token, self.token_type(stream=5))
Loading