Skip to content

Commit c41b515

Browse files
committed
Add support for MSC3202 in appservice module
1 parent e2ce035 commit c41b515

File tree

2 files changed

+44
-19
lines changed

2 files changed

+44
-19
lines changed

mautrix/appservice/as_handler.py

+33-17
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@
44
# License, v. 2.0. If a copy of the MPL was not distributed with this
55
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
66
# Partly based on github.com/Cadair/python-appservice-framework (MIT license)
7-
from typing import Optional, Callable, Awaitable, List, Set
7+
from typing import Optional, Callable, Awaitable, List, Set, Dict, Any
88
from json import JSONDecodeError
99
from aiohttp import web
1010
import asyncio
1111
import logging
1212

13-
from mautrix.types import JSON, UserID, RoomAlias, Event, EphemeralEvent, SerializerError
13+
from mautrix.types import (JSON, UserID, RoomAlias, Event, EphemeralEvent, SerializerError,
14+
DeviceOTKCount, DeviceLists)
1415

1516
QueryFunc = Callable[[web.Request], Awaitable[Optional[web.Response]]]
1617
HandlerFunc = Callable[[Event], Awaitable]
@@ -102,6 +103,17 @@ async def _http_query_alias(self, request: web.Request) -> web.Response:
102103
return web.json_response({}, status=404)
103104
return web.json_response(response)
104105

106+
@staticmethod
107+
def _get_with_fallback(json: Dict[str, Any], field: str, unstable_prefix: str,
108+
default: Any = None) -> Any:
109+
try:
110+
return json.pop(field)
111+
except KeyError:
112+
try:
113+
return json.pop(f"{unstable_prefix}.{field}")
114+
except KeyError:
115+
return default
116+
105117
async def _http_handle_transaction(self, request: web.Request) -> web.Response:
106118
if not self._check_token(request):
107119
return web.json_response({"error": "Invalid auth token"}, status=401)
@@ -116,29 +128,30 @@ async def _http_handle_transaction(self, request: web.Request) -> web.Response:
116128
return web.json_response({"error": "Body is not JSON"}, status=400)
117129

118130
try:
119-
events = json["events"]
131+
events = json.pop("events")
120132
except KeyError:
121133
return web.json_response({"error": "Missing events object in body"}, status=400)
122134

123-
if self.ephemeral_events:
124-
try:
125-
ephemeral = json["ephemeral"]
126-
except KeyError:
127-
try:
128-
ephemeral = json["de.sorunome.msc2409.ephemeral"]
129-
except KeyError:
130-
ephemeral = None
131-
else:
132-
ephemeral = None
135+
ephemeral = (self._get_with_fallback(json, "ephemeral", "de.sorunome.msc2409")
136+
if self.ephemeral_events else None)
137+
device_lists = DeviceLists.deserialize(
138+
self._get_with_fallback(json, "device_lists", "org.matrix.msc3202"))
139+
otk_counts = {user_id: DeviceOTKCount.deserialize(count)
140+
for user_id, count
141+
in self._get_with_fallback(json, "device_one_time_keys_count",
142+
"org.matrix.msc3202", default={}).items()}
133143

134144
try:
135-
await self.handle_transaction(transaction_id, events=events, ephemeral=ephemeral)
145+
output = await self.handle_transaction(transaction_id, events=events, extra_data=json,
146+
ephemeral=ephemeral, device_lists=device_lists,
147+
device_otk_count=otk_counts)
136148
except Exception:
137149
self.log.exception("Exception in transaction handler")
150+
output = None
138151

139152
self.transactions.add(transaction_id)
140153

141-
return web.json_response({})
154+
return web.json_response(output or {})
142155

143156
@staticmethod
144157
def _fix_prev_content(raw_event: JSON) -> None:
@@ -150,8 +163,10 @@ def _fix_prev_content(raw_event: JSON) -> None:
150163
except KeyError:
151164
pass
152165

153-
async def handle_transaction(self, txn_id: str, events: List[JSON],
154-
ephemeral: Optional[List[JSON]] = None) -> None:
166+
async def handle_transaction(self, txn_id: str, *, events: List[JSON], extra_data: JSON,
167+
ephemeral: Optional[List[JSON]] = None,
168+
device_otk_count: Optional[Dict[UserID, DeviceOTKCount]] = None,
169+
device_lists: Optional[DeviceLists] = None) -> Optional[JSON]:
155170
for raw_edu in ephemeral or []:
156171
try:
157172
edu = EphemeralEvent.deserialize(raw_edu)
@@ -167,6 +182,7 @@ async def handle_transaction(self, txn_id: str, events: List[JSON],
167182
self.log.exception("Failed to deserialize event %s", raw_event)
168183
else:
169184
self.handle_matrix_event(event)
185+
return {}
170186

171187
def handle_matrix_event(self, event: Event) -> None:
172188
if event.type.is_state and event.state_key is None:

mautrix/types/misc.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,17 @@
1313
from .util import SerializableAttrs
1414
from .event import Event
1515

16-
DeviceLists = NamedTuple("DeviceLists", changed=List[UserID], left=List[UserID])
17-
DeviceOTKCount = NamedTuple("DeviceOTKCount", curve25519=int, signed_curve25519=int)
16+
17+
@dataclass
18+
class DeviceLists(SerializableAttrs):
19+
changed: List[UserID] = attr.ib(factory=lambda: [])
20+
left: List[UserID] = attr.ib(factory=lambda: [])
21+
22+
23+
@dataclass
24+
class DeviceOTKCount(SerializableAttrs):
25+
curve25519: int
26+
signed_curve25519: int
1827

1928

2029
class RoomCreatePreset(Enum):

0 commit comments

Comments
 (0)