Skip to content

Add DeviceID to client.CSAPI #374

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ type SyncReq struct {
type CSAPI struct {
UserID string
AccessToken string
DeviceID string
BaseURL string
Client *http.Client
// how long are we willing to wait for MustSyncUntil.... calls
Expand Down Expand Up @@ -298,7 +299,7 @@ func (c *CSAPI) MustSyncUntil(t *testing.T, syncReq SyncReq, checks ...SyncCheck

//RegisterUser will register the user with given parameters and
// return user ID & access token, and fail the test on network error
func (c *CSAPI) RegisterUser(t *testing.T, localpart, password string) (userID, accessToken string) {
func (c *CSAPI) RegisterUser(t *testing.T, localpart, password string) (userID, accessToken, deviceID string) {
t.Helper()
reqBody := map[string]interface{}{
"auth": map[string]string{
Expand All @@ -316,12 +317,13 @@ func (c *CSAPI) RegisterUser(t *testing.T, localpart, password string) (userID,

userID = gjson.GetBytes(body, "user_id").Str
accessToken = gjson.GetBytes(body, "access_token").Str
return userID, accessToken
deviceID = gjson.GetBytes(body, "device_id").Str
return userID, accessToken, deviceID
}

// RegisterSharedSecret registers a new account with a shared secret via HMAC
// See https://github.com/matrix-org/synapse/blob/e550ab17adc8dd3c48daf7fedcd09418a73f524b/synapse/_scripts/register_new_matrix_user.py#L40
func (c *CSAPI) RegisterSharedSecret(t *testing.T, user, pass string, isAdmin bool) (userID, password string) {
func (c *CSAPI) RegisterSharedSecret(t *testing.T, user, pass string, isAdmin bool) (userID, accessToken, deviceID string) {
resp := c.DoFunc(t, "GET", []string{"_synapse", "admin", "v1", "register"})
if resp.StatusCode != 200 {
t.Skipf("Homeserver image does not support shared secret registration, /_synapse/admin/v1/register returned HTTP %d", resp.StatusCode)
Expand Down Expand Up @@ -354,7 +356,10 @@ func (c *CSAPI) RegisterSharedSecret(t *testing.T, user, pass string, isAdmin bo
}
resp = c.MustDoFunc(t, "POST", []string{"_synapse", "admin", "v1", "register"}, WithJSONBody(t, reqBody))
body = must.ParseJSON(t, resp.Body)
return gjson.GetBytes(body, "user_id").Str, gjson.GetBytes(body, "access_token").Str
userID = gjson.GetBytes(body, "user_id").Str
accessToken = gjson.GetBytes(body, "access_token").Str
deviceID = gjson.GetBytes(body, "device_id").Str
return userID, accessToken, deviceID
}

// GetCapbabilities queries the server's capabilities
Expand Down
5 changes: 5 additions & 0 deletions internal/docker/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,11 @@ func (d *Builder) construct(bprint b.Blueprint) (errs []error) {
}
}

deviceIDs := runner.DeviceIDs(res.homeserver.Name)
for userID, deviceID := range deviceIDs {
labels["device_id"+userID] = deviceID
}

// Combine the labels for tokens and application services
asLabels := labelsForApplicationServices(res.homeserver)
for k, v := range asLabels {
Expand Down
1 change: 1 addition & 0 deletions internal/docker/deployer.go
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ func deployImage(
ContainerID: containerID,
AccessTokens: tokensFromLabels(inspect.Config.Labels),
ApplicationServices: asIDToRegistrationFromLabels(inspect.Config.Labels),
DeviceIDs: deviceIDsFromLabels(inspect.Config.Labels),
}
if lastErr != nil {
return d, fmt.Errorf("%s: failed to check server is up. %w", contextStr, lastErr)
Expand Down
13 changes: 10 additions & 3 deletions internal/docker/deployment.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ type HomeserverDeployment struct {
ContainerID string // e.g 10de45efba
AccessTokens map[string]string // e.g { "@alice:hs1": "myAcc3ssT0ken" }
ApplicationServices map[string]string // e.g { "my-as-id": "id: xxx\nas_token: xxx ..."} }
DeviceIDs map[string]string // e.g { "@alice:hs1": "myDeviceID" }
}

// Destroy the entire deployment. Destroys all running containers. If `printServerLogs` is true,
Expand All @@ -51,9 +52,14 @@ func (d *Deployment) Client(t *testing.T, hsName, userID string) *client.CSAPI {
t.Fatalf("Deployment.Client - HS name '%s' - user ID '%s' not found", hsName, userID)
return nil
}
deviceID := dep.DeviceIDs[userID]
if deviceID == "" && userID != "" {
t.Logf("WARNING: Deployment.Client - HS name '%s' - user ID '%s' - deviceID not found", hsName, userID)
}
return &client.CSAPI{
UserID: userID,
AccessToken: token,
DeviceID: deviceID,
BaseURL: dep.BaseURL,
Client: client.NewLoggedClient(t, hsName, nil),
SyncUntilTimeout: 5 * time.Second,
Expand All @@ -75,17 +81,18 @@ func (d *Deployment) RegisterUser(t *testing.T, hsName, localpart, password stri
SyncUntilTimeout: 5 * time.Second,
Debug: d.Deployer.debugLogging,
}
var userID, accessToken string
var userID, accessToken, deviceID string
if isAdmin {
userID, accessToken = client.RegisterSharedSecret(t, localpart, password, isAdmin)
userID, accessToken, deviceID = client.RegisterSharedSecret(t, localpart, password, isAdmin)
} else {
userID, accessToken = client.RegisterUser(t, localpart, password)
userID, accessToken, deviceID = client.RegisterUser(t, localpart, password)
}

// remember the token so subsequent calls to deployment.Client return the user
dep.AccessTokens[userID] = accessToken

client.UserID = userID
client.AccessToken = accessToken
client.DeviceID = deviceID
return client
}
10 changes: 10 additions & 0 deletions internal/docker/labels.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,13 @@ func labelsForApplicationServices(hs b.Homeserver) map[string]string {
}
return labels
}

func deviceIDsFromLabels(labels map[string]string) map[string]string {
userIDToToken := make(map[string]string)
for k, v := range labels {
if strings.HasPrefix(k, "device_id") {
userIDToToken[strings.TrimPrefix(k, "device_id")] = v
}
}
return userIDToToken
}
21 changes: 19 additions & 2 deletions internal/instruction/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,21 @@ func (r *Runner) AccessTokens(hsDomain string) map[string]string {
return res
}

// DeviceIDs returns the device ids for all users who were created on the given HS domain.
// Returns a map of user_id => device_id
func (r *Runner) DeviceIDs(hsDomain string) map[string]string {
res := make(map[string]string)
r.lookup.Range(func(k, v interface{}) bool {
key := k.(string)
val := v.(string)
if strings.HasPrefix(key, "device_@") && strings.HasSuffix(key, ":"+hsDomain) {
res[strings.TrimPrefix(key, "device_")] = val
}
return true
})
return res
}

// Load a previously stored value from RunInstructions
func (r *Runner) GetStoredValue(opts RunOpts, key string) string {
fullKey := opts.StoreNamespace + key
Expand Down Expand Up @@ -541,7 +556,8 @@ func instructionRegister(hs b.Homeserver, user b.User) instruction {
accessToken: "",
body: body,
storeResponse: map[string]string{
"user_@" + user.Localpart + ":" + hs.Name: ".access_token",
"user_@" + user.Localpart + ":" + hs.Name: ".access_token",
"device_@" + user.Localpart + ":" + hs.Name: ".device_id",
},
}
}
Expand Down Expand Up @@ -581,7 +597,8 @@ func instructionLogin(hs b.Homeserver, user b.User) instruction {
accessToken: "",
body: body,
storeResponse: map[string]string{
"user_@" + user.Localpart + ":" + hs.Name: ".access_token",
"user_@" + user.Localpart + ":" + hs.Name: ".access_token",
"device_@" + user.Localpart + ":" + hs.Name: ".device_id",
},
}
}
Expand Down