Skip to content

Commit 8591084

Browse files
committed
PKCE Support
Adds Code Challenge PKCE support (RFC-7636) and partial Authorization Server Metadata (RFC-8414) for detecting PKCE support. - Introduces new option `--force-code-challenge-method` to force a specific code challenge method (either `S256` or `plain`) for instances when the server has not implemented RFC-8414 in order to detect PKCE support on the discovery document. - In all other cases, if the PKCE support can be determined during discovery then the `code_challenge_methods_supported` is used and S256 is always preferred. - The force command line argument is helpful with some providers like Azure who supports PKCE but does not list it in their discovery document yet. - Initial thought was given to just always attempt PKCE since according to spec additional URL parameters should be dropped by servers which implemented OAuth 2, however other projects found cases in the wild where this causes 500 errors by buggy implementations. See: spring-projects/spring-security#7804 (comment) - Due to the fact that the `code_verifier` must be saved between the redirect and callback, sessions are now created when the redirect takes place with `Authenticated: false`. The session will be recreated and marked as `Authenticated` on callback. - Individual provider implementations can choose to include or ignore code_challenge and code_verifier function parameters passed to them Note: Technically speaking `plain` is not required to be implemented since oauth2-proxy will always be able to handle S256 and servers MUST implement S256 support. > If the client is capable of using "S256", it MUST use "S256", as "S256" > is Mandatory To Implement (MTI) on the server. Clients are permitted > to use "plain" only if they cannot support "S256" for some technical > reason and know via out-of-band configuration that the server supports > "plain". Ref: RFC-7636 Sec 4.2 oauth2-proxy will always use S256 unless the user explicitly forces `plain`. Fixes oauth2-proxy#1361
1 parent b4997c6 commit 8591084

29 files changed

+375
-99
lines changed

.vscode/launch.json

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{
2+
// Use IntelliSense to learn about possible attributes.
3+
// Hover to view descriptions of existing attributes.
4+
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5+
"version": "0.2.0",
6+
"configurations": [
7+
{
8+
"name": "Launch Package",
9+
"type": "go",
10+
"request": "launch",
11+
"mode": "auto",
12+
"program": "${fileDirname}",
13+
"args": ["--config", "contrib/local-environment/oauth2-proxy.cfg"]
14+
}
15+
]
16+
}

CHANGELOG.md

+4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
## Release Highlights
44

