From 7cf8880eb1e002ce6fb90ffadae90b41452462e3 Mon Sep 17 00:00:00 2001 From: cmP Date: Mon, 7 Jan 2019 21:56:42 +0800 Subject: [PATCH] oauth2: add device flow support --- deviceauth.go | 79 +++++++++++++++++++++++++++++++++++++++++++++++++++ oauth2.go | 63 ++++++++++++++++++++++++++++++++++++++-- 2 files changed, 140 insertions(+), 2 deletions(-) create mode 100644 deviceauth.go diff --git a/deviceauth.go b/deviceauth.go new file mode 100644 index 000000000..edca0b6d5 --- /dev/null +++ b/deviceauth.go @@ -0,0 +1,79 @@ +package oauth2 + +import ( + "context" + "encoding/json" + "fmt" + "golang.org/x/net/context/ctxhttp" + "io" + "io/ioutil" + "net/http" + "net/url" + "strings" +) + +const ( + errAuthorizationPending = "authorization_pending" + errSlowDown = "slow_down" + errAccessDenied = "access_denied" + errExpiredToken = "expired_token" +) + +type DeviceAuth struct { + DeviceCode string `json:"device_code"` + UserCode string `json:"user_code"` + VerificationURI string `json:"verification_uri,verification_url"` + VerificationURIComplete string `json:"verification_uri_complete,omitempty"` + ExpiresIn int `json:"expires_in,string"` + Interval int `json:"interval,string,omitempty"` + raw map[string]interface{} +} + +func retrieveDeviceAuth(ctx context.Context, c *Config, v url.Values) (*DeviceAuth, error) { + req, err := http.NewRequest("POST", c.Endpoint.DeviceAuthURL, strings.NewReader(v.Encode())) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + r, err := ctxhttp.Do(ctx, nil, req) + if err != nil { + return nil, err + } + + body, err := ioutil.ReadAll(io.LimitReader(r.Body, 1<<20)) + if err != nil { + return nil, fmt.Errorf("oauth2: cannot auth device: %v", err) + } + if code := r.StatusCode; code < 200 || code > 299 { + return nil, &RetrieveError{ + Response: r, + Body: body, + } + } + + var da = &DeviceAuth{} + err = json.Unmarshal(body, &da) + if err != nil { + return nil, err + } + + _ = json.Unmarshal(body, &da.raw) + + // Azure AD supplies verification_url instead of verification_uri + if da.VerificationURI == "" { + da.VerificationURI, _ = da.raw["verification_url"].(string) + } + + return da, nil +} + +func parseError(err error) string { + e, ok := err.(*RetrieveError) + if ok { + eResp := make(map[string]string) + _ = json.Unmarshal(e.Body, &eResp) + return eResp["error"] + } + return "" +} diff --git a/oauth2.go b/oauth2.go index 1e8e1b741..14a9ba833 100644 --- a/oauth2.go +++ b/oauth2.go @@ -16,6 +16,7 @@ import ( "net/url" "strings" "sync" + "time" "golang.org/x/oauth2/internal" ) @@ -74,8 +75,9 @@ type TokenSource interface { // Endpoint contains the OAuth 2.0 provider's authorization and token // endpoint URLs. type Endpoint struct { - AuthURL string - TokenURL string + AuthURL string + DeviceAuthURL string + TokenURL string } var ( @@ -203,6 +205,63 @@ func (c *Config) Exchange(ctx context.Context, code string, opts ...AuthCodeOpti return retrieveToken(ctx, c, v) } +// AuthDevice returns a device auth struct which contains a device code +// and authorization information provided for users to enter on another device. +func (c *Config) AuthDevice(ctx context.Context, opts ...AuthCodeOption) (*DeviceAuth, error) { + v := url.Values{ + "client_id": {c.ClientID}, + } + if len(c.Scopes) > 0 { + v.Set("scope", strings.Join(c.Scopes, " ")) + } + for _, opt := range opts { + opt.setValue(v) + } + return retrieveDeviceAuth(ctx, c, v) +} + +// Poll does a polling to exchange an device code for a token. +func (c *Config) Poll(ctx context.Context, da *DeviceAuth, opts ...AuthCodeOption) (*Token, error) { + v := url.Values{ + "client_id": {c.ClientID}, + "grant_type": {"urn:ietf:params:oauth:grant-type:device_code"}, + "device_code": {da.DeviceCode}, + "code": {da.DeviceCode}, + } + if len(c.Scopes) > 0 { + v.Set("scope", strings.Join(c.Scopes, " ")) + } + for _, opt := range opts { + opt.setValue(v) + } + + // If no interval was provided, the client MUST use a reasonable default polling interval. + // See https://tools.ietf.org/html/draft-ietf-oauth-device-flow-07#section-3.5 + interval := da.Interval + if interval == 0 { + interval = 5 + } + + for { + time.Sleep(time.Duration(interval) * time.Second) + + tok, err := retrieveToken(ctx, c, v) + if err == nil { + return tok, nil + } + + errTyp := parseError(err) + switch errTyp { + case errAccessDenied, errExpiredToken: + return tok, errors.New("oauth2: " + errTyp) + case errSlowDown: + interval += 5 + fallthrough + case errAuthorizationPending: + } + } +} + // Client returns an HTTP client using the provided token. // The token will auto-refresh as necessary. The underlying // HTTP transport will be obtained using the provided context.