Skip to content

Commit 809b162

Browse files
authored
Bring high level room joining logic into GMSL (#372)
1 parent 8daeaeb commit 809b162

11 files changed

+646
-21
lines changed

authchain.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ import (
55
"fmt"
66
)
77

8-
// AuthChainProvider returns the requested list of auth events.
9-
type AuthChainProvider func(roomVer RoomVersion, eventIDs []string) ([]*Event, error)
8+
// EventProvider returns the requested list of events.
9+
type EventProvider func(roomVer RoomVersion, eventIDs []string) ([]*Event, error)
1010

1111
// VerifyEventAuthChain will verify that the event is allowed according to its auth_events, and then
1212
// recursively verify each of those auth_events.
@@ -19,7 +19,7 @@ type AuthChainProvider func(roomVer RoomVersion, eventIDs []string) ([]*Event, e
1919
// The `provideEvents` function will only be called for *new* events rather than for everything as it is
2020
// assumed that this function is costly. Failing to provide all the requested events will fail this function.
2121
// Returning an error from `provideEvents` will also fail this function.
22-
func VerifyEventAuthChain(ctx context.Context, eventToVerify *HeaderedEvent, provideEvents AuthChainProvider) error {
22+
func VerifyEventAuthChain(ctx context.Context, eventToVerify *HeaderedEvent, provideEvents EventProvider) error {
2323
eventsByID := make(map[string]*Event) // A lookup table for verifying this auth chain
2424
evv := eventToVerify.Unwrap()
2525
eventsByID[evv.EventID()] = evv

authchain_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ func TestVerifyEventAuthChainCascadeFailure(t *testing.T) {
7070
}
7171
}
7272

73-
func provideEvents(t *testing.T, events [][]byte) gomatrixserverlib.AuthChainProvider {
73+
func provideEvents(t *testing.T, events [][]byte) gomatrixserverlib.EventProvider {
7474
eventMap := make(map[string]*gomatrixserverlib.Event)
7575
for _, eventBytes := range events {
7676
ev, err := gomatrixserverlib.MustGetRoomVersion(gomatrixserverlib.RoomVersionV1).NewEventFromTrustedJSON(eventBytes, false)

authstate.go

+6-6
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ func VerifyAuthRulesAtState(ctx context.Context, sp StateProvider, eventToVerify
163163

164164
func checkAllowedByAuthEvents(
165165
event *Event, eventsByID map[string]*Event,
166-
missingAuth AuthChainProvider,
166+
missingAuth EventProvider,
167167
) error {
168168
authEvents := NewAuthEvents(nil)
169169

@@ -173,7 +173,7 @@ func checkAllowedByAuthEvents(
173173
if !ok {
174174
// We don't have an entry in the eventsByID map - neither an event nor nil.
175175
if missingAuth != nil {
176-
// If we have a AuthChainProvider then ask it for the missing event.
176+
// If we have a EventProvider then ask it for the missing event.
177177
if ev, err := missingAuth(event.Version(), []string{ae}); err == nil && len(ev) > 0 {
178178
// It claims to have returned events - populate the eventsByID
179179
// map and the authEvents provider so that we can retry with the
@@ -193,7 +193,7 @@ func checkAllowedByAuthEvents(
193193
}
194194
goto retryEvent
195195
} else {
196-
// If we didn't have a AuthChainProvider then we can't get the event
196+
// If we didn't have a EventProvider then we can't get the event
197197
// so just carry on without it. If it was important for anything then
198198
// Check() below will catch it.
199199
continue
@@ -206,7 +206,7 @@ func checkAllowedByAuthEvents(
206206
}
207207
} else {
208208
// We had an entry in the map but it contains nil, which means that we tried
209-
// to use the AuthChainProvider to retrieve it and failed, so at this point
209+
// to use the EventProvider to retrieve it and failed, so at this point
210210
// we just have to ignore the event.
211211
continue
212212
}
@@ -229,7 +229,7 @@ func checkAllowedByAuthEvents(
229229
// return parameter). Does not alter any input args.
230230
func CheckStateResponse(
231231
ctx context.Context, r StateResponse, roomVersion RoomVersion,
232-
keyRing JSONVerifier, missingAuth AuthChainProvider,
232+
keyRing JSONVerifier, missingAuth EventProvider,
233233
) ([]*Event, []*Event, error) {
234234
logger := util.GetLogger(ctx)
235235
authEvents := r.GetAuthEvents().UntrustedEvents(roomVersion)
@@ -322,7 +322,7 @@ func CheckStateResponse(
322322
func CheckSendJoinResponse(
323323
ctx context.Context, roomVersion RoomVersion, r StateResponse,
324324
keyRing JSONVerifier, joinEvent *Event,
325-
missingAuth AuthChainProvider,
325+
missingAuth EventProvider,
326326
) (StateResponse, error) {
327327
// First check that the state is valid and that the events in the response
328328
// are correctly signed.

errors.go

+17-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
package gomatrixserverlib
22

3-
import "fmt"
3+
import (
4+
"fmt"
5+
6+
"github.com/matrix-org/gomatrixserverlib/spec"
7+
)
48

59
// MissingAuthEventError refers to a situation where one of the auth
610
// event for a given event was not found.
@@ -27,3 +31,15 @@ func (e BadJSONError) Error() string {
2731
func (e BadJSONError) Unwrap() error {
2832
return e.err
2933
}
34+
35+
// FederationError contains context surrounding why a federation request may have failed.
36+
type FederationError struct {
37+
ServerName spec.ServerName // The server being contacted.
38+
Transient bool // Whether the failure is permanent (will fail if performed again) or not.
39+
Reachable bool // Whether the server could be contacted.
40+
Err error // The underlying error message.
41+
}
42+
43+
func (e FederationError) Error() string {
44+
return fmt.Sprintf("FederationError(t=%v, r=%v): %s", e.Transient, e.Reachable, e.Err.Error())
45+
}

fclient/federationclient.go

+14-2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ type FederationClient interface {
2626
// Perform operations
2727
LookupRoomAlias(ctx context.Context, origin, s spec.ServerName, roomAlias string) (res RespDirectory, err error)
2828
Peek(ctx context.Context, origin, s spec.ServerName, roomID, peekID string, roomVersions []gomatrixserverlib.RoomVersion) (res RespPeek, err error)
29-
MakeJoin(ctx context.Context, origin, s spec.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res RespMakeJoin, err error)
29+
MakeJoin(ctx context.Context, origin, s spec.ServerName, roomID, userID string) (res RespMakeJoin, err error)
3030
SendJoin(ctx context.Context, origin, s spec.ServerName, event *gomatrixserverlib.Event) (res RespSendJoin, err error)
3131
MakeLeave(ctx context.Context, origin, s spec.ServerName, roomID, userID string) (res RespMakeLeave, err error)
3232
SendLeave(ctx context.Context, origin, s spec.ServerName, event *gomatrixserverlib.Event) (err error)
@@ -180,6 +180,18 @@ func makeVersionQueryString(roomVersions []gomatrixserverlib.RoomVersion) string
180180
return versionQueryString
181181
}
182182

183+
// Takes the map of room version implementations and converts it into a list of
184+
// room version strings.
185+
func roomVersionsToList(
186+
versionsMap map[gomatrixserverlib.RoomVersion]gomatrixserverlib.IRoomVersion,
187+
) []gomatrixserverlib.RoomVersion {
188+
var supportedVersions []gomatrixserverlib.RoomVersion
189+
for version := range versionsMap {
190+
supportedVersions = append(supportedVersions, version)
191+
}
192+
return supportedVersions
193+
}
194+
183195
// MakeJoin makes a join m.room.member event for a room on a remote matrix server.
184196
// This is used to join a room the local server isn't a member of.
185197
// We need to query a remote server because if we aren't in the room we don't
@@ -191,8 +203,8 @@ func makeVersionQueryString(roomVersions []gomatrixserverlib.RoomVersion) string
191203
// See https://matrix.org/docs/spec/server_server/unstable.html#joining-rooms
192204
func (ac *federationClient) MakeJoin(
193205
ctx context.Context, origin, s spec.ServerName, roomID, userID string,
194-
roomVersions []gomatrixserverlib.RoomVersion,
195206
) (res RespMakeJoin, err error) {
207+
roomVersions := roomVersionsToList(gomatrixserverlib.RoomVersions())
196208
versionQueryString := makeVersionQueryString(roomVersions)
197209
path := federationPathPrefixV1 + "/make_join/" +
198210
url.PathEscape(roomID) + "/" +

fclient/federationclient_test.go

+6-6
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ func TestSendJoinFallback(t *testing.T) {
8989
if err != nil {
9090
t.Fatalf("SendJoin returned an error: %s", err)
9191
}
92-
if !reflect.DeepEqual(res.StateEvents, wantRes.StateEvents) {
93-
t.Fatalf("SendJoin response got %+v want %+v", res.StateEvents, wantRes.StateEvents)
92+
if !reflect.DeepEqual(res.GetStateEvents(), wantRes.StateEvents) {
93+
t.Fatalf("SendJoin response got %+v want %+v", res.GetStateEvents(), wantRes.StateEvents)
9494
}
9595
}
9696

@@ -150,11 +150,11 @@ func TestSendJoinJSON(t *testing.T) {
150150
}
151151
wantStateEvents := gomatrixserverlib.EventJSONs{[]byte(retEv)}
152152
wantAuthChain := gomatrixserverlib.EventJSONs{[]byte(retEv)}
153-
if !reflect.DeepEqual(res.StateEvents, wantStateEvents) {
154-
t.Fatalf("SendJoin response got state %+v want %+v", jsonify(res.StateEvents), jsonify(wantStateEvents))
153+
if !reflect.DeepEqual(res.GetStateEvents(), wantStateEvents) {
154+
t.Fatalf("SendJoin response got state %+v want %+v", jsonify(res.GetStateEvents()), jsonify(wantStateEvents))
155155
}
156-
if !reflect.DeepEqual(res.AuthEvents, wantAuthChain) {
157-
t.Fatalf("SendJoin response got auth %+v want %+v", jsonify(res.AuthEvents), jsonify(wantAuthChain))
156+
if !reflect.DeepEqual(res.GetAuthEvents(), wantAuthChain) {
157+
t.Fatalf("SendJoin response got auth %+v want %+v", jsonify(res.GetAuthEvents()), jsonify(wantAuthChain))
158158
}
159159
}
160160

fclient/federationtypes.go

+24
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,14 @@ type RespMakeJoin struct {
253253
RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"`
254254
}
255255

256+
func (r *RespMakeJoin) GetJoinEvent() gomatrixserverlib.EventBuilder {
257+
return r.JoinEvent
258+
}
259+
260+
func (r *RespMakeJoin) GetRoomVersion() gomatrixserverlib.RoomVersion {
261+
return r.RoomVersion
262+
}
263+
256264
// A RespSendJoin is the content of a response to PUT /_matrix/federation/v2/send_join/{roomID}/{eventID}
257265
type RespSendJoin struct {
258266
// A list of events giving the state of the room before the request event.
@@ -278,6 +286,22 @@ func (r *RespSendJoin) GetAuthEvents() gomatrixserverlib.EventJSONs {
278286
return r.AuthEvents
279287
}
280288

289+
func (r *RespSendJoin) GetOrigin() spec.ServerName {
290+
return r.Origin
291+
}
292+
293+
func (r *RespSendJoin) GetJoinEvent() spec.RawJSON {
294+
return r.Event
295+
}
296+
297+
func (r *RespSendJoin) GetMembersOmitted() bool {
298+
return r.MembersOmitted
299+
}
300+
301+
func (r *RespSendJoin) GetServersInRoom() []string {
302+
return r.ServersInRoom
303+
}
304+
281305
// MarshalJSON implements json.Marshaller
282306
func (r RespSendJoin) MarshalJSON() ([]byte, error) {
283307
fields := respSendJoinFields{

join.go

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package gomatrixserverlib
2+
3+
import (
4+
"context"
5+
6+
"github.com/matrix-org/gomatrixserverlib/spec"
7+
)
8+
9+
type FederatedJoinClient interface {
10+
MakeJoin(ctx context.Context, origin, s spec.ServerName, roomID, userID string) (res MakeJoinResponse, err error)
11+
SendJoin(ctx context.Context, origin, s spec.ServerName, event *Event) (res SendJoinResponse, err error)
12+
}
13+
14+
type MakeJoinResponse interface {
15+
GetJoinEvent() EventBuilder
16+
GetRoomVersion() RoomVersion
17+
}
18+
19+
type SendJoinResponse interface {
20+
GetAuthEvents() EventJSONs
21+
GetStateEvents() EventJSONs
22+
GetOrigin() spec.ServerName
23+
GetJoinEvent() spec.RawJSON
24+
GetMembersOmitted() bool
25+
GetServersInRoom() []string
26+
}

load.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ type EventLoadResult struct {
1818
type EventsLoader struct {
1919
roomVer RoomVersion
2020
keyRing JSONVerifier
21-
provider AuthChainProvider
21+
provider EventProvider
2222
stateProvider StateProvider
2323
// Set to true to do:
2424
// 6. Passes authorization rules based on the current state of the room, otherwise it is "soft failed".
@@ -27,7 +27,7 @@ type EventsLoader struct {
2727
}
2828

2929
// NewEventsLoader returns a new events loader
30-
func NewEventsLoader(roomVer RoomVersion, keyRing JSONVerifier, stateProvider StateProvider, provider AuthChainProvider, performSoftFailCheck bool) *EventsLoader {
30+
func NewEventsLoader(roomVer RoomVersion, keyRing JSONVerifier, stateProvider StateProvider, provider EventProvider, performSoftFailCheck bool) *EventsLoader {
3131
return &EventsLoader{
3232
roomVer: roomVer,
3333
keyRing: keyRing,

0 commit comments

Comments
 (0)