Skip to content

Commit c0aa320

Browse files
committed
device flow
1 parent cbc7e73 commit c0aa320

File tree

5 files changed

+244
-9
lines changed

5 files changed

+244
-9
lines changed

deviceauth.go

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
package oauth2
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"io"
8+
"net/http"
9+
"net/url"
10+
"strings"
11+
"time"
12+
13+
"golang.org/x/oauth2/internal"
14+
)
15+
16+
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.5
17+
const (
18+
errAuthorizationPending = "authorization_pending"
19+
errSlowDown = "slow_down"
20+
errAccessDenied = "access_denied"
21+
errExpiredToken = "expired_token"
22+
)
23+
24+
// DeviceAuthResponse https://datatracker.ietf.org/doc/html/rfc8628#section-3.2
25+
type DeviceAuthResponse struct {
26+
DeviceCode string `json:"device_code"`
27+
UserCode string `json:"user_code"`
28+
VerificationURI string `json:"verification_uri"`
29+
VerificationURIComplete string `json:"verification_uri_complete,omitempty"`
30+
Expiry time.Time `json:"expires_in,omitempty"`
31+
Interval int64 `json:"interval,omitempty"`
32+
}
33+
34+
func (d DeviceAuthResponse) MarshalJSON() ([]byte, error) {
35+
type Alias DeviceAuthResponse
36+
var expiresIn int64
37+
if !d.Expiry.IsZero() {
38+
expiresIn = int64(time.Until(d.Expiry).Seconds())
39+
}
40+
return json.Marshal(&struct {
41+
ExpiresIn int64 `json:"expires_in,omitempty"`
42+
*Alias
43+
}{
44+
ExpiresIn: expiresIn,
45+
Alias: (*Alias)(&d),
46+
})
47+
48+
}
49+
50+
func (c *DeviceAuthResponse) UnmarshalJSON(data []byte) error {
51+
type Alias DeviceAuthResponse
52+
aux := &struct {
53+
ExpiresIn int64 `json:"expires_in"`
54+
*Alias
55+
}{
56+
Alias: (*Alias)(c),
57+
}
58+
if err := json.Unmarshal(data, &aux); err != nil {
59+
return err
60+
}
61+
if aux.ExpiresIn != 0 {
62+
c.Expiry = time.Now().UTC().Add(time.Second * time.Duration(aux.ExpiresIn))
63+
}
64+
return nil
65+
}
66+
67+
// DeviceAuth returns a device auth struct which contains a device code
68+
// and authorization information provided for users to enter on another device.
69+
func (c *Config) DeviceAuth(ctx context.Context, opts ...AuthCodeOption) (*DeviceAuthResponse, error) {
70+
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.1
71+
v := url.Values{
72+
"client_id": {c.ClientID},
73+
}
74+
if len(c.Scopes) > 0 {
75+
v.Set("scope", strings.Join(c.Scopes, " "))
76+
}
77+
for _, opt := range opts {
78+
opt.setValue(v)
79+
}
80+
return retrieveDeviceAuth(ctx, c, v)
81+
}
82+
83+
func retrieveDeviceAuth(ctx context.Context, c *Config, v url.Values) (*DeviceAuthResponse, error) {
84+
req, err := http.NewRequest("POST", c.Endpoint.DeviceAuthURL, strings.NewReader(v.Encode()))
85+
if err != nil {
86+
return nil, err
87+
}
88+
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
89+
req.Header.Set("Accept", "application/json")
90+
91+
r, err := internal.ContextClient(ctx).Do(req)
92+
if err != nil {
93+
return nil, err
94+
}
95+
96+
body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
97+
if err != nil {
98+
return nil, fmt.Errorf("oauth2: cannot auth device: %v", err)
99+
}
100+
if code := r.StatusCode; code < 200 || code > 299 {
101+
return nil, &RetrieveError{
102+
Response: r,
103+
Body: body,
104+
}
105+
}
106+
107+
da := &DeviceAuthResponse{}
108+
err = json.Unmarshal(body, &da)
109+
if err != nil {
110+
return nil, fmt.Errorf("unmarshal %s", err)
111+
}
112+
return da, nil
113+
}
114+
115+
// Poll tries to exchange an device code for a token.
116+
func (c *Config) Poll(ctx context.Context, da *DeviceAuthResponse, opts ...AuthCodeOption) (*Token, error) {
117+
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.4
118+
v := url.Values{
119+
"client_id": {c.ClientID},
120+
"grant_type": {"urn:ietf:params:oauth:grant-type:device_code"},
121+
"device_code": {da.DeviceCode},
122+
}
123+
if len(c.Scopes) > 0 {
124+
v.Set("scope", strings.Join(c.Scopes, " "))
125+
}
126+
for _, opt := range opts {
127+
opt.setValue(v)
128+
}
129+
130+
// If no interval was provided, the client MUST use a reasonable default polling interval.
131+
// See https://tools.ietf.org/html/draft-ietf-oauth-device-flow-07#section-3.5
132+
interval := da.Interval
133+
if interval == 0 {
134+
interval = 5
135+
}
136+
137+
for {
138+
time.Sleep(time.Duration(interval) * time.Second)
139+
140+
tok, err := retrieveToken(ctx, c, v)
141+
if err == nil {
142+
return tok, nil
143+
}
144+
145+
e, ok := err.(*RetrieveError)
146+
if !ok {
147+
return nil, err
148+
}
149+
switch e.ErrorCode {
150+
case errSlowDown:
151+
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.5
152+
// "the interval MUST be increased by 5 seconds for this and all subsequent requests"
153+
interval += 5
154+
case errAuthorizationPending:
155+
// Do nothing.
156+
case errAccessDenied, errExpiredToken:
157+
fallthrough
158+
default:
159+
return tok, err
160+
}
161+
}
162+
}

