4
4
# License, v. 2.0. If a copy of the MPL was not distributed with this
5
5
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
6
6
# Partly based on github.com/Cadair/python-appservice-framework (MIT license)
7
- from typing import Optional , Callable , Awaitable , List , Set
7
+ from typing import Optional , Callable , Awaitable , List , Set , Dict , Any
8
8
from json import JSONDecodeError
9
9
from aiohttp import web
10
10
import asyncio
11
11
import logging
12
12
13
- from mautrix .types import JSON , UserID , RoomAlias , Event , EphemeralEvent , SerializerError
13
+ from mautrix .types import (JSON , UserID , RoomAlias , Event , EphemeralEvent , SerializerError ,
14
+ DeviceOTKCount , DeviceLists )
14
15
15
16
QueryFunc = Callable [[web .Request ], Awaitable [Optional [web .Response ]]]
16
17
HandlerFunc = Callable [[Event ], Awaitable ]
@@ -102,6 +103,17 @@ async def _http_query_alias(self, request: web.Request) -> web.Response:
102
103
return web .json_response ({}, status = 404 )
103
104
return web .json_response (response )
104
105
106
+ @staticmethod
107
+ def _get_with_fallback (json : Dict [str , Any ], field : str , unstable_prefix : str ,
108
+ default : Any = None ) -> Any :
109
+ try :
110
+ return json .pop (field )
111
+ except KeyError :
112
+ try :
113
+ return json .pop (f"{ unstable_prefix } .{ field } " )
114
+ except KeyError :
115
+ return default
116
+
105
117
async def _http_handle_transaction (self , request : web .Request ) -> web .Response :
106
118
if not self ._check_token (request ):
107
119
return web .json_response ({"error" : "Invalid auth token" }, status = 401 )
@@ -116,29 +128,30 @@ async def _http_handle_transaction(self, request: web.Request) -> web.Response:
116
128
return web .json_response ({"error" : "Body is not JSON" }, status = 400 )
117
129
118
130
try :
119
- events = json [ "events" ]
131
+ events = json . pop ( "events" )
120
132
except KeyError :
121
133
return web .json_response ({"error" : "Missing events object in body" }, status = 400 )
122
134
123
- if self .ephemeral_events :
124
- try :
125
- ephemeral = json ["ephemeral" ]
126
- except KeyError :
127
- try :
128
- ephemeral = json ["de.sorunome.msc2409.ephemeral" ]
129
- except KeyError :
130
- ephemeral = None
131
- else :
132
- ephemeral = None
135
+ ephemeral = (self ._get_with_fallback (json , "ephemeral" , "de.sorunome.msc2409" )
136
+ if self .ephemeral_events else None )
137
+ device_lists = DeviceLists .deserialize (
138
+ self ._get_with_fallback (json , "device_lists" , "org.matrix.msc3202" ))
139
+ otk_counts = {user_id : DeviceOTKCount .deserialize (count )
140
+ for user_id , count
141
+ in self ._get_with_fallback (json , "device_one_time_keys_count" ,
142
+ "org.matrix.msc3202" , default = {}).items ()}
133
143
134
144
try :
135
- await self .handle_transaction (transaction_id , events = events , ephemeral = ephemeral )
145
+ output = await self .handle_transaction (transaction_id , events = events , extra_data = json ,
146
+ ephemeral = ephemeral , device_lists = device_lists ,
147
+ device_otk_count = otk_counts )
136
148
except Exception :
137
149
self .log .exception ("Exception in transaction handler" )
150
+ output = None
138
151
139
152
self .transactions .add (transaction_id )
140
153
141
- return web .json_response ({})
154
+ return web .json_response (output or {})
142
155
143
156
@staticmethod
144
157
def _fix_prev_content (raw_event : JSON ) -> None :
@@ -150,8 +163,10 @@ def _fix_prev_content(raw_event: JSON) -> None:
150
163
except KeyError :
151
164
pass
152
165
153
- async def handle_transaction (self , txn_id : str , events : List [JSON ],
154
- ephemeral : Optional [List [JSON ]] = None ) -> None :
166
+ async def handle_transaction (self , txn_id : str , * , events : List [JSON ], extra_data : JSON ,
167
+ ephemeral : Optional [List [JSON ]] = None ,
168
+ device_otk_count : Optional [Dict [UserID , DeviceOTKCount ]] = None ,
169
+ device_lists : Optional [DeviceLists ] = None ) -> Optional [JSON ]:
155
170
for raw_edu in ephemeral or []:
156
171
try :
157
172
edu = EphemeralEvent .deserialize (raw_edu )
@@ -167,6 +182,7 @@ async def handle_transaction(self, txn_id: str, events: List[JSON],
167
182
self .log .exception ("Failed to deserialize event %s" , raw_event )
168
183
else :
169
184
self .handle_matrix_event (event )
185
+ return {}
170
186
171
187
def handle_matrix_event (self , event : Event ) -> None :
172
188
if event .type .is_state and event .state_key is None :
0 commit comments