|
8 | 8 | "net/http/httptest"
|
9 | 9 | "net/url"
|
10 | 10 | "os"
|
11 |
| - "regexp" |
| 11 | + "strings" |
12 | 12 | "testing"
|
13 | 13 |
|
14 | 14 | "k8s.io/kubernetes/pkg/client/restclient"
|
@@ -153,8 +153,11 @@ func TestOAuthRequestHeader(t *testing.T) {
|
153 | 153 | t.Fatalf("unexpected error: %v", err)
|
154 | 154 | }
|
155 | 155 |
|
156 |
| - authorizeURL := clientConfig.Host + "/oauth/authorize?client_id=openshift-challenging-client&response_type=token" |
157 |
| - proxyURL := proxyServer.URL + "/oauth/authorize?client_id=openshift-challenging-client&response_type=token" |
| 156 | + state := `{"then": "/index.html?a=1&b=2&c=%2F"}` |
| 157 | + encodedState := (url.Values{"state": []string{state}}).Encode() |
| 158 | + |
| 159 | + authorizeURL := clientConfig.Host + "/oauth/authorize?client_id=openshift-challenging-client&response_type=token&" + encodedState |
| 160 | + proxyURL := proxyServer.URL + "/oauth/authorize?client_id=openshift-challenging-client&response_type=token&" + encodedState |
158 | 161 |
|
159 | 162 | testcases := map[string]struct {
|
160 | 163 | transport http.RoundTripper
|
@@ -245,14 +248,25 @@ func TestOAuthRequestHeader(t *testing.T) {
|
245 | 248 | continue
|
246 | 249 | }
|
247 | 250 |
|
248 |
| - // Extract the access_token |
249 |
| - |
250 |
| - // group #0 is everything. #1 #2 #3 |
251 |
| - accessTokenRedirectRegex := regexp.MustCompile(`(^|&)access_token=([^&]+)($|&)`) |
252 |
| - accessToken := "" |
253 |
| - if matches := accessTokenRedirectRegex.FindStringSubmatch(tokenRedirect.Fragment); matches != nil { |
254 |
| - accessToken = matches[2] |
| 251 | + // Grab the raw fragment ourselves, since the stdlib URL parsing decodes parts of it |
| 252 | + fragment := "" |
| 253 | + if parts := strings.SplitN(authenticatedProxyResponse.Header.Get("Location"), "#", 2); len(parts) == 2 { |
| 254 | + fragment = parts[1] |
| 255 | + } |
| 256 | + // Extract query-encoded values from the fragment |
| 257 | + fragmentValues, err := url.ParseQuery(fragment) |
| 258 | + if err != nil { |
| 259 | + t.Errorf("%s: %v", k, err) |
| 260 | + continue |
| 261 | + } |
| 262 | + // Ensure the state was retrieved correctly |
| 263 | + returnedState := fragmentValues.Get("state") |
| 264 | + if returnedState != state { |
| 265 | + t.Errorf("%s: Expected state\n\t%v\ngot\n\t%v", k, state, returnedState) |
| 266 | + continue |
255 | 267 | }
|
| 268 | + // Ensure the access_token was retrieved correctly |
| 269 | + accessToken := fragmentValues.Get("access_token") |
256 | 270 | if accessToken == "" {
|
257 | 271 | t.Errorf("%s: Expected access token, got %s", k, tokenRedirect.String())
|
258 | 272 | continue
|
|
0 commit comments