Skip to content

Commit 9c98314

Browse files
mgyongyosikalleep
andauthored
OAuth: Refactor OAuth parameters handling to support obtaining refresh tokens for Google OAuth (grafana#58782)
* Add ApprovalForce to AuthCodeOptions * Extract access token validity check to a function * Refactor * Oauth: set options internally instead of exposing new function * Align tests * Remove unused function Co-authored-by: Karl Persson <[email protected]>
1 parent d46e391 commit 9c98314

10 files changed

+70
-40
lines changed

Diff for: pkg/api/frontendsettings_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ func setupTestEnvironment(t *testing.T, cfg *setting.Cfg, features *featuremgmt.
5858
grafanaUpdateChecker: &updatechecker.GrafanaService{},
5959
AccessControl: accesscontrolmock.New().WithDisabled(),
6060
PluginSettings: pluginSettings.ProvideService(sqlStore, secretsService),
61-
SocialService: social.ProvideService(cfg),
61+
SocialService: social.ProvideService(cfg, features),
6262
}
6363

6464
m := web.New()

Diff for: pkg/api/login_oauth.go

+1-3
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,7 @@ func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) {
9797

9898
code := ctx.Query("code")
9999
if code == "" {
100-
// FIXME: access_type is a Google OAuth2 specific thing, consider refactoring this and moving to google_oauth.go
101-
opts := []oauth2.AuthCodeOption{oauth2.AccessTypeOffline}
102-
100+
var opts []oauth2.AuthCodeOption
103101
if provider.UsePKCE {
104102
ascii, pkce, err := genPKCECode()
105103
if err != nil {

Diff for: pkg/api/login_oauth_test.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@ import (
99
"path/filepath"
1010
"testing"
1111

12-
"github.com/grafana/grafana/pkg/infra/db"
13-
"github.com/grafana/grafana/pkg/services/secrets/fakes"
14-
1512
"github.com/stretchr/testify/assert"
1613
"github.com/stretchr/testify/require"
1714

15+
"github.com/grafana/grafana/pkg/infra/db"
1816
"github.com/grafana/grafana/pkg/login/social"
17+
"github.com/grafana/grafana/pkg/services/featuremgmt"
1918
"github.com/grafana/grafana/pkg/services/hooks"
2019
"github.com/grafana/grafana/pkg/services/licensing"
20+
"github.com/grafana/grafana/pkg/services/secrets/fakes"
2121
"github.com/grafana/grafana/pkg/setting"
2222
"github.com/grafana/grafana/pkg/web"
2323
)
@@ -36,7 +36,7 @@ func setupOAuthTest(t *testing.T, cfg *setting.Cfg) *web.Mux {
3636
Cfg: cfg,
3737
License: &licensing.OSSLicensingService{Cfg: cfg},
3838
SQLStore: sqlStore,
39-
SocialService: social.ProvideService(cfg),
39+
SocialService: social.ProvideService(cfg, featuremgmt.WithFeatures()),
4040
HooksService: hooks.ProvideService(),
4141
SecretsService: fakes.NewFakeSecretsService(),
4242
}

Diff for: pkg/login/social/azuread_oauth_test.go

+15-13
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ import (
1313
"golang.org/x/oauth2"
1414
"gopkg.in/square/go-jose.v2"
1515
"gopkg.in/square/go-jose.v2/jwt"
16+
17+
"github.com/grafana/grafana/pkg/services/featuremgmt"
1618
)
1719

1820
func trueBoolPtr() *bool {
@@ -54,7 +56,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
5456
ID: "1234",
5557
},
5658
fields: fields{
57-
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "Viewer", false),
59+
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "Viewer", false, *featuremgmt.WithFeatures()),
5860
},
5961
want: &BasicUserInfo{
6062
Id: "1234",
@@ -93,7 +95,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
9395
ID: "1234",
9496
},
9597
fields: fields{
96-
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "Viewer", false),
98+
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "Viewer", false, *featuremgmt.WithFeatures()),
9799
},
98100
want: &BasicUserInfo{
99101
Id: "1234",
@@ -143,7 +145,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
143145
{
144146
name: "Only other roles",
145147
fields: fields{
146-
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "Viewer", false),
148+
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "Viewer", false, *featuremgmt.WithFeatures()),
147149
},
148150
claims: &azureClaims{
149151
@@ -171,7 +173,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
171173
ID: "1234",
172174
},
173175
fields: fields{
174-
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "Editor", false),
176+
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "Editor", false, *featuremgmt.WithFeatures()),
175177
},
176178
want: &BasicUserInfo{
177179
Id: "1234",
@@ -220,7 +222,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
220222
},
221223
{
222224
name: "Grafana Admin but setting is disabled",
223-
fields: fields{SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{AllowAssignGrafanaAdmin: false}, "Editor", false)},
225+
fields: fields{SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{AllowAssignGrafanaAdmin: false}, "Editor", false, *featuremgmt.WithFeatures())},
224226
claims: &azureClaims{
225227
226228
PreferredUsername: "",
@@ -242,7 +244,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
242244
name: "Editor roles in claim and GrafanaAdminAssignment enabled",
243245
fields: fields{
244246
SocialBase: newSocialBase("azuread",
245-
&oauth2.Config{}, &OAuthInfo{AllowAssignGrafanaAdmin: true}, "", false)},
247+
&oauth2.Config{}, &OAuthInfo{AllowAssignGrafanaAdmin: true}, "", false, *featuremgmt.WithFeatures())},
246248
claims: &azureClaims{
247249
248250
PreferredUsername: "",
@@ -263,7 +265,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
263265
{
264266
name: "Grafana Admin and Editor roles in claim",
265267
fields: fields{SocialBase: newSocialBase("azuread",
266-
&oauth2.Config{}, &OAuthInfo{AllowAssignGrafanaAdmin: true}, "", false)},
268+
&oauth2.Config{}, &OAuthInfo{AllowAssignGrafanaAdmin: true}, "", false, *featuremgmt.WithFeatures())},
267269
claims: &azureClaims{
268270
269271
PreferredUsername: "",
@@ -302,7 +304,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
302304
fields: fields{
303305
allowedGroups: []string{"foo", "bar"},
304306
SocialBase: newSocialBase("azuread",
305-
&oauth2.Config{}, &OAuthInfo{AllowAssignGrafanaAdmin: false}, "Viewer", false),
307+
&oauth2.Config{}, &OAuthInfo{AllowAssignGrafanaAdmin: false}, "Viewer", false, *featuremgmt.WithFeatures()),
306308
},
307309
claims: &azureClaims{
308310
@@ -324,7 +326,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
324326
{
325327
name: "Fetch groups when ClaimsNames and ClaimsSources is set",
326328
fields: fields{
327-
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "", false),
329+
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "", false, *featuremgmt.WithFeatures()),
328330
},
329331
claims: &azureClaims{
330332
ID: "1",
@@ -349,7 +351,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
349351
{
350352
name: "Fetch groups when forceUseGraphAPI is set",
351353
fields: fields{
352-
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "", false),
354+
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "", false, *featuremgmt.WithFeatures()),
353355
forceUseGraphAPI: true,
354356
},
355357
claims: &azureClaims{
@@ -376,7 +378,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
376378
{
377379
name: "Fetch empty role when strict attribute role is true and no match",
378380
fields: fields{
379-
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{RoleAttributeStrict: true}, "", false),
381+
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{RoleAttributeStrict: true}, "", false, *featuremgmt.WithFeatures()),
380382
},
381383
claims: &azureClaims{
382384
@@ -392,7 +394,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
392394
{
393395
name: "Fetch empty role when strict attribute role is true and no role claims returned",
394396
fields: fields{
395-
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{RoleAttributeStrict: true}, "", false),
397+
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{RoleAttributeStrict: true}, "", false, *featuremgmt.WithFeatures()),
396398
},
397399
claims: &azureClaims{
398400
@@ -416,7 +418,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
416418
}
417419