5+
- [#1361](https://github.com/oauth2-proxy/oauth2-proxy/issues/1361) PKCE Code Challenge Support - RFC-7636 (@braunsonm)
6+
- The `--force-code-challenge-method` flag can be used when the providers discovery document does not advertise PKCE support.
7+
- Parital support for OAuth2 Authorization Server Metadata for detecting code challenge methods (@braunsonm)
8+
59
## Important Notes
610

711
## Breaking Changes

docs/docs/configuration/overview.md

+1
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ An example [oauth2-proxy.cfg](https://github.com/oauth2-proxy/oauth2-proxy/blob/
104104
| `--flush-interval` | duration | period between flushing response buffers when streaming responses | `"1s"` |
105105
| `--force-https` | bool | enforce https redirect | `false` |
106106
| `--force-json-errors` | bool | force JSON errors instead of HTTP error pages or redirects | `false` |
107+
| `--force-code-challenge-method` | will force PKCE code challenges with the specified method (if not automatically detected by the discovery document). Either 'plain' or 'S256' | |
107108
| `--banner` | string | custom (html) banner string. Use `"-"` to disable default banner. | |
108109
| `--footer` | string | custom (html) footer string. Use `"-"` to disable default footer. | |
109110
| `--github-org` | string | restrict logins to members of this organisation | |

oauthproxy.go

+87-4
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,13 @@ package main
22

33
import (
44
"context"
5+
"crypto/rand"
6+
"crypto/sha256"
7+
"encoding/base64"
58
"encoding/json"
69
"errors"
710
"fmt"
11+
"math/big"
812
"net"
913
"net/http"
1014
"net/url"
@@ -26,6 +30,7 @@ import (
2630
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/authentication/basic"
2731
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/cookies"
2832
proxyhttp "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/http"
33+
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/validation"
2934

3035
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/ip"
3136
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
@@ -118,6 +123,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
118123
if err != nil {
119124
return nil, fmt.Errorf("error intiailising provider: %v", err)
120125
}
126+
provider.Data().CodeChallengeMethod = providers.ParseCodeChallengeMethod(opts)
121127

122128
pageWriter, err := pagewriter.NewWriter(pagewriter.Opts{
123129
TemplatesPath: opts.Templates.Path,
@@ -279,6 +285,7 @@ func (p *OAuthProxy) buildServeMux(proxyPrefix string) {
279285
r := mux.NewRouter().UseEncodedPath()
280286
// Everything served by the router must go through the preAuthChain first.
281287
r.Use(p.preAuthChain.Then)
288+
r.Use(p.sessionChain.Then)
282289

283290
// Register the robots path writer
284291
r.Path(robotsPath).HandlerFunc(p.pageWriter.WriteRobotsTxt)
@@ -686,11 +693,31 @@ func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) {
686693
return
687694
}
688695

696+
var codeChallenge, codeVerifier, codeChallengeMethod string
697+
if p.provider.Data().CodeChallengeMethod != "" {
698+
codeChallengeMethod = p.provider.Data().CodeChallengeMethod
699+
codeVerifier, err = generateRandomString(128)
700+
if err != nil {
701+
logger.Errorf("Unable to build random string: %v", err)
702+
p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error())
703+
return
704+
}
705+
706+
codeChallenge, err = generateCodeChallenge(p.provider.Data().CodeChallengeMethod, codeVerifier)
707+
if err != nil {
708+
logger.Errorf("Error creating code challenge: %v", err)
709+
p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error())
710+
return
711+
}
712+
}
713+
689714
callbackRedirect := p.getOAuthRedirectURI(req)
690715
loginURL := p.provider.GetLoginURL(
691716
callbackRedirect,
692717
encodeState(csrf.HashOAuthState(), appRedirect),
693718
csrf.HashOIDCNonce(),
719+
codeChallenge,
720+
codeChallengeMethod,
694721
)
695722

696723
if _, err := csrf.SetCookie(rw, req); err != nil {
@@ -699,6 +726,16 @@ func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) {
699726
return
700727
}
701728

729+
// A session is created with `Authenticated: false` to store the code verifier
730+
// for token redemption
731+
session := &sessionsapi.SessionState{CodeVerifier: codeVerifier}
732+
err = p.SaveSession(rw, req, session)
733+
if err != nil {
734+
logger.Errorf("Error saving session state for %s: %v", req.RemoteAddr, err)
735+
p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error())
736+
return
737+
}
738+
702739
http.Redirect(rw, req, loginURL, http.StatusFound)
703740
}
704741

@@ -723,7 +760,20 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
723760
return
724761
}
725762

726-
session, err := p.redeemCode(req)
763+
session := middlewareapi.GetRequestScope(req).Session
764+
var codeVerifier string
765+
if session == nil {
766+
logger.Errorf("Error retrieving session containing code verifier")
767+
if p.provider.Data().CodeChallengeMethod != "" {
768+
// Only error the whole callback if code verifier is required
769+
p.ErrorPage(rw, req, http.StatusInternalServerError, "")
770+
return
771+
}
772+
} else {
773+
codeVerifier = session.CodeVerifier
774+
}
775+
776+
session, err = p.redeemCode(req, codeVerifier)
727777
if err != nil {
728778
logger.Errorf("Error redeeming code during OAuth2 callback: %v", err)
729779
p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error())
@@ -790,14 +840,14 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
790840
}
791841
}
792842

