Skip to content

Commit 1a637ca

Browse files
committed
[iam] basic state param encoding
1 parent c2104ab commit 1a637ca

File tree

3 files changed

+105
-19
lines changed

3 files changed

+105
-19
lines changed

components/iam/pkg/oidc/oauth2.go

+17-1
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,16 @@ import (
1313
)
1414

1515
type OAuth2Result struct {
16+
ClientID string
1617
OAuth2Token *oauth2.Token
17-
Redirect string
18+
RedirectURL string
19+
}
20+
21+
type StateParam struct {
22+
// Internal client ID
23+
ClientId string `json:"clientId"`
24+
25+
RedirectURL string `json:"redirectUrl"`
1826
}
1927

2028
type keyOAuth2Result struct{}
@@ -47,6 +55,12 @@ func OAuth2Middleware(next http.Handler) http.Handler {
4755
return
4856
}
4957

58+
state, err := decodeStateParam(stateParam)
59+
if err != nil {
60+
http.Error(rw, "bad state param", http.StatusBadRequest)
61+
return
62+
}
63+
5064
code := r.URL.Query().Get("code")
5165
if code == "" {
5266
http.Error(rw, "code param not found", http.StatusBadRequest)
@@ -61,6 +75,8 @@ func OAuth2Middleware(next http.Handler) http.Handler {
6175

6276
ctx = context.WithValue(ctx, keyOAuth2Result{}, OAuth2Result{
6377
OAuth2Token: oauth2Token,
78+
RedirectURL: state.RedirectURL,
79+
ClientID: state.ClientId,
6480
})
6581
next.ServeHTTP(rw, r.WithContext(ctx))
6682
})

components/iam/pkg/oidc/service.go

+51-13
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,17 @@
55
package oidc
66

77
import (
8+
"bytes"
89
"context"
910
"crypto/rand"
1011
"encoding/base64"
12+
"encoding/json"
1113
"errors"
1214
"io"
1315
"net/http"
16+
"strings"
1417

1518
"github.com/coreos/go-oidc/v3/oidc"
16-
"github.com/gitpod-io/gitpod/common-go/log"
1719
"github.com/go-chi/chi/v5"
1820
"golang.org/x/oauth2"
1921
)
@@ -67,13 +69,18 @@ func (service *OIDCService) AddClientConfig(config *OIDCClientConfig) error {
6769
}
6870

