Skip to content

Commit ccd42c8

Browse files
committed
PKCE Support
Adds Code Challenge PKCE support (RFC-7636) and partial Authorization Server Metadata (RFC-8414) for detecting PKCE support. - Introduces new option `--force-code-challenge-method` to force a specific code challenge method (either `S256` or `plain`) for instances when the server has not implemented RFC-8414 in order to detect PKCE support on the discovery document. - In all other cases, if the PKCE support can be determined during discovery then the `code_challenge_methods_supported` is used and S256 is always preferred. - The force command line argument is helpful with some providers like Azure who supports PKCE but does not list it in their discovery document yet. - Initial thought was given to just always attempt PKCE since according to spec additional URL parameters should be dropped by servers which implemented OAuth 2, however other projects found cases in the wild where this causes 500 errors by buggy implementations. See: spring-projects/spring-security#7804 (comment) - Due to the fact that the `code_verifier` must be saved between the redirect and callback, sessions are now created when the redirect takes place with `Authenticated: false`. The session will be recreated and marked as `Authenticated` on callback. - Individual provider implementations can choose to include or ignore code_challenge and code_verifier function parameters passed to them Note: Technically speaking `plain` is not required to be implemented since oauth2-proxy will always be able to handle S256 and servers MUST implement S256 support. > If the client is capable of using "S256", it MUST use "S256", as "S256" > is Mandatory To Implement (MTI) on the server. Clients are permitted > to use "plain" only if they cannot support "S256" for some technical > reason and know via out-of-band configuration that the server supports > "plain". Ref: RFC-7636 Sec 4.2 oauth2-proxy will always use S256 unless the user explicitly forces `plain`. Fixes oauth2-proxy#1361
1 parent 88709d8 commit ccd42c8

24 files changed

+283
-47
lines changed

CHANGELOG.md

+4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
## Release Highlights
44

5+
- [#1361](https://github.com/oauth2-proxy/oauth2-proxy/issues/1361) PKCE Code Challenge Support - RFC-7636 (@braunsonm)
6+
- The `--force-code-challenge-method` flag can be used when the providers discovery document does not advertise PKCE support.
7+
- Parital support for OAuth2 Authorization Server Metadata for detecting code challenge methods (@braunsonm)
8+
59
## Important Notes
610

711
## Breaking Changes

docs/docs/configuration/overview.md

+1
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ An example [oauth2-proxy.cfg](https://github.com/oauth2-proxy/oauth2-proxy/blob/
104104
| `--flush-interval` | duration | period between flushing response buffers when streaming responses | `"1s"` |
105105
| `--force-https` | bool | enforce https redirect | `false` |
106106
| `--force-json-errors` | force JSON errors instead of HTTP error pages or redirects | `false` |
107+
| `--force-code-challenge-method` | will force PKCE code challenges with the specified method (if not automatically detected by the discovery document). Either 'plain' or 'S256' | |
107108
| `--banner` | string | custom (html) banner string. Use `"-"` to disable default banner. | |
108109
| `--footer` | string | custom (html) footer string. Use `"-"` to disable default footer. | |
109110
| `--github-org` | string | restrict logins to members of this organisation | |

oauthproxy.go

+82-4
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,13 @@ package main
22

33
import (
44
"context"
5+
"crypto/rand"
6+
"crypto/sha256"
7+
"encoding/base64"
58
"encoding/json"
69
"errors"
710
"fmt"
11+
"math/big"
812
"net"
913
"net/http"
1014
"net/url"
@@ -26,6 +30,7 @@ import (
2630
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/authentication/basic"
2731
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/cookies"
2832
proxyhttp "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/http"
33+
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/validation"
2934

3035
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/ip"
3136
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
@@ -274,6 +279,7 @@ func (p *OAuthProxy) buildServeMux(proxyPrefix string) {
274279
r := mux.NewRouter().UseEncodedPath()
275280
// Everything served by the router must go through the preAuthChain first.
276281
r.Use(p.preAuthChain.Then)
282+
r.Use(p.sessionChain.Then)
277283

278284
// Register the robots path writer
279285
r.Path(robotsPath).HandlerFunc(p.pageWriter.WriteRobotsTxt)
@@ -681,11 +687,31 @@ func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) {
681687
return
682688
}
683689

690+
var codeChallenge, codeVerifier, codeChallengeMethod string
691+
if p.provider.Data().CodeChallengeMethod != "" {
692+
codeChallengeMethod = p.provider.Data().CodeChallengeMethod
693+
codeVerifier, err = generateRandomString(128)
694+
if err != nil {
695+
logger.Errorf("Unable to build random string: %v", err)
696+
p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error())
697+
return
698+
}
699+
700+
codeChallenge, err = generateCodeChallenge(p.provider.Data().CodeChallengeMethod, codeVerifier)
701+
if err != nil {
702+
logger.Errorf("Error creating code challenge: %v", err)
703+
p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error())
704+
return
705+
}
706+
}
707+
684708
callbackRedirect := p.getOAuthRedirectURI(req)
685709
loginURL := p.provider.GetLoginURL(
686710
callbackRedirect,
687711
encodeState(csrf.HashOAuthState(), appRedirect),
688712
csrf.HashOIDCNonce(),
713+
codeChallenge,
714+
codeChallengeMethod,
689715
)
690716

691717
if _, err := csrf.SetCookie(rw, req); err != nil {
@@ -694,6 +720,16 @@ func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) {
694720
return
695721
}
696722

723+
// A session is created with `Authenticated: false` to store the code verifier
724+
// for token redemption
725+
session := &sessionsapi.SessionState{CodeVerifier: codeVerifier}
726+
err = p.SaveSession(rw, req, session)
727+
if err != nil {
728+
logger.Errorf("Error saving session state for %s: %v", req.RemoteAddr, err)
729+
p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error())
730+
return
731+
}
732+
697733
http.Redirect(rw, req, loginURL, http.StatusFound)
698734
}
699735