793-
func (p *OAuthProxy) redeemCode(req *http.Request) (*sessionsapi.SessionState, error) {
843+
func (p *OAuthProxy) redeemCode(req *http.Request, codeVerifier string) (*sessionsapi.SessionState, error) {
794844
code := req.Form.Get("code")
795845
if code == "" {
796846
return nil, providers.ErrMissingCode
797847
}
798848

799849
redirectURI := p.getOAuthRedirectURI(req)
800-
s, err := p.provider.Redeem(req.Context(), redirectURI, code)
850+
s, err := p.provider.Redeem(req.Context(), redirectURI, code, codeVerifier)
801851
if err != nil {
802852
return nil, err
803853
}
@@ -809,10 +859,43 @@ func (p *OAuthProxy) redeemCode(req *http.Request) (*sessionsapi.SessionState, e
809859
if s.ExpiresOn == nil {
810860
s.ExpiresIn(p.CookieOptions.Expire)
811861
}
862+
s.Authenticated = true
812863

813864
return s, nil
814865
}
815866

867+
func generateCodeChallenge(method, codeVerifier string) (string, error) {
868+
switch method {
869+
case validation.CodeChallengeMethodPlain:
870+
return codeVerifier, nil
871+
case validation.CodeChallengeMethodS256:
872+
shaSum := sha256.Sum256([]byte(codeVerifier))
873+
return base64.RawURLEncoding.EncodeToString(shaSum[:]), nil
874+
default:
875+
return "", fmt.Errorf("unknown challenge method: %v", method)
876+
}
877+
}
878+
879+
const runes string = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789.-_~"
880+
881+
// generateRandomString returns a securely generated random ASCII string.
882+
// It reads random numbers from crypto/rand and searches for printable characters.
883+
// It will return an error if the system's secure random number generator fails to
884+
// function correctly, in which case the caller must not continue.
885+
// From: https://gist.github.com/dopey/c69559607800d2f2f90b1b1ed4e550fb
886+
func generateRandomString(n int) (string, error) {
887+
ret := make([]byte, n)
888+
for i := 0; i < n; i++ {
889+
num, err := rand.Int(rand.Reader, big.NewInt(int64(len(runes))))
890+
if err != nil {
891+
return "", err
892+
}
893+
ret[i] = runes[num.Int64()]
894+
}
895+
896+
return string(ret), nil
897+
}
898+
816899
func (p *OAuthProxy) enrichSessionState(ctx context.Context, s *sessionsapi.SessionState) error {
817900
var err error
818901
if s.Email == "" {
@@ -951,7 +1034,7 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R
9511034
return session, nil
9521035
}
9531036

954-
if session == nil {
1037+
if session == nil || !session.Authenticated {
9551038
return nil, ErrNeedsLogin
9561039
}
9571040

oauthproxy_test.go

+43-34
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ func Test_redeemCode(t *testing.T) {
115115
}
116116

117117
req := httptest.NewRequest(http.MethodGet, "/", nil)
118-
_, err = proxy.redeemCode(req)
118+
_, err = proxy.redeemCode(req, "")
119119
assert.Equal(t, providers.ErrMissingCode, err)
120120
}
121121

@@ -242,7 +242,8 @@ func TestBasicAuthPassword(t *testing.T) {
242242
rw := httptest.NewRecorder()
243243
req, _ := http.NewRequest("GET", "/", nil)
244244
err = proxy.sessionStore.Save(rw, req, &sessions.SessionState{
245-
Email: emailAddress,
245+
Email: emailAddress,
246+
Authenticated: true,
246247
})
247248
assert.NoError(t, err)
248249

@@ -285,11 +286,12 @@ func TestPassGroupsHeadersWithGroups(t *testing.T) {
285286
groups := []string{"a", "b"}
286287
created := time.Now()
287288
session := &sessions.SessionState{
288-
User: userName,
289-
Groups: groups,
290-
Email: emailAddress,
291-
AccessToken: "oauth_token",
292-
CreatedAt: &created,
289+
User: userName,
290+
Groups: groups,
291+
Email: emailAddress,
292+
AccessToken: "oauth_token",
293+
CreatedAt: &created,
294+
Authenticated: true,
293295
}
294296

295297
proxy, err := NewOAuthProxy(opts, func(email string) bool {
@@ -930,28 +932,31 @@ func TestUserInfoEndpointAccepted(t *testing.T) {
930932
{
931933
name: "Full session",
932934
session: &sessions.SessionState{
933-
User: "john.doe",
934-
935-
Groups: []string{"example", "groups"},
936-
AccessToken: "my_access_token",
935+
User: "john.doe",
936+
937+
Groups: []string{"example", "groups"},
938+
AccessToken: "my_access_token",
939+
Authenticated: true,
937940
},
938941
expectedResponse: "{\"user\":\"john.doe\",\"email\":\"[email protected]\",\"groups\":[\"example\",\"groups\"]}\n",
939942
},
940943
{
941944
name: "Minimal session",
942945
session: &sessions.SessionState{
943-
User: "john.doe",
944-
945-
Groups: []string{"example", "groups"},
946+
User: "john.doe",
947+
948+
Groups: []string{"example", "groups"},
949+
Authenticated: true,
946950
},
947951
expectedResponse: "{\"user\":\"john.doe\",\"email\":\"[email protected]\",\"groups\":[\"example\",\"groups\"]}\n",
948952
},
949953
{
950954
name: "No groups",
951955
session: &sessions.SessionState{
952-
User: "john.doe",
953-
954-
AccessToken: "my_access_token",
956+
User: "john.doe",
957+
958+
AccessToken: "my_access_token",
959+
Authenticated: true,
955960
},
956961
expectedResponse: "{\"user\":\"john.doe\",\"email\":\"[email protected]\"}\n",
957962
},
@@ -963,6 +968,7 @@ func TestUserInfoEndpointAccepted(t *testing.T) {
963968
964969
Groups: []string{"example", "groups"},
965970
AccessToken: "my_access_token",
971+
Authenticated: true,
966972
},
967973
expectedResponse: "{\"user\":\"john.doe\",\"email\":\"[email protected]\",\"groups\":[\"example\",\"groups\"],\"preferredUsername\":\"john\"}\n",
968974
},
@@ -1024,7 +1030,7 @@ func TestAuthOnlyEndpointAccepted(t *testing.T) {
10241030

10251031
created := time.Now()
10261032
startSession := &sessions.SessionState{
1027-
Email: "[email protected]", AccessToken: "my_access_token", CreatedAt: &created}
1033+
Email: "[email protected]", AccessToken: "my_access_token", CreatedAt: &created, Authenticated: true}
10281034
err = test.SaveSession(startSession)
10291035
assert.NoError(t, err)
10301036

@@ -1154,7 +1160,7 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) {
11541160

11551161
created := time.Now()
11561162
startSession := &sessions.SessionState{
1157-
User: "oauth_user", Groups: []string{"oauth_groups"}, Email: "[email protected]", AccessToken: "oauth_token", CreatedAt: &created}
1163+
User: "oauth_user", Groups: []string{"oauth_groups"}, Email: "[email protected]", AccessToken: "oauth_token", CreatedAt: &created, Authenticated: true}
11581164
err = pcTest.SaveSession(startSession)
11591165
assert.NoError(t, err)
11601166

@@ -1247,7 +1253,7 @@ func TestAuthOnlyEndpointSetBasicAuthTrueRequestHeaders(t *testing.T) {
12471253

12481254
created := time.Now()
12491255
startSession := &sessions.SessionState{
1250-
User: "oauth_user", Email: "[email protected]", AccessToken: "oauth_token", CreatedAt: &created}
1256+
User: "oauth_user", Email: "[email protected]", AccessToken: "oauth_token", CreatedAt: &created, Authenticated: true}
12511257
err = pcTest.SaveSession(startSession)
12521258
assert.NoError(t, err)
12531259

@@ -1327,7 +1333,7 @@ func TestAuthOnlyEndpointSetBasicAuthFalseRequestHeaders(t *testing.T) {
13271333

13281334
created := time.Now()
13291335
startSession := &sessions.SessionState{
1330-
User: "oauth_user", Email: "[email protected]", AccessToken: "oauth_token", CreatedAt: &created}
1336+
User: "oauth_user", Email: "[email protected]", AccessToken: "oauth_token", CreatedAt: &created, Authenticated: true}
13311337
err = pcTest.SaveSession(startSession)
13321338
assert.NoError(t, err)
13331339

@@ -1502,7 +1508,7 @@ func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) er
15021508
req.Header = st.header
15031509

15041510
state := &sessions.SessionState{
1505-
Email: "[email protected]", AccessToken: "my_access_token"}
1511+
Email: "[email protected]", AccessToken: "my_access_token", Authenticated: true}
15061512
err = proxy.SaveSession(st.rw, req, state)
15071513
if err != nil {
15081514
return err
@@ -2451,10 +2457,11 @@ func TestProxyAllowedGroups(t *testing.T) {
24512457
created := time.Now()
24522458

24532459
session := &sessions.SessionState{
2454-
Groups: tt.groups,
2455-
Email: emailAddress,
2456-
AccessToken: "oauth_token",
2457-
CreatedAt: &created,
2460+
Groups: tt.groups,
2461+
Email: emailAddress,
2462+
AccessToken: "oauth_token",
2463+
CreatedAt: &created,
2464+
Authenticated: true,
24582465
}
24592466

24602467
upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -2587,10 +2594,11 @@ func TestAuthOnlyAllowedGroups(t *testing.T) {
25872594
created := time.Now()
25882595

25892596
session := &sessions.SessionState{
2590-
Groups: tc.groups,
2591-
Email: emailAddress,
2592-
AccessToken: "oauth_token",
2593-
CreatedAt: &created,
2597+
Groups: tc.groups,
2598+
Email: emailAddress,
2599+
AccessToken: "oauth_token",
2600+
CreatedAt: &created,
2601+
Authenticated: true,
25942602
}
25952603

25962604
test, err := NewAuthOnlyEndpointTest(tc.querystring, func(opts *options.Options) {
@@ -2683,10 +2691,11 @@ func TestAuthOnlyAllowedGroupsWithSkipMethods(t *testing.T) {
26832691
if tc.withSession {
26842692
created := time.Now()
26852693
session := &sessions.SessionState{
2686-
Groups: tc.groups,
2687-
Email: "test",
2688-
AccessToken: "oauth_token",
2689-
CreatedAt: &created,
2694+
Groups: tc.groups,
2695+
Email: "test",
2696+
AccessToken: "oauth_token",
2697+
CreatedAt: &created,
2698+
Authenticated: true,
26902699
}
26912700
err = test.SaveSession(session)
26922701
}

0 commit comments

Comments
 (0)