5
5
package oidc
6
6
7
7
import (
8
+ "bytes"
8
9
"context"
9
10
"crypto/rand"
10
11
"encoding/base64"
12
+ "encoding/json"
11
13
"errors"
12
14
"io"
13
15
"net/http"
16
+ "strings"
14
17
15
18
"github.com/coreos/go-oidc/v3/oidc"
16
- "github.com/gitpod-io/gitpod/common-go/log"
17
19
"github.com/go-chi/chi/v5"
18
20
"golang.org/x/oauth2"
19
21
)
@@ -67,13 +69,18 @@ func (service *OIDCService) AddClientConfig(config *OIDCClientConfig) error {
67
69
}
68
70
69
71
func (service * OIDCService ) GetStartParams (config * OIDCClientConfig ) (* OIDCStartParams , error ) {
70
- // TODO(at) state should be a JWT encoding a redirect location
71
- // Using a random string to get the flow running.
72
- state , err := randString (32 )
72
+ // state is supposed to a) be present on client request as cookie header
73
+ // and b) to be mirrored by the IdP on callback requests.
74
+ stateParam := StateParam {
75
+ ClientId : config .ID ,
76
+ RedirectURL : config .OAuth2Config .RedirectURL ,
77
+ }
78
+ state , err := encodeStateParam (stateParam )
73
79
if err != nil {
74
- return nil , errors .New ("failed to create state" )
80
+ return nil , errors .New ("failed to encode state" )
75
81
}
76
82
83
+ // number used once
77
84
nonce , err := randString (32 )
78
85
if err != nil {
79
86
return nil , errors .New ("failed to create nonce" )
@@ -89,6 +96,25 @@ func (service *OIDCService) GetStartParams(config *OIDCClientConfig) (*OIDCStart
89
96
}, nil
90
97
}
91
98
99
+ // TODO(at) state should be a JWT encoding a redirect location
100
+ // For now, just use base64
101
+ func encodeStateParam (state StateParam ) (string , error ) {
102
+ var buf bytes.Buffer
103
+ encoder := base64 .NewEncoder (base64 .StdEncoding , & buf )
104
+ err := json .NewEncoder (encoder ).Encode (state )
105
+ if err != nil {
106
+ return "" , err
107
+ }
108
+ encoder .Close ()
109
+ return buf .String (), nil
110
+ }
111
+
112
+ func decodeStateParam (encoded string ) (StateParam , error ) {
113
+ var result StateParam
114
+ err := json .NewDecoder (base64 .NewDecoder (base64 .StdEncoding , strings .NewReader (encoded ))).Decode (& result )
115
+ return result , err
116
+ }
117
+
92
118
func randString (size int ) (string , error ) {
93
119
b := make ([]byte , size )
94
120
if _ , err := io .ReadFull (rand .Reader , b ); err != nil {
@@ -99,18 +125,30 @@ func randString(size int) (string, error) {
99
125
100
126
func (service * OIDCService ) GetClientConfigFromRequest (r * http.Request ) (* OIDCClientConfig , error ) {
101
127
issuerParam := r .URL .Query ().Get ("issuer" )
102
- if issuerParam == "" {
103
- return nil , errors .New ("issuer param not specified" )
128
+ stateParam := r .URL .Query ().Get ("state" )
129
+ if issuerParam == "" && stateParam == "" {
130
+ return nil , errors .New ("missing request parameters" )
104
131
}
105
- log .WithField ("issuerParam" , issuerParam ).Trace ("GetClientConfigFromRequest" )
106
- log .WithField ("issuer" , issuerParam ).Trace ("at GetClientConfigFromRequest" )
107
132
108
- for id , value := range service .configsById {
109
- log .WithField ("issuer" , value .Issuer ).WithField ("id" , id ).Trace ("GetClientConfigFromRequest (candidate)" )
110
- if value .Issuer == issuerParam {
111
- return value , nil
133
+ if issuerParam != "" {
134
+ for _ , value := range service .configsById {
135
+ if value .Issuer == issuerParam {
136
+ return value , nil
137
+ }
112
138
}
113
139
}
140
+
141
+ if stateParam != "" {
142
+ state , err := decodeStateParam (stateParam )
143
+ if err != nil {
144
+ return nil , errors .New ("bad state param" )
145
+ }
146
+ config := service .configsById [state .ClientId ]
147
+ if config != nil {
148
+ return config , nil
149
+ }
150
+ }
151
+
114
152
return nil , errors .New ("failed to find OIDC config for request" )
115
153
}
116
154
0 commit comments