@@ -717,8 +753,17 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
717753
p.ErrorPage(rw, req, http.StatusForbidden, message, message)
718754
return
719755
}
756+
session := middlewareapi.GetRequestScope(req).Session
757+
if session == nil {
758+
logger.Errorf("Error retrieving session containing code verifier: %v", "")
759+
if p.provider.Data().CodeChallengeMethod != "" {
760+
// Only error the whole callback if code verifier is required
761+
p.ErrorPage(rw, req, http.StatusInternalServerError, "")
762+
return
763+
}
764+
}
720765

721-
session, err := p.redeemCode(req)
766+
session, err = p.redeemCode(req, session.CodeVerifier)
722767
if err != nil {
723768
logger.Errorf("Error redeeming code during OAuth2 callback: %v", err)
724769
p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error())
@@ -785,14 +830,14 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
785830
}
786831
}
787832

788-
func (p *OAuthProxy) redeemCode(req *http.Request) (*sessionsapi.SessionState, error) {
833+
func (p *OAuthProxy) redeemCode(req *http.Request, codeVerifier string) (*sessionsapi.SessionState, error) {
789834
code := req.Form.Get("code")
790835
if code == "" {
791836
return nil, providers.ErrMissingCode
792837
}
793838

794839
redirectURI := p.getOAuthRedirectURI(req)
795-
s, err := p.provider.Redeem(req.Context(), redirectURI, code)
840+
s, err := p.provider.Redeem(req.Context(), redirectURI, code, codeVerifier)
796841
if err != nil {
797842
return nil, err
798843
}
@@ -804,10 +849,43 @@ func (p *OAuthProxy) redeemCode(req *http.Request) (*sessionsapi.SessionState, e
804849
if s.ExpiresOn == nil {
805850
s.ExpiresIn(p.CookieOptions.Expire)
806851
}
852+
s.Authenticated = true
807853

808854
return s, nil
809855
}
810856

857+
func generateCodeChallenge(method, codeVerifier string) (string, error) {
858+
switch method {
859+
case validation.CodeChallengeMethodPlain:
860+
return codeVerifier, nil
861+
case validation.CodeChallengeMethodS256:
862+
shaSum := sha256.Sum256([]byte(codeVerifier))
863+
return base64.RawURLEncoding.EncodeToString(shaSum[:]), nil
864+
default:
865+
return "", fmt.Errorf("unknown challenge method: %v", method)
866+
}
867+
}
868+
869+
const runes string = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789.-_~"
870+
871+
// generateRandomString returns a securely generated random ASCII string.
872+
// It reads random numbers from crypto/rand and searches for printable characters.
873+
// It will return an error if the system's secure random number generator fails to
874+
// function correctly, in which case the caller must not continue.
875+
// From: https://gist.github.com/dopey/c69559607800d2f2f90b1b1ed4e550fb
876+
func generateRandomString(n int) (string, error) {
877+
ret := make([]byte, n)
878+
for i := 0; i < n; i++ {
879+
num, err := rand.Int(rand.Reader, big.NewInt(int64(len(runes))))
880+
if err != nil {
881+
return "", err
882+
}
883+
ret[i] = runes[num.Int64()]
884+
}
885+
886+
return string(ret), nil
887+
}
888+
811889
func (p *OAuthProxy) enrichSessionState(ctx context.Context, s *sessionsapi.SessionState) error {
812890
var err error
813891
if s.Email == "" {
@@ -946,7 +1024,7 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R
9461024
return session, nil
9471025
}
9481026

949-
if session == nil {
1027+
if session == nil || !session.Authenticated {
9501028
return nil, ErrNeedsLogin
9511029
}
9521030

oauthproxy_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ func Test_redeemCode(t *testing.T) {
114114
}
115115

116116
req := httptest.NewRequest(http.MethodGet, "/", nil)
117-
_, err = proxy.redeemCode(req)
117+
_, err = proxy.redeemCode(req, "")
118118
assert.Equal(t, providers.ErrMissingCode, err)
119119
}
120120

pkg/apis/options/options.go

+3
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ type Options struct {
5959
SSLInsecureSkipVerify bool `flag:"ssl-insecure-skip-verify" cfg:"ssl_insecure_skip_verify"`
6060
SkipAuthPreflight bool `flag:"skip-auth-preflight" cfg:"skip_auth_preflight"`
6161
ForceJSONErrors bool `flag:"force-json-errors" cfg:"force_json_errors"`
62+
// Force PKCE Code Challenges if method is not detected via RFC-8414 metadata document
63+
ForceCodeChallengeMethod string `flag:"force-code-challenge-method" cfg:"force_code_challenge_method"`
6264

6365
SignatureKey string `flag:"signature-key" cfg:"signature_key"`
6466
GCPHealthChecks bool `flag:"gcp-healthchecks" cfg:"gcp_healthchecks"`
@@ -123,6 +125,7 @@ func NewFlagSet() *pflag.FlagSet {
123125
flagSet.Bool("ssl-insecure-skip-verify", false, "skip validation of certificates presented when using HTTPS providers")
124126
flagSet.Bool("skip-jwt-bearer-tokens", false, "will skip requests that have verified JWT bearer tokens (default false)")
125127
flagSet.Bool("force-json-errors", false, "will force JSON errors instead of HTTP error pages or redirects")
128+
flagSet.String("force-code-challenge-method", "", "will force PKCE code challenges with the specified method. Either 'plain' or 'S256'")
126129
flagSet.StringSlice("extra-jwt-issuers", []string{}, "if skip-jwt-bearer-tokens is set, a list of extra JWT issuer=audience pairs (where the issuer URL has a .well-known/openid-configuration or a .well-known/jwks.json)")
127130

128131
flagSet.StringSlice("email-domain", []string{}, "authenticate emails with the specified domain (may be given multiple times). Use * to authenticate any email")

pkg/apis/options/providers.go

+2
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ type Provider struct {
7070
ApprovalPrompt string `json:"approvalPrompt,omitempty"`
7171
// AllowedGroups is a list of restrict logins to members of this group
7272
AllowedGroups []string `json:"allowedGroups,omitempty"`
73+
// Code challenge methods supported by the Provider
74+
CodeChallengeMethods []string `json:"code_challenge_methods_supported,omitempty"`
7375

7476
// AcrValues is a string of acr values
7577
AcrValues string `json:"acrValues,omitempty"`

pkg/apis/sessions/session_state.go

+4-2
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@ import (
1616

1717
// SessionState is used to store information about the currently authenticated user session
1818
type SessionState struct {
19-
CreatedAt *time.Time `msgpack:"ca,omitempty"`
20-
ExpiresOn *time.Time `msgpack:"eo,omitempty"`
19+
Authenticated bool `msgpack:"au,omitempty"`
20+
CreatedAt *time.Time `msgpack:"ca,omitempty"`
21+
ExpiresOn *time.Time `msgpack:"eo,omitempty"`
2122

2223
AccessToken string `msgpack:"at,omitempty"`
2324
IDToken string `msgpack:"it,omitempty"`
@@ -29,6 +30,7 @@ type SessionState struct {
2930
User string `msgpack:"u,omitempty"`
3031
Groups []string `msgpack:"g,omitempty"`
3132
PreferredUsername string `msgpack:"pu,omitempty"`
33+
CodeVerifier string `msgpack:"cv,omitempty"`
3234

3335
// Internal helpers, not serialized
3436
Clock clock.Clock `msgpack:"-"`

pkg/validation/options.go

+38
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@ import (
2121
"github.com/oauth2-proxy/oauth2-proxy/v7/providers"
2222
)
2323

24+
const (
25+
CodeChallengeMethodPlain = "plain"
26+
CodeChallengeMethodS256 = "S256"
27+
)
28+
2429
// Validate checks that required options are set and validates those that they
2530
// are of the correct format
2631
func Validate(o *options.Options) error {
@@ -132,6 +137,16 @@ func Validate(o *options.Options) error {
132137
SkipIssuerCheck: o.Providers[0].OIDCConfig.InsecureSkipIssuerVerification,
133138
}))
134139

140+
type DiscoveryClaims struct {
141+
// RFC-8414 Authorization Server Metadata
142+
CodeChallengeMethods []string `json:"code_challenge_methods_supported"`
143+
}
144+
var claims DiscoveryClaims
145+
if err := provider.Claims(&claims); err != nil {
146+
logger.Errorf("error: failed to parse additional OIDC discovery claims: %v, PKCE must be force enabled to be used.", err)
147+
}
148+
149+
o.Providers[0].CodeChallengeMethods = claims.CodeChallengeMethods
135150
o.Providers[0].LoginURL = provider.Endpoint().AuthURL
136151
o.Providers[0].RedeemURL = provider.Endpoint().TokenURL
137152
}
@@ -210,6 +225,7 @@ func parseProviderInfo(o *options.Options, msgs []string) []string {
210225
p.ProfileURL, msgs = parseURL(o.Providers[0].ProfileURL, "profile", msgs)
211226
p.ValidateURL, msgs = parseURL(o.Providers[0].ValidateURL, "validate", msgs)
212227
p.ProtectedResource, msgs = parseURL(o.Providers[0].ProtectedResource, "resource", msgs)
228+
p.CodeChallengeMethod = parseCodeChallengeMethod(o)
213229

214230
// Make the OIDC options available to all providers that support it
215231
p.AllowUnverifiedEmail = o.Providers[0].OIDCConfig.InsecureAllowUnverifiedEmail
@@ -332,6 +348,28 @@ func parseProviderInfo(o *options.Options, msgs []string) []string {
332348
return msgs
333349
}
334350

351+
func stringInSlice(element string, list []string) bool {
352+
for _, x := range list {
353+
if x == element {
354+
return true
355+
}
356+
}
357+
return false
358+
}
359+
360+
// Pick the most appropriate code challenge method for PKCE
361+
func parseCodeChallengeMethod(o *options.Options) string {
362+
switch {
363+
case o.ForceCodeChallengeMethod != "":
364+
return o.ForceCodeChallengeMethod
365+
case o.Providers[0].CodeChallengeMethods == nil:
366+
return ""
367+
case stringInSlice(CodeChallengeMethodS256, o.Providers[0].CodeChallengeMethods):
368+
return CodeChallengeMethodS256
369+
}
370+
return CodeChallengeMethodPlain
371+
}
372+
335373
func parseSignatureKey(o *options.Options, msgs []string) []string {
336374
if o.SignatureKey == "" {
337375
return msgs

pkg/validation/options_test.go

+33
Original file line numberDiff line numberDiff line change
@@ -302,3 +302,36 @@ func TestProviderCAFilesError(t *testing.T) {
302302
assert.Error(t, err)
303303
assert.Contains(t, err.Error(), "unable to load provider CA file(s)")
304304
}
305+
306+
func TestForcedMethodS256(t *testing.T) {
307+
options := testOptions()
308+
options.ForceCodeChallengeMethod = CodeChallengeMethodS256
309+
method := parseCodeChallengeMethod(options)
310+
311+
assert.Equal(t, CodeChallengeMethodS256, method)
312+
}
313+
314+
func TestForcedMethodPlain(t *testing.T) {
315+
options := testOptions()
316+
options.ForceCodeChallengeMethod = CodeChallengeMethodPlain
317+
method := parseCodeChallengeMethod(options)
318+
319+
assert.Equal(t, CodeChallengeMethodPlain, method)
320+
}
321+
322+
func TestPrefersS256(t *testing.T) {
323+
options := testOptions()
324+
options.Providers[0].CodeChallengeMethods = []string{"plain", "S256"}
325+
method := parseCodeChallengeMethod(options)
326+
327+
assert.Equal(t, CodeChallengeMethodS256, method)
328+
}
329+
330+
func TestCanOverwriteS256(t *testing.T) {
331+
options := testOptions()
332+
options.Providers[0].CodeChallengeMethods = []string{"plain", "S256"}
333+
options.ForceCodeChallengeMethod = "plain"
334+
method := parseCodeChallengeMethod(options)
335+
336+
assert.Equal(t, CodeChallengeMethodPlain, method)
337+
}

providers/adfs.go

+5-1
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,15 @@ func (p *ADFSProvider) Configure(skipScope bool) {
6666

6767
// GetLoginURL Override to double encode the state parameter. If not query params are lost
6868
// More info here: https://docs.microsoft.com/en-us/powerapps/maker/portals/configure/configure-saml2-settings
69-
func (p *ADFSProvider) GetLoginURL(redirectURI, state, nonce string) string {
69+
func (p *ADFSProvider) GetLoginURL(redirectURI, state, nonce, codeChallenge, codeChallengeMethod string) string {
7070
extraParams := url.Values{}
7171
if !p.SkipNonce {
7272
extraParams.Add("nonce", nonce)
7373
}
74+
if codeChallenge != "" && codeChallengeMethod != "" {
75+
extraParams.Add("code_challenge", codeChallenge)
76+
extraParams.Add("code_challenge_method", codeChallengeMethod)
77+
}
7478
loginURL := makeLoginURL(p.Data(), redirectURI, url.QueryEscape(state), extraParams)
7579
if p.skipScope {
7680
q := loginURL.Query()

providers/adfs_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ var _ = Describe("ADFS Provider Tests", func() {
164164
})
165165
p.skipScope = true
166166

167-
result := p.GetLoginURL("https://example.com/adfs/oauth2/", "", "")
167+
result := p.GetLoginURL("https://example.com/adfs/oauth2/", "", "", "", "")
168168
Expect(result).NotTo(ContainSubstring("scope="))
169169
})
170170
})
@@ -185,7 +185,7 @@ var _ = Describe("ADFS Provider Tests", func() {
185185
})
186186

187187
Expect(p.Data().Scope).To(Equal(in.expectedScope))
188-
result := p.GetLoginURL("https://example.com/adfs/oauth2/", "", "")
188+
result := p.GetLoginURL("https://example.com/adfs/oauth2/", "", "", "", "")
189189
Expect(result).To(ContainSubstring("scope=" + url.QueryEscape(in.expectedScope)))
190190
},
191191
Entry("should add slash", scopeTableInput{

0 commit comments

Comments
 (0)