deviceauth_test.go

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
package oauth2
2+
3+
import (
4+
"encoding/json"
5+
"testing"
6+
"time"
7+
8+
"github.com/google/go-cmp/cmp"
9+
"github.com/google/go-cmp/cmp/cmpopts"
10+
)
11+
12+
func TestDeviceAuthResponseMarshalJson(t *testing.T) {
13+
tests := []struct {
14+
name string
15+
response DeviceAuthResponse
16+
want string
17+
}{
18+
{
19+
name: "empty",
20+
response: DeviceAuthResponse{},
21+
want: `{"device_code":"","user_code":"","verification_uri":""}`,
22+
},
23+
{
24+
name: "soon",
25+
response: DeviceAuthResponse{
26+
Expiry: time.Now().Add(100*time.Second + 999*time.Millisecond),
27+
},
28+
want: `{"expires_in":100,"device_code":"","user_code":"","verification_uri":""}`,
29+
},
30+
}
31+
for _, tc := range tests {
32+
t.Run(tc.name, func(t *testing.T) {
33+
gotBytes, err := json.Marshal(tc.response)
34+
if err != nil {
35+
t.Fatal(err)
36+
}
37+
got := string(gotBytes)
38+
if got != tc.want {
39+
t.Errorf("want=%s, got=%s", tc.want, got)
40+
}
41+
})
42+
}
43+
}
44+
45+
func TestDeviceAuthResponseUnmarshalJson(t *testing.T) {
46+
tests := []struct {
47+
name string
48+
data string
49+
want DeviceAuthResponse
50+
}{
51+
{
52+
name: "empty",
53+
data: `{}`,
54+
want: DeviceAuthResponse{},
55+
},
56+
{
57+
name: "soon",
58+
data: `{"expires_in":100}`,
59+
want: DeviceAuthResponse{Expiry: time.Now().UTC().Add(100 * time.Second)},
60+
},
61+
}
62+
for _, tc := range tests {
63+
t.Run(tc.name, func(t *testing.T) {
64+
got := DeviceAuthResponse{}
65+
err := json.Unmarshal([]byte(tc.data), &got)
66+
if err != nil {
67+
t.Fatal(err)
68+
}
69+
if !cmp.Equal(got, tc.want, cmpopts.IgnoreUnexported(DeviceAuthResponse{}), cmpopts.EquateApproxTime(time.Second)) {
70+
t.Errorf("want=%#v, got=%#v", tc.want, got)
71+
}
72+
})
73+
}
74+
}

endpoints/endpoints.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,9 @@ var Fitbit = oauth2.Endpoint{
5555

5656
// GitHub is the endpoint for Github.
5757
var GitHub = oauth2.Endpoint{
58-
AuthURL: "https://github.com/login/oauth/authorize",
59-
TokenURL: "https://github.com/login/oauth/access_token",
58+
AuthURL: "https://github.com/login/oauth/authorize",
59+
TokenURL: "https://github.com/login/oauth/access_token",
60+
DeviceAuthURL: "https://github.com/login/device/code",
6061
}
6162

6263
// GitLab is the endpoint for GitLab.

github/github.go

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,8 @@
66
package github // import "golang.org/x/oauth2/github"
77

88
import (
9-
"golang.org/x/oauth2"
9+
"golang.org/x/oauth2/endpoints"
1010
)
1111

1212
// Endpoint is Github's OAuth 2.0 endpoint.
13-
var Endpoint = oauth2.Endpoint{
14-
AuthURL: "https://github.com/login/oauth/authorize",
15-
TokenURL: "https://github.com/login/oauth/access_token",
16-
}
13+
var Endpoint = endpoints.GitHub

oauth2.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,9 @@ type TokenSource interface {
7070
// Endpoint represents an OAuth 2.0 provider's authorization and token
7171
// endpoint URLs.
7272
type Endpoint struct {
73-
AuthURL string
74-
TokenURL string
73+
AuthURL string
74+
DeviceAuthURL string
75+
TokenURL string
7576

7677
// AuthStyle optionally specifies how the endpoint wants the
7778
// client ID & client secret sent. The zero value means to

0 commit comments

Comments
 (0)