418420
if tt.fields.SocialBase == nil {
419-
s.SocialBase = newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "", false)
421+
s.SocialBase = newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "", false, *featuremgmt.WithFeatures())
420422
}
421423

422424
key := []byte("secret")

Diff for: pkg/login/social/generic_oauth.go

+9
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ import (
1414
"strconv"
1515

1616
"golang.org/x/oauth2"
17+
18+
"github.com/grafana/grafana/pkg/services/featuremgmt"
1719
)
1820

1921
type SocialGenericOAuth struct {
@@ -504,3 +506,10 @@ func (s *SocialGenericOAuth) FetchOrganizations(client *http.Client) ([]string,
504506

505507
return logins, true
506508
}
509+
510+
func (s *SocialGenericOAuth) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string {
511+
if s.features.IsEnabled(featuremgmt.FlagAccessTokenExpirationCheck) {
512+
opts = append(opts, oauth2.AccessTypeOffline)
513+
}
514+
return s.SocialBase.AuthCodeURL(state, opts...)
515+
}

Diff for: pkg/login/social/github_oauth_test.go

+3-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ import (
99

1010
"github.com/stretchr/testify/require"
1111
"golang.org/x/oauth2"
12+
13+
"github.com/grafana/grafana/pkg/services/featuremgmt"
1214
)
1315

1416
const testGHUserTeamsJSON = `[
@@ -202,7 +204,7 @@ func TestSocialGitHub_UserInfo(t *testing.T) {
202204

203205
s := &SocialGithub{
204206
SocialBase: newSocialBase("github", &oauth2.Config{},
205-
&OAuthInfo{RoleAttributePath: tt.roleAttributePath}, tt.autoAssignOrgRole, false),
207+
&OAuthInfo{RoleAttributePath: tt.roleAttributePath}, tt.autoAssignOrgRole, false, *featuremgmt.WithFeatures()),
206208
allowedOrganizations: []string{},
207209
apiUrl: server.URL + "/user",
208210
teamIds: []int{},

Diff for: pkg/login/social/google_oauth.go

+9
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import (
66
"net/http"
77

88
"golang.org/x/oauth2"
9+
10+
"github.com/grafana/grafana/pkg/services/featuremgmt"
911
)
1012

1113
type SocialGoogle struct {
@@ -38,3 +40,10 @@ func (s *SocialGoogle) UserInfo(client *http.Client, token *oauth2.Token) (*Basi
3840
Login: data.Email,
3941
}, nil
4042
}
43+
44+
func (s *SocialGoogle) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string {
45+
if s.features.IsEnabled(featuremgmt.FlagAccessTokenExpirationCheck) {
46+
opts = append(opts, oauth2.AccessTypeOffline, oauth2.ApprovalForce)
47+
}
48+
return s.SocialBase.AuthCodeURL(state, opts...)
49+
}

Diff for: pkg/login/social/social.go

+12-9
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"golang.org/x/text/language"
1717

1818
"github.com/grafana/grafana/pkg/infra/log"
19+
"github.com/grafana/grafana/pkg/services/featuremgmt"
1920
"github.com/grafana/grafana/pkg/services/org"
2021
"github.com/grafana/grafana/pkg/setting"
2122
"github.com/grafana/grafana/pkg/util"
@@ -58,7 +59,7 @@ type OAuthInfo struct {
5859
UsePKCE bool
5960
}
6061

61-
func ProvideService(cfg *setting.Cfg) *SocialService {
62+
func ProvideService(cfg *setting.Cfg, features *featuremgmt.FeatureManager) *SocialService {
6263
ss := SocialService{
6364
cfg: cfg,
6465
oAuthProvider: make(map[string]*OAuthInfo),
@@ -139,7 +140,7 @@ func ProvideService(cfg *setting.Cfg) *SocialService {
139140
// GitHub.
140141
if name == "github" {
141142
ss.socialMap["github"] = &SocialGithub{
142-
SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync),
143+
SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, *features),
143144
apiUrl: info.ApiUrl,
144145
teamIds: sec.Key("team_ids").Ints(","),
145146
allowedOrganizations: util.SplitString(sec.Key("allowed_organizations").String()),
@@ -149,7 +150,7 @@ func ProvideService(cfg *setting.Cfg) *SocialService {
149150
// GitLab.
150151
if name == "gitlab" {
151152
ss.socialMap["gitlab"] = &SocialGitlab{
152-
SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync),
153+
SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, *features),
153154
apiUrl: info.ApiUrl,
154155
allowedGroups: util.SplitString(sec.Key("allowed_groups").String()),
155156
}
@@ -158,7 +159,7 @@ func ProvideService(cfg *setting.Cfg) *SocialService {
158159
// Google.
159160
if name == "google" {
160161
ss.socialMap["google"] = &SocialGoogle{
161-
SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync),
162+
SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, *features),
162163
hostedDomain: info.HostedDomain,
163164
apiUrl: info.ApiUrl,
164165
}
@@ -167,7 +168,7 @@ func ProvideService(cfg *setting.Cfg) *SocialService {
167168
// AzureAD.
168169
if name == "azuread" {
169170
ss.socialMap["azuread"] = &SocialAzureAD{
170-
SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync),
171+
SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, *features),
171172
allowedGroups: util.SplitString(sec.Key("allowed_groups").String()),
172173
forceUseGraphAPI: sec.Key("force_use_graph_api").MustBool(false),
173174
}
@@ -176,7 +177,7 @@ func ProvideService(cfg *setting.Cfg) *SocialService {
176177
// Okta
177178
if name == "okta" {
178179
ss.socialMap["okta"] = &SocialOkta{
179-
SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync),
180+
SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, *features),
180181
apiUrl: info.ApiUrl,
181182
allowedGroups: util.SplitString(sec.Key("allowed_groups").String()),
182183
}
@@ -185,7 +186,7 @@ func ProvideService(cfg *setting.Cfg) *SocialService {
185186
// Generic - Uses the same scheme as GitHub.
186187
if name == "generic_oauth" {
187188
ss.socialMap["generic_oauth"] = &SocialGenericOAuth{
188-
SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync),
189+
SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, *features),
189190
apiUrl: info.ApiUrl,
190191
teamsUrl: info.TeamsUrl,
191192
emailAttributeName: info.EmailAttributeName,
@@ -214,8 +215,7 @@ func ProvideService(cfg *setting.Cfg) *SocialService {
214215
}
215216

216217
ss.socialMap[grafanaCom] = &SocialGrafanaCom{
217-
SocialBase: newSocialBase(name, &config, info,
218-
cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync),
218+
SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, *features),
219219
url: cfg.GrafanaComURL,
220220
allowedOrganizations: util.SplitString(sec.Key("allowed_organizations").String()),
221221
}
@@ -261,6 +261,7 @@ type SocialBase struct {
261261
roleAttributeStrict bool
262262
autoAssignOrgRole string
263263
skipOrgRoleSync bool
264+
features featuremgmt.FeatureManager
264265
}
265266

266267
type Error struct {
@@ -295,6 +296,7 @@ func newSocialBase(name string,
295296
info *OAuthInfo,
296297
autoAssignOrgRole string,
297298
skipOrgRoleSync bool,
299+
features featuremgmt.FeatureManager,
298300
) *SocialBase {
299301
logger := log.New("oauth." + name)
300302

@@ -308,6 +310,7 @@ func newSocialBase(name string,
308310
roleAttributePath: info.RoleAttributePath,
309311
roleAttributeStrict: info.RoleAttributeStrict,
310312
skipOrgRoleSync: skipOrgRoleSync,
313+
features: features,
311314
}
312315
}
313316

Diff for: pkg/server/server.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ func (s *Server) init() error {
127127
}
128128

129129
login.ProvideService(s.HTTPServer.SQLStore, s.HTTPServer.Login, s.loginAttemptService, s.userService)
130-
social.ProvideService(s.cfg)
130+
social.ProvideService(s.cfg, s.HTTPServer.Features)
131131

132132
if err := s.roleRegistry.RegisterFixedRoles(s.context); err != nil {
133133
return err

Diff for: pkg/services/contexthandler/contexthandler.go

+15-8
Original file line numberDiff line numberDiff line change
@@ -449,20 +449,14 @@ func (h *ContextHandler) initContextWithToken(reqContext *models.ReqContext, org
449449
return false
450450
}
451451

452-
getTime := h.GetTime
453-
if getTime == nil {
454-
getTime = time.Now
455-
}
456-
457452
if h.features.IsEnabled(featuremgmt.FlagAccessTokenExpirationCheck) {
458453
// Check whether the logged in User has a token (whether the User used an OAuth provider to login)
459454
oauthToken, exists, _ := h.oauthTokenService.HasOAuthEntry(ctx, queryResult)
460455
if exists {
461-
// Skip where the OAuthExpiry is default/zero/unset
462-
if !oauthToken.OAuthExpiry.IsZero() && oauthToken.OAuthExpiry.Round(0).Add(-oauthtoken.ExpiryDelta).Before(getTime()) {
456+
if h.hasAccessTokenExpired(oauthToken) {
463457
reqContext.Logger.Info("access token expired", "userId", query.UserID, "expiry", fmt.Sprintf("%v", oauthToken.OAuthExpiry))
464458

465-
// If the User doesn't have a refresh_token or refreshing the token was unsuccessful then log out the User and Invalidate the OAuth tokens
459+
// If the User doesn't have a refresh_token or refreshing the token was unsuccessful then log out the User and invalidate the OAuth tokens
466460
if err = h.oauthTokenService.TryTokenRefresh(ctx, oauthToken); err != nil {
467461
if !errors.Is(err, oauthtoken.ErrNoRefreshTokenFound) {
468462
reqContext.Logger.Error("could not fetch a new access token", "userId", oauthToken.UserId, "error", err)
@@ -732,3 +726,16 @@ func AuthHTTPHeaderListFromContext(c context.Context) *AuthHTTPHeaderList {
732726
}
733727
return nil
734728
}
729+
730+
func (h *ContextHandler) hasAccessTokenExpired(token *models.UserAuth) bool {
731+
if token.OAuthExpiry.IsZero() {
732+
return false
733+
}
734+
735+
getTime := h.GetTime
736+
if getTime == nil {
737+
getTime = time.Now
738+
}
739+
740+
return token.OAuthExpiry.Round(0).Add(-oauthtoken.ExpiryDelta).Before(getTime())
741+
}

0 commit comments

Comments
 (0)