Skip to content

Commit 72aa965

Browse files
committed
Add support for MSC3202 in appservice package
1 parent 18c5531 commit 72aa965

File tree

8 files changed

+176
-59
lines changed

8 files changed

+176
-59
lines changed

appservice/appservice.go

+8-8
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ import (
3333

3434
// EventChannelSize is the size for the Events channel in Appservice instances.
3535
var EventChannelSize = 64
36+
var OTKChannelSize = 4
3637

3738
// Create a blank appservice instance.
3839
func Create() *AppService {
@@ -83,20 +84,17 @@ type AppService struct {
8384
RegistrationPath string `yaml:"registration"`
8485
Host HostConfig `yaml:"host"`
8586
LogConfig LogConfig `yaml:"logging"`
86-
Sync struct {
87-
Enabled bool `yaml:"enabled"`
88-
FilterID string `yaml:"filter_id"`
89-
NextBatch string `yaml:"next_batch"`
90-
} `yaml:"sync"`
9187

9288
Registration *Registration `yaml:"-"`
9389
Log maulogger.Logger `yaml:"-"`
9490

9591
lastProcessedTransaction string
9692

97-
Events chan *event.Event `yaml:"-"`
98-
QueryHandler QueryHandler `yaml:"-"`
99-
StateStore StateStore `yaml:"-"`
93+
Events chan *event.Event `yaml:"-"`
94+
DeviceLists chan *mautrix.DeviceLists `yaml:"-"`
95+
OTKCounts chan *mautrix.OTKCount `yaml:"-"`
96+
QueryHandler QueryHandler `yaml:"-"`
97+
StateStore StateStore `yaml:"-"`
10098

10199
Router *mux.Router `yaml:"-"`
102100
UserAgent string `yaml:"-"`
@@ -246,6 +244,8 @@ func (as *AppService) BotClient() *mautrix.Client {
246244
// Init initializes the logger and loads the registration of this appservice.
247245
func (as *AppService) Init() (bool, error) {
248246
as.Events = make(chan *event.Event, EventChannelSize)
247+
as.OTKCounts = make(chan *mautrix.OTKCount, OTKChannelSize)
248+
as.DeviceLists = make(chan *mautrix.DeviceLists, EventChannelSize)
249249
as.QueryHandler = &QueryHandlerStub{}
250250

251251
if len(as.UserAgent) == 0 {

appservice/eventprocessor.go

+51-6
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212

1313
log "maunium.net/go/maulogger/v2"
1414

15+
"maunium.net/go/mautrix"
1516
"maunium.net/go/mautrix/event"
1617
)
1718

@@ -24,6 +25,8 @@ const (
2425
)
2526

2627
type EventHandler func(evt *event.Event)
28+
type OTKHandler func(otk *mautrix.OTKCount)
29+
type DeviceListHandler func(otk *mautrix.DeviceLists, since string)
2730

2831
type EventProcessor struct {
2932
ExecMode ExecMode
@@ -32,6 +35,9 @@ type EventProcessor struct {
3235
log log.Logger
3336
stop chan struct{}
3437
handlers map[event.Type][]EventHandler
38+
39+
otkHandlers []OTKHandler
40+
deviceListHandlers []DeviceListHandler
3541
}
3642

3743
func NewEventProcessor(as *AppService) *EventProcessor {
@@ -41,6 +47,9 @@ func NewEventProcessor(as *AppService) *EventProcessor {
4147
log: as.Log.Sub("Events"),
4248
stop: make(chan struct{}, 1),
4349
handlers: make(map[event.Type][]EventHandler),
50+
51+
otkHandlers: make([]OTKHandler, 0),
52+
deviceListHandlers: make([]DeviceListHandler, 0),
4453
}
4554
}
4655

@@ -54,16 +63,48 @@ func (ep *EventProcessor) On(evtType event.Type, handler EventHandler) {
5463
ep.handlers[evtType] = handlers
5564
}
5665

66+
func (ep *EventProcessor) OnOTK(handler OTKHandler) {
67+
ep.otkHandlers = append(ep.otkHandlers, handler)
68+
}
69+
70+
func (ep *EventProcessor) OnDeviceList(handler DeviceListHandler) {
71+
ep.deviceListHandlers = append(ep.deviceListHandlers, handler)
72+
}
73+
74+
func (ep *EventProcessor) recoverFunc(data interface{}) {
75+
if err := recover(); err != nil {
76+
d, _ := json.Marshal(data)
77+
ep.log.Errorfln("Panic in Matrix event handler: %v (event content: %s):\n%s", err, string(d), string(debug.Stack()))
78+
}
79+
}
80+
5781
func (ep *EventProcessor) callHandler(handler EventHandler, evt *event.Event) {
58-
defer func() {
59-
if err := recover(); err != nil {
60-
d, _ := json.Marshal(evt)
61-
ep.log.Errorfln("Panic in Matrix event handler: %v (event content: %s):\n%s", err, string(d), string(debug.Stack()))
62-
}
63-
}()
82+
defer ep.recoverFunc(evt)
6483
handler(evt)
6584
}
6685

86+
func (ep *EventProcessor) callOTKHandler(handler OTKHandler, otk *mautrix.OTKCount) {
87+
defer ep.recoverFunc(otk)
88+
handler(otk)
89+
}
90+
91+
func (ep *EventProcessor) callDeviceListHandler(handler DeviceListHandler, dl *mautrix.DeviceLists) {
92+
defer ep.recoverFunc(dl)
93+
handler(dl, "")
94+
}
95+
96+
func (ep *EventProcessor) DispatchOTK(otk *mautrix.OTKCount) {
97+
for _, handler := range ep.otkHandlers {
98+
go ep.callOTKHandler(handler, otk)
99+
}
100+
}
101+
102+
func (ep *EventProcessor) DispatchDeviceList(dl *mautrix.DeviceLists) {
103+
for _, handler := range ep.deviceListHandlers {
104+
go ep.callDeviceListHandler(handler, dl)
105+
}
106+
}
107+
67108
func (ep *EventProcessor) Dispatch(evt *event.Event) {
68109
handlers, ok := ep.handlers[evt.Type]
69110
if !ok {
@@ -92,6 +133,10 @@ func (ep *EventProcessor) Start() {
92133
select {
93134
case evt := <-ep.as.Events:
94135
ep.Dispatch(evt)
136+
case otk := <-ep.as.OTKCounts:
137+
ep.DispatchOTK(otk)
138+
case dl := <-ep.as.DeviceLists:
139+
ep.DispatchDeviceList(dl)
95140
case <-ep.stop:
96141
return
97142
}

appservice/http.go

+44-12
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2020 Tulir Asokan
1+
// Copyright (c) 2021 Tulir Asokan
22
//
33
// This Source Code Form is subject to the terms of the Mozilla Public
44
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -17,11 +17,12 @@ import (
1717

1818
"github.com/gorilla/mux"
1919

20+
"maunium.net/go/mautrix"
2021
"maunium.net/go/mautrix/event"
2122
"maunium.net/go/mautrix/id"
2223
)
2324

24-
// Listen starts the HTTP server that listens for calls from the Matrix homeserver.
25+
// Start starts the HTTP server that listens for calls from the Matrix homeserver.
2526
func (as *AppService) Start() {
2627
as.Router.HandleFunc("/transactions/{txnID}", as.PutTransaction).Methods(http.MethodPut)
2728
as.Router.HandleFunc("/rooms/{roomAlias}", as.GetRoom).Methods(http.MethodGet)
@@ -119,8 +120,8 @@ func (as *AppService) PutTransaction(w http.ResponseWriter, r *http.Request) {
119120
return
120121
}
121122

122-
eventList := EventList{}
123-
err = json.Unmarshal(body, &eventList)
123+
var txn Transaction
124+
err = json.Unmarshal(body, &txn)
124125
if err != nil {
125126
as.Log.Warnfln("Failed to parse JSON of transaction %s: %v", txnID, err)
126127
Error{
@@ -130,22 +131,53 @@ func (as *AppService) PutTransaction(w http.ResponseWriter, r *http.Request) {
130131
}.Write(w)
131132
} else {
132133
if as.Registration.EphemeralEvents {
133-
if eventList.EphemeralEvents != nil {
134-
as.handleEvents(eventList.EphemeralEvents, event.EphemeralEventType)
135-
} else if eventList.SoruEphemeralEvents != nil {
136-
as.handleEvents(eventList.SoruEphemeralEvents, event.EphemeralEventType)
134+
if txn.EphemeralEvents != nil {
135+
as.handleEvents(txn.EphemeralEvents, event.EphemeralEventType)
136+
} else if txn.MSC2409EphemeralEvents != nil {
137+
as.handleEvents(txn.MSC2409EphemeralEvents, event.EphemeralEventType)
137138
}
138139
}
139-
as.handleEvents(eventList.Events, event.UnknownEventType)
140+
as.handleEvents(txn.Events, event.UnknownEventType)
141+
if txn.DeviceLists != nil {
142+
as.handleDeviceLists(txn.DeviceLists)
143+
} else if txn.MSC3202DeviceLists != nil {
144+
as.handleDeviceLists(txn.MSC3202DeviceLists)
145+
}
146+
if txn.DeviceOTKCount != nil {
147+
as.handleOTKCounts(txn.DeviceOTKCount)
148+
} else if txn.MSC3202DeviceOTKCount != nil {
149+
as.handleOTKCounts(txn.MSC3202DeviceOTKCount)
150+
}
140151
WriteBlankOK(w)
141152
}
142153
as.lastProcessedTransaction = txnID
143154
}
144155

145-
func (as *AppService) handleEvents(evts []*event.Event, typeClass event.TypeClass) {
156+
func (as *AppService) handleOTKCounts(otks map[id.UserID]mautrix.OTKCount) {
157+
for userID, otkCounts := range otks {
158+
otkCounts.UserID = userID
159+
select {
160+
case as.OTKCounts <- &otkCounts:
161+
default:
162+
as.Log.Warnfln("Dropped OTK count update for %s because channel is full", userID)
163+
}
164+
}
165+
}
166+
167+
func (as *AppService) handleDeviceLists(dl *mautrix.DeviceLists) {
168+
select {
169+
case as.DeviceLists <- dl:
170+
default:
171+
as.Log.Warnln("Dropped device list update because channel is full")
172+
}
173+
}
174+
175+
func (as *AppService) handleEvents(evts []*event.Event, defaultTypeClass event.TypeClass) {
146176
for _, evt := range evts {
147-
if typeClass != event.UnknownEventType {
148-
evt.Type.Class = typeClass
177+
if len(evt.ToUserID) > 0 {
178+
evt.Type.Class = event.ToDeviceEventType
179+
} else if defaultTypeClass != event.UnknownEventType {
180+
evt.Type.Class = defaultTypeClass
149181
} else if evt.StateKey != nil {
150182
evt.Type.Class = event.StateEventType
151183
} else {

appservice/protocol.go

+13-6
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2020 Tulir Asokan
1+
// Copyright (c) 2021 Tulir Asokan
22
//
33
// This Source Code Form is subject to the terms of the Mozilla Public
44
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -10,14 +10,21 @@ import (
1010
"encoding/json"
1111
"net/http"
1212

13+
"maunium.net/go/mautrix"
1314
"maunium.net/go/mautrix/event"
15+
"maunium.net/go/mautrix/id"
1416
)
1517

16-
// EventList contains a list of events.
17-
type EventList struct {
18-
Events []*event.Event `json:"events"`
19-
EphemeralEvents []*event.Event `json:"ephemeral"`
20-
SoruEphemeralEvents []*event.Event `json:"de.sorunome.msc2409.ephemeral"`
18+
// Transaction contains a list of events.
19+
type Transaction struct {
20+
Events []*event.Event `json:"events"`
21+
EphemeralEvents []*event.Event `json:"ephemeral,omitempty"`
22+
DeviceLists *mautrix.DeviceLists `json:"device_lists,omitempty"`
23+
DeviceOTKCount map[id.UserID]mautrix.OTKCount `json:"device_one_time_keys_count,omitempty"`
24+
25+
MSC2409EphemeralEvents []*event.Event `json:"de.sorunome.msc2409.ephemeral,omitempty"`
26+
MSC3202DeviceLists *mautrix.DeviceLists `json:"org.matrix.msc3202.device_lists,omitempty"`
27+
MSC3202DeviceOTKCount map[id.UserID]mautrix.OTKCount `json:"org.matrix.msc3202.device_one_time_keys_count,omitempty"`
2128
}
2229

2330
// EventListener is a function that receives events.

appservice/websocket.go

+3-8
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,6 @@ import (
2020
"maunium.net/go/mautrix/event"
2121
)
2222

23-
type ErrorResponse struct {
24-
ErrorCode ErrorCode `json:"errcode"`
25-
Error string `json:"error"`
26-
}
27-
2823
type WebsocketCommand struct {
2924
ReqID int `json:"id,omitempty"`
3025
Command string `json:"command"`
@@ -34,7 +29,7 @@ type WebsocketCommand struct {
3429
type WebsocketTransaction struct {
3530
Status string `json:"status"`
3631
TxnID string `json:"txn_id"`
37-
EventList
32+
Transaction
3833
}
3934

4035
type WebsocketMessage struct {
@@ -150,12 +145,12 @@ func (as *AppService) StartWebsocket(baseURL string, onConnect func()) error {
150145
"User-Agent": []string{as.BotClient().UserAgent},
151146
})
152147
if resp != nil && resp.StatusCode >= 400 {
153-
var errResp ErrorResponse
148+
var errResp Error
154149
err = json.NewDecoder(resp.Body).Decode(&errResp)
155150
if err != nil {
156151
return fmt.Errorf("websocket request returned HTTP %d with non-JSON body", resp.StatusCode)
157152
} else {
158-
return fmt.Errorf("websocket request returned %s (HTTP %d): %s", errResp.ErrorCode, resp.StatusCode, errResp.Error)
153+
return fmt.Errorf("websocket request returned %s (HTTP %d): %s", errResp.ErrorCode, resp.StatusCode, errResp.Message)
159154
}
160155
} else if err != nil {
161156
return fmt.Errorf("failed to open websocket: %w", err)

crypto/machine.go

+42-12
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"sync"
1313
"time"
1414

15+
"maunium.net/go/mautrix/appservice"
1516
"maunium.net/go/mautrix/crypto/ssss"
1617
"maunium.net/go/mautrix/id"
1718

@@ -153,16 +154,47 @@ func (mach *OlmMachine) OwnIdentity() *DeviceIdentity {
153154
}
154155
}
155156

157+
func (mach *OlmMachine) AddAppserviceListener(ep *appservice.EventProcessor, az *appservice.AppService) {
158+
// ToDeviceForwardedRoomKey and ToDeviceRoomKey should only be present inside encrypted to-device events
159+
ep.On(event.ToDeviceEncrypted, mach.HandleToDeviceEvent)
160+
ep.On(event.ToDeviceRoomKeyRequest, mach.HandleToDeviceEvent)
161+
ep.On(event.ToDeviceRoomKeyWithheld, mach.HandleToDeviceEvent)
162+
ep.On(event.ToDeviceOrgMatrixRoomKeyWithheld, mach.HandleToDeviceEvent)
163+
ep.On(event.ToDeviceVerificationRequest, mach.HandleToDeviceEvent)
164+
ep.On(event.ToDeviceVerificationStart, mach.HandleToDeviceEvent)
165+
ep.On(event.ToDeviceVerificationAccept, mach.HandleToDeviceEvent)
166+
ep.On(event.ToDeviceVerificationKey, mach.HandleToDeviceEvent)
167+
ep.On(event.ToDeviceVerificationMAC, mach.HandleToDeviceEvent)
168+
ep.On(event.ToDeviceVerificationCancel, mach.HandleToDeviceEvent)
169+
ep.OnOTK(mach.HandleOTKCounts)
170+
ep.OnDeviceList(mach.HandleDeviceLists)
171+
}
172+
173+
func (mach *OlmMachine) HandleDeviceLists(dl *mautrix.DeviceLists, since string) {
174+
if len(dl.Changed) > 0 {
175+
mach.Log.Trace("Device list changes in /sync: %v", dl.Changed)
176+
mach.fetchKeys(dl.Changed, since, false)
177+
}
178+
}
179+
180+
func (mach *OlmMachine) HandleOTKCounts(otkCount *mautrix.OTKCount) {
181+
minCount := mach.account.Internal.MaxNumberOfOneTimeKeys() / 2
182+
if otkCount.SignedCurve25519 < int(minCount) {
183+
mach.Log.Debug("Sync response said we have %d signed curve25519 keys left, sharing new ones...", otkCount.SignedCurve25519)
184+
err := mach.ShareKeys(otkCount.SignedCurve25519)
185+
if err != nil {
186+
mach.Log.Error("Failed to share keys: %v", err)
187+
}
188+
}
189+
}
190+
156191
// ProcessSyncResponse processes a single /sync response.
157192
//
158193
// This can be easily registered into a mautrix client using .OnSync():
159194
//
160195
// client.Syncer.(*mautrix.DefaultSyncer).OnSync(c.crypto.ProcessSyncResponse)
161196
func (mach *OlmMachine) ProcessSyncResponse(resp *mautrix.RespSync, since string) bool {
162-
if len(resp.DeviceLists.Changed) > 0 {
163-
mach.Log.Trace("Device list changes in /sync: %v", resp.DeviceLists.Changed)
164-
mach.fetchKeys(resp.DeviceLists.Changed, since, false)
165-
}
197+
mach.HandleDeviceLists(&resp.DeviceLists, since)
166198

167199
for _, evt := range resp.ToDevice.Events {
168200
evt.Type.Class = event.ToDeviceEventType
@@ -174,14 +206,7 @@ func (mach *OlmMachine) ProcessSyncResponse(resp *mautrix.RespSync, since string
174206
mach.HandleToDeviceEvent(evt)
175207
}
176208

177-
min := mach.account.Internal.MaxNumberOfOneTimeKeys() / 2
178-
if resp.DeviceOneTimeKeysCount.SignedCurve25519 < int(min) {
179-
mach.Log.Debug("Sync response said we have %d signed curve25519 keys left, sharing new ones...", resp.DeviceOneTimeKeysCount.SignedCurve25519)
180-
err := mach.ShareKeys(resp.DeviceOneTimeKeysCount.SignedCurve25519)
181-
if err != nil {
182-
mach.Log.Error("Failed to share keys: %v", err)
183-
}
184-
}
209+
mach.HandleOTKCounts(&resp.DeviceOTKCount)
185210
return true
186211
}
187212

@@ -222,6 +247,11 @@ func (mach *OlmMachine) HandleMemberEvent(evt *event.Event) {
222247
// HandleToDeviceEvent handles a single to-device event. This is automatically called by ProcessSyncResponse, so you
223248
// don't need to add any custom handlers if you use that method.
224249
func (mach *OlmMachine) HandleToDeviceEvent(evt *event.Event) {
250+
if len(evt.ToUserID) > 0 && (evt.ToUserID != mach.Client.UserID || evt.ToDeviceID != mach.Client.DeviceID) {
251+
// TODO This log probably needs to be silence-able if someone wants to use encrypted appservices with multiple e2ee sessions
252+
mach.Log.Debug("Dropping to-device event targeted to %s/%s (not us)", evt.ToUserID, evt.ToDeviceID)
253+
return
254+
}
225255
switch content := evt.Content.Parsed.(type) {
226256
case *event.EncryptedEventContent:
227257
mach.Log.Debug("Handling encrypted to-device event from %s/%s", evt.Sender, content.SenderKey)

0 commit comments

Comments
 (0)