Skip to content

Commit 11e88a4

Browse files
Merge pull request #16520 from liggitt/oauth-code-expiration
Automatic merge from submit-queue Set access token expiration correctly for code and implicit flows The expiration was only being set during the authorization request, which meant it was only set for implicit flow access tokens, and was incorrectly set on authorization tokens. Fixed up and added tests for both flows Fixes https://bugzilla.redhat.com/show_bug.cgi?id=1493903
2 parents 34ed389 + a202ab4 commit 11e88a4

File tree

4 files changed

+143
-40
lines changed

4 files changed

+143
-40
lines changed

pkg/auth/oauth/handlers/authenticator.go

+13-3
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,12 @@ func (h *AuthorizeAuthenticator) HandleAuthorize(ar *osin.AuthorizeRequest, resp
4646
ar.UserData = info
4747
ar.Authorized = true
4848

49-
if e, ok := ar.Client.(TokenMaxAgeSeconds); ok {
50-
if maxAge := e.GetTokenMaxAgeSeconds(); maxAge != nil {
51-
ar.Expiration = *maxAge
49+
// If requesting a token directly, optionally override the expiration
50+
if ar.Type == osin.TOKEN {
51+
if e, ok := ar.Client.(TokenMaxAgeSeconds); ok {
52+
if maxAge := e.GetTokenMaxAgeSeconds(); maxAge != nil {
53+
ar.Expiration = *maxAge
54+
}
5255
}
5356
}
5457

@@ -101,7 +104,14 @@ func (h *AccessAuthenticator) HandleAccess(ar *osin.AccessRequest, w http.Respon
101104
if info != nil {
102105
ar.AccessData.UserData = info
103106
}
107+
108+
if e, ok := ar.Client.(TokenMaxAgeSeconds); ok {
109+
if maxAge := e.GetTokenMaxAgeSeconds(); maxAge != nil {
110+
ar.Expiration = *maxAge
111+
}
112+
}
104113
}
114+
105115
return nil
106116
}
107117

pkg/client/oauthauthorizetoken.go

+10
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
package client
22

33
import (
4+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
5+
kapi "k8s.io/kubernetes/pkg/api"
6+
47
oauthapi "github.com/openshift/origin/pkg/oauth/apis/oauth"
58
)
69

@@ -10,6 +13,7 @@ type OAuthAuthorizeTokensInterface interface {
1013

1114
type OAuthAuthorizeTokenInterface interface {
1215
Create(token *oauthapi.OAuthAuthorizeToken) (*oauthapi.OAuthAuthorizeToken, error)
16+
Get(name string, options metav1.GetOptions) (*oauthapi.OAuthAuthorizeToken, error)
1317
Delete(name string) error
1418
}
1519

@@ -33,3 +37,9 @@ func (c *oauthAuthorizeTokenInterface) Create(token *oauthapi.OAuthAuthorizeToke
3337
err = c.r.Post().Resource("oauthauthorizetokens").Body(token).Do().Into(result)
3438
return
3539
}
40+
41+
func (c *oauthAuthorizeTokenInterface) Get(name string, options metav1.GetOptions) (result *oauthapi.OAuthAuthorizeToken, err error) {
42+
result = &oauthapi.OAuthAuthorizeToken{}
43+
err = c.r.Get().Resource("oauthauthorizetokens").Name(name).VersionedParams(&options, kapi.ParameterCodec).Do().Into(result)
44+
return
45+
}

pkg/client/testclient/fake_oauthauthorizetoken.go

+10
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package testclient
22

33
import (
4+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
45
"k8s.io/apimachinery/pkg/runtime/schema"
56
clientgotesting "k8s.io/client-go/testing"
67

@@ -26,3 +27,12 @@ func (c *FakeOAuthAuthorizeTokens) Create(inObj *oauthapi.OAuthAuthorizeToken) (
2627

2728
return obj.(*oauthapi.OAuthAuthorizeToken), err
2829
}
30+
31+
func (c *FakeOAuthAuthorizeTokens) Get(name string, options metav1.GetOptions) (*oauthapi.OAuthAuthorizeToken, error) {
32+
obj, err := c.Fake.Invokes(clientgotesting.NewRootGetAction(oAuthAuthorizeTokensResource, name), &oauthapi.OAuthAuthorizeToken{})
33+
if obj == nil {
34+
return nil, err
35+
}
36+
37+
return obj.(*oauthapi.OAuthAuthorizeToken), err
38+
}
+110-37
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
package integration
22

33
import (
4+
"net/http"
45
"testing"
56
"time"
67

8+
"golang.org/x/net/context"
9+
"golang.org/x/oauth2"
710
"k8s.io/apimachinery/pkg/api/errors"
811
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
912
"k8s.io/apimachinery/pkg/util/wait"
@@ -43,7 +46,7 @@ func TestOAuthExpiration(t *testing.T) {
4346

4447
{
4548
zero := int32(0)
46-
nonexpiring, err := clusterAdminClient.OAuthClients().Create(&oauthapi.OAuthClient{
49+
client, err := clusterAdminClient.OAuthClients().Create(&oauthapi.OAuthClient{
4750
ObjectMeta: metav1.ObjectMeta{Name: "nonexpiring"},
4851
RespondWithChallenges: true,
4952
RedirectURIs: []string{"http://localhost"},
@@ -54,22 +57,59 @@ func TestOAuthExpiration(t *testing.T) {
5457
t.Fatal(err)
5558
}
5659

57-
nonExpiringTokenOpts := tokencmd.NewRequestTokenOptions(anonConfig, nil, "username", "password")
58-
nonExpiringTokenOpts.ClientID = nonexpiring.Name
59-
nonexpiringToken, err := nonExpiringTokenOpts.RequestToken()
60+
testExpiringOAuthFlows(t, clusterAdminClient, client, anonConfig, 0)
61+
}
62+
63+
{
64+
ten := int32(10)
65+
client, err := clusterAdminClient.OAuthClients().Create(&oauthapi.OAuthClient{
66+
ObjectMeta: metav1.ObjectMeta{Name: "shortexpiring"},
67+
RespondWithChallenges: true,
68+
RedirectURIs: []string{"http://localhost"},
69+
AccessTokenMaxAgeSeconds: &ten,
70+
GrantMethod: oauthapi.GrantHandlerAuto,
71+
})
72+
if err != nil {
73+
t.Fatal(err)
74+
}
75+
76+
token := testExpiringOAuthFlows(t, clusterAdminClient, client, anonConfig, 10)
77+
78+
// Ensure the token goes away after the time expiration
79+
if err := wait.Poll(1*time.Second, time.Minute, func() (bool, error) {
80+
_, err := clusterAdminClient.OAuthAccessTokens().Get(token, metav1.GetOptions{})
81+
if errors.IsNotFound(err) {
82+
return true, nil
83+
}
84+
if err != nil {
85+
return false, err
86+
}
87+
return false, nil
88+
}); err != nil {
89+
t.Fatal(err)
90+
}
91+
}
92+
}
93+
94+
func testExpiringOAuthFlows(t *testing.T, clusterAdminClient *client.Client, oauthclient *oauthapi.OAuthClient, anonConfig *restclient.Config, expectedExpires int) string {
95+
96+
{
97+
tokenOpts := tokencmd.NewRequestTokenOptions(anonConfig, nil, "username", "password")
98+
tokenOpts.ClientID = oauthclient.Name
99+
token, err := tokenOpts.RequestToken()
60100
if err != nil {
61101
t.Fatal(err)
62102
}
63103

64104
// Make sure we can use the token, and it represents who we expect
65-
nonExpiringUserConfig := *anonConfig
66-
nonExpiringUserConfig.BearerToken = nonexpiringToken
67-
nonExpiringUserClient, err := client.New(&nonExpiringUserConfig)
105+
userConfig := *anonConfig
106+
userConfig.BearerToken = token
107+
userClient, err := client.New(&userConfig)
68108
if err != nil {
69109
t.Fatalf("Unexpected error: %v", err)
70110
}
71111

72-
user, err := nonExpiringUserClient.Users().Get("~", metav1.GetOptions{})
112+
user, err := userClient.Users().Get("~", metav1.GetOptions{})
73113
if err != nil {
74114
t.Fatalf("Unexpected error: %v", err)
75115
}
@@ -78,63 +118,96 @@ func TestOAuthExpiration(t *testing.T) {
78118
}
79119

80120
// Make sure the token exists with the overridden time
81-
tokenObj, err := clusterAdminClient.OAuthAccessTokens().Get(nonexpiringToken, metav1.GetOptions{})
121+
tokenObj, err := clusterAdminClient.OAuthAccessTokens().Get(token, metav1.GetOptions{})
82122
if err != nil {
83123
t.Fatal(err)
84124
}
85-
if tokenObj.ExpiresIn != 0 {
86-
t.Fatalf("Expected expiration of 0, got %#v", tokenObj.ExpiresIn)
125+
if tokenObj.ExpiresIn != int64(expectedExpires) {
126+
t.Fatalf("Expected expiration of %d, got %#v", expectedExpires, tokenObj.ExpiresIn)
87127
}
88128
}
89129

90130
{
91-
ten := int32(10)
92-
shortexpiring, err := clusterAdminClient.OAuthClients().Create(&oauthapi.OAuthClient{
93-
ObjectMeta: metav1.ObjectMeta{Name: "shortexpiring"},
94-
RespondWithChallenges: true,
95-
RedirectURIs: []string{"http://localhost"},
96-
AccessTokenMaxAgeSeconds: &ten,
97-
GrantMethod: oauthapi.GrantHandlerAuto,
98-
})
131+
rt, err := restclient.TransportFor(anonConfig)
132+
if err != nil {
133+
t.Fatal(err)
134+
}
135+
136+
conf := &oauth2.Config{
137+
ClientID: oauthclient.Name,
138+
ClientSecret: oauthclient.Secret,
139+
RedirectURL: oauthclient.RedirectURIs[0],
140+
Endpoint: oauth2.Endpoint{
141+
AuthURL: anonConfig.Host + "/oauth/authorize",
142+
TokenURL: anonConfig.Host + "/oauth/token",
143+
},
144+
}
145+
146+
// get code
147+
req, err := http.NewRequest("GET", conf.AuthCodeURL(""), nil)
148+
if err != nil {
149+
t.Fatal(err)
150+
}
151+
req.SetBasicAuth("username", "password")
152+
resp, err := rt.RoundTrip(req)
153+
if err != nil {
154+
t.Fatal(err)
155+
}
156+
if resp.StatusCode != http.StatusFound {
157+
t.Fatalf("unexpected status %v", resp.StatusCode)
158+
}
159+
location, err := resp.Location()
160+
if err != nil {
161+
t.Fatal(err)
162+
}
163+
code := location.Query().Get("code")
164+
if len(code) == 0 {
165+
t.Fatalf("Unexpected response: %v", location)
166+
}
167+
168+
// Make sure the code exists with the default time
169+
codeObj, err := clusterAdminClient.OAuthAuthorizeTokens().Get(code, metav1.GetOptions{})
99170
if err != nil {
100171
t.Fatal(err)
101172
}
173+
if codeObj.ExpiresIn != (5 * 60) {
174+
t.Fatalf("Expected expiration of %d, got %#v", (5 * 60), codeObj.ExpiresIn)
175+
}
102176

103-
expiringTokenOpts := tokencmd.NewRequestTokenOptions(anonConfig, nil, "username", "password")
104-
expiringTokenOpts.ClientID = shortexpiring.Name
105-
expiringToken, err := expiringTokenOpts.RequestToken()
177+
// Use the custom HTTP client when requesting a token.
178+
httpClient := &http.Client{Transport: rt}
179+
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)
180+
oauthToken, err := conf.Exchange(ctx, code)
106181
if err != nil {
107182
t.Fatal(err)
108183
}
184+
token := oauthToken.AccessToken
109185

110186
// Make sure we can use the token, and it represents who we expect
111-
expiringUserConfig := *anonConfig
112-
expiringUserConfig.BearerToken = expiringToken
113-
expiringUserClient, err := client.New(&expiringUserConfig)
187+
userConfig := *anonConfig
188+
userConfig.BearerToken = token
189+
userClient, err := client.New(&userConfig)
114190
if err != nil {
115191
t.Fatalf("Unexpected error: %v", err)
116192
}
117193

118-
user, err := expiringUserClient.Users().Get("~", metav1.GetOptions{})
194+
user, err := userClient.Users().Get("~", metav1.GetOptions{})
119195
if err != nil {
120196
t.Fatalf("Unexpected error: %v", err)
121197
}
122198
if user.Name != "username" {
123199
t.Fatalf("Expected username as the user, got %v", user)
124200
}
125201

126-
// Ensure the token goes away after the time expiration
127-
if err := wait.Poll(1*time.Second, time.Minute, func() (bool, error) {
128-
_, err := clusterAdminClient.OAuthAccessTokens().Get(expiringToken, metav1.GetOptions{})
129-
if errors.IsNotFound(err) {
130-
return true, nil
131-
}
132-
if err != nil {
133-
return false, err
134-
}
135-
return false, nil
136-
}); err != nil {
202+
// Make sure the token exists with the overridden time
203+
tokenObj, err := clusterAdminClient.OAuthAccessTokens().Get(token, metav1.GetOptions{})
204+
if err != nil {
137205
t.Fatal(err)
138206
}
207+
if tokenObj.ExpiresIn != int64(expectedExpires) {
208+
t.Fatalf("Expected expiration of %d, got %#v", expectedExpires, tokenObj.ExpiresIn)
209+
}
210+
211+
return token
139212
}
140213
}

0 commit comments

Comments
 (0)