From b66c2627dc085a21c3968cc9f306b10d92cde76d Mon Sep 17 00:00:00 2001 From: Sam Wedgwood Date: Tue, 1 Aug 2023 11:27:56 +0100 Subject: [PATCH] use pointers for querying sender ID --- performinvite.go | 4 ++-- performinvite_test.go | 5 +++-- spec/senderid.go | 13 ++++++++++++- 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/performinvite.go b/performinvite.go index 6b4c3168..7f591a29 100644 --- a/performinvite.go +++ b/performinvite.go @@ -96,8 +96,8 @@ func PerformInvite(ctx context.Context, input PerformInviteInput, fedClient Fede return nil, err } - if invitedSenderID != "" { - err = abortIfAlreadyJoined(ctx, input.RoomID, invitedSenderID, input.MembershipQuerier) + if invitedSenderID != nil { + err = abortIfAlreadyJoined(ctx, input.RoomID, *invitedSenderID, input.MembershipQuerier) if err != nil { return nil, err } diff --git a/performinvite_test.go b/performinvite_test.go index 35a8d937..f663da0f 100644 --- a/performinvite_test.go +++ b/performinvite_test.go @@ -13,8 +13,9 @@ import ( "golang.org/x/crypto/ed25519" ) -func SenderIDForUserTest(roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) { - return spec.SenderID(userID.String()), nil +func SenderIDForUserTest(roomID spec.RoomID, userID spec.UserID) (*spec.SenderID, error) { + senderID := spec.SenderID(userID.String()) + return &senderID, nil } func CreateSenderID(ctx context.Context, userID spec.UserID, roomID spec.RoomID, roomVersion string) (spec.SenderID, ed25519.PrivateKey, error) { diff --git a/spec/senderid.go b/spec/senderid.go index 44954c15..76622faf 100644 --- a/spec/senderid.go +++ b/spec/senderid.go @@ -23,7 +23,7 @@ import ( type SenderID string type UserIDForSender func(roomID RoomID, senderID SenderID) (*UserID, error) -type SenderIDForUser func(roomID RoomID, userID UserID) (SenderID, error) +type SenderIDForUser func(roomID RoomID, userID UserID) (*SenderID, error) // CreateSenderID is a function used to create the pseudoID private key. type CreateSenderID func(ctx context.Context, userID UserID, roomID RoomID, roomVersion string) (SenderID, ed25519.PrivateKey, error) @@ -42,3 +42,14 @@ func (s SenderID) RawBytes() (res Base64Bytes, err error) { } return res, nil } + +func (s SenderID) IsUserID() bool { + // Key is base64, @ is not a valid base64 char + // So if string starts with @, then this sender ID must + // be a user ID + return string(s)[0] == '@' +} + +func (s SenderID) IsPseudoID() bool { + return !s.IsUserID() +}