6971
func (service *OIDCService) GetStartParams(config *OIDCClientConfig) (*OIDCStartParams, error) {
70-
// TODO(at) state should be a JWT encoding a redirect location
71-
// Using a random string to get the flow running.
72-
state, err := randString(32)
72+
// state is supposed to a) be present on client request as cookie header
73+
// and b) to be mirrored by the IdP on callback requests.
74+
stateParam := StateParam{
75+
ClientId: config.ID,
76+
RedirectURL: config.OAuth2Config.RedirectURL,
77+
}
78+
state, err := encodeStateParam(stateParam)
7379
if err != nil {
74-
return nil, errors.New("failed to create state")
80+
return nil, errors.New("failed to encode state")
7581
}
7682

83+
// number used once
7784
nonce, err := randString(32)
7885
if err != nil {
7986
return nil, errors.New("failed to create nonce")
@@ -89,6 +96,25 @@ func (service *OIDCService) GetStartParams(config *OIDCClientConfig) (*OIDCStart
8996
}, nil
9097
}
9198

99+
// TODO(at) state should be a JWT encoding a redirect location
100+
// For now, just use base64
101+
func encodeStateParam(state StateParam) (string, error) {
102+
var buf bytes.Buffer
103+
encoder := base64.NewEncoder(base64.StdEncoding, &buf)
104+
err := json.NewEncoder(encoder).Encode(state)
105+
if err != nil {
106+
return "", err
107+
}
108+
encoder.Close()
109+
return buf.String(), nil
110+
}
111+
112+
func decodeStateParam(encoded string) (StateParam, error) {
113+
var result StateParam
114+
err := json.NewDecoder(base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded))).Decode(&result)
115+
return result, err
116+
}
117+
92118
func randString(size int) (string, error) {
93119
b := make([]byte, size)
94120
if _, err := io.ReadFull(rand.Reader, b); err != nil {
@@ -99,18 +125,30 @@ func randString(size int) (string, error) {
99125

100126
func (service *OIDCService) GetClientConfigFromRequest(r *http.Request) (*OIDCClientConfig, error) {
101127
issuerParam := r.URL.Query().Get("issuer")
102-
if issuerParam == "" {
103-
return nil, errors.New("issuer param not specified")
128+
stateParam := r.URL.Query().Get("state")
129+
if issuerParam == "" && stateParam == "" {
130+
return nil, errors.New("missing request parameters")
104131
}
105-
log.WithField("issuerParam", issuerParam).Trace("GetClientConfigFromRequest")
106-
log.WithField("issuer", issuerParam).Trace("at GetClientConfigFromRequest")
107132

108-
for id, value := range service.configsById {
109-
log.WithField("issuer", value.Issuer).WithField("id", id).Trace("GetClientConfigFromRequest (candidate)")
110-
if value.Issuer == issuerParam {
111-
return value, nil
133+
if issuerParam != "" {
134+
for _, value := range service.configsById {
135+
if value.Issuer == issuerParam {
136+
return value, nil
137+
}
112138
}
113139
}
140+
141+
if stateParam != "" {
142+
state, err := decodeStateParam(stateParam)
143+
if err != nil {
144+
return nil, errors.New("bad state param")
145+
}
146+
config := service.configsById[state.ClientId]
147+
if config != nil {
148+
return config, nil
149+
}
150+
}
151+
114152
return nil, errors.New("failed to find OIDC config for request")
115153
}
116154

components/iam/pkg/oidc/service_test.go

+37-5
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"log"
1111
"net/http"
1212
"net/http/httptest"
13+
"net/url"
1314
"testing"
1415

1516
"github.com/coreos/go-oidc/v3/oidc"
@@ -52,13 +53,29 @@ func TestGetStartParams(t *testing.T) {
5253
require.NotNil(t, params.AuthCodeURL)
5354
require.Contains(t, params.AuthCodeURL, issuerG)
5455
require.Contains(t, params.AuthCodeURL, clientID)
55-
require.Contains(t, params.AuthCodeURL, params.Nonce)
56-
require.Contains(t, params.AuthCodeURL, params.State)
56+
require.Contains(t, params.AuthCodeURL, url.QueryEscape(params.Nonce))
57+
require.Contains(t, params.AuthCodeURL, url.QueryEscape(params.State))
5758
}
5859

5960
func TestGetClientConfigFromRequest(t *testing.T) {
6061
issuer := newFakeIdP(t)
6162

63+
const (
64+
clientID = "google-1"
65+
)
66+
67+
state, err := encodeStateParam(StateParam{
68+
ClientId: clientID,
69+
RedirectURL: "",
70+
})
71+
require.NoError(t, err, "failed encode state param")
72+
73+
state_unknown, err := encodeStateParam(StateParam{
74+
ClientId: "UNKNOWN",
75+
RedirectURL: "",
76+
})
77+
require.NoError(t, err, "failed encode state param")
78+
6279
testCases := []struct {
6380
Location string
6481
ExpectedError bool
@@ -72,18 +89,33 @@ func TestGetClientConfigFromRequest(t *testing.T) {
7289
{
7390
Location: "/start?issuer=" + issuer,
7491
ExpectedError: false,
75-
ExpectedId: "google-1",
92+
ExpectedId: clientID,
7693
},
7794
{
7895
Location: "/start?issuer=UNKNOWN",
7996
ExpectedError: true,
8097
ExpectedId: "",
8198
},
99+
{
100+
Location: "/callback?state=BAD",
101+
ExpectedError: true,
102+
ExpectedId: "",
103+
},
104+
{
105+
Location: "/callback?state=" + state_unknown,
106+
ExpectedError: true,
107+
ExpectedId: "",
108+
},
109+
{
110+
Location: "/callback?state=" + state,
111+
ExpectedError: false,
112+
ExpectedId: clientID,
113+
},
82114
}
83115

84116
service := NewOIDCService()
85-
err := service.AddClientConfig(&OIDCClientConfig{
86-
ID: "google-1",
117+
err = service.AddClientConfig(&OIDCClientConfig{
118+
ID: clientID,
87119
Issuer: issuer,
88120
OIDCConfig: &oidc.Config{},
89121
OAuth2Config: &oauth2.Config{},

0 commit comments

Comments
 (0)