Skip to content

Commit 79f7325

Browse files
committed
device flow
1 parent ac6658e commit 79f7325

File tree

6 files changed

+281
-9
lines changed

6 files changed

+281
-9
lines changed

deviceauth.go

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
package oauth2
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"io"
8+
"net/http"
9+
"net/url"
10+
"strings"
11+
"time"
12+
13+
"golang.org/x/oauth2/internal"
14+
)
15+
16+
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.5
17+
const (
18+
errAuthorizationPending = "authorization_pending"
19+
errSlowDown = "slow_down"
20+
errAccessDenied = "access_denied"
21+
errExpiredToken = "expired_token"
22+
)
23+
24+
// DeviceAuthResponse describes a successful RFC 8628 Device Authorization Response
25+
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.2
26+
type DeviceAuthResponse struct {
27+
// DeviceCode
28+
DeviceCode string `json:"device_code"`
29+
// UserCode is the code the user should enter at the verification uri
30+
UserCode string `json:"user_code"`
31+
// VerificationURI is where user should enter the user code
32+
VerificationURI string `json:"verification_uri"`
33+
// VerificationURIComplete (if populated) includes the user code in the verification URI. This is typically shown to the user in non-textual form, such as a QR code.
34+
VerificationURIComplete string `json:"verification_uri_complete,omitempty"`
35+
// Expiry is when the device code and user code expire
36+
Expiry time.Time `json:"expires_in,omitempty"`
37+
// Interval is the duration in seconds that Poll should wait between requests
38+
Interval int64 `json:"interval,omitempty"`
39+
}
40+
41+
func (d DeviceAuthResponse) MarshalJSON() ([]byte, error) {
42+
type Alias DeviceAuthResponse
43+
var expiresIn int64
44+
if !d.Expiry.IsZero() {
45+
expiresIn = int64(time.Until(d.Expiry).Seconds())
46+
}
47+
return json.Marshal(&struct {
48+
ExpiresIn int64 `json:"expires_in,omitempty"`
49+
*Alias
50+
}{
51+
ExpiresIn: expiresIn,
52+
Alias: (*Alias)(&d),
53+
})
54+
55+
}
56+
57+
func (c *DeviceAuthResponse) UnmarshalJSON(data []byte) error {
58+
type Alias DeviceAuthResponse
59+
aux := &struct {
60+
ExpiresIn int64 `json:"expires_in"`
61+
*Alias
62+
}{
63+
Alias: (*Alias)(c),
64+
}
65+
if err := json.Unmarshal(data, &aux); err != nil {
66+
return err
67+
}
68+
if aux.ExpiresIn != 0 {
69+
c.Expiry = time.Now().UTC().Add(time.Second * time.Duration(aux.ExpiresIn))
70+
}
71+
return nil
72+
}
73+
74+
// DeviceAuth returns a device auth struct which contains a device code
75+
// and authorization information provided for users to enter on another device.
76+
func (c *Config) DeviceAuth(ctx context.Context, opts ...AuthCodeOption) (*DeviceAuthResponse, error) {
77+
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.1
78+
v := url.Values{
79+
"client_id": {c.ClientID},
80+
}
81+
if len(c.Scopes) > 0 {
82+
v.Set("scope", strings.Join(c.Scopes, " "))
83+
}
84+
for _, opt := range opts {
85+
opt.setValue(v)
86+
}
87+
return retrieveDeviceAuth(ctx, c, v)
88+
}
89+
90+
func retrieveDeviceAuth(ctx context.Context, c *Config, v url.Values) (*DeviceAuthResponse, error) {
91+
req, err := http.NewRequest("POST", c.Endpoint.DeviceAuthURL, strings.NewReader(v.Encode()))
92+
if err != nil {
93+
return nil, err
94+
}
95+
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
96+
req.Header.Set("Accept", "application/json")
97+
98+
r, err := internal.ContextClient(ctx).Do(req)
99+
if err != nil {
100+
return nil, err
101+
}
102+
103+
body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
104+
if err != nil {
105+
return nil, fmt.Errorf("oauth2: cannot auth device: %v", err)
106+
}
107+
if code := r.StatusCode; code < 200 || code > 299 {
108+
return nil, &RetrieveError{
109+
Response: r,
110+
Body: body,
111+
}
112+
}
113+
114+
da := &DeviceAuthResponse{}
115+
err = json.Unmarshal(body, &da)
116+
if err != nil {
117+
return nil, fmt.Errorf("unmarshal %s", err)
118+
}
119+
return da, nil
120+
}
121+
122+
// DeviceAccessToken polls the server to exchange an device code for a token.
123+
func (c *Config) DeviceAccessToken(ctx context.Context, da *DeviceAuthResponse, opts ...AuthCodeOption) (*Token, error) {
124+
if !da.Expiry.IsZero() {
125+
var cancel context.CancelFunc
126+
ctx, cancel = context.WithDeadline(ctx, da.Expiry)
127+
defer cancel()
128+
}
129+
130+
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.4
131+
v := url.Values{
132+
"client_id": {c.ClientID},
133+
"grant_type": {"urn:ietf:params:oauth:grant-type:device_code"},
134+
"device_code": {da.DeviceCode},
135+
}
136+
if len(c.Scopes) > 0 {
137+
v.Set("scope", strings.Join(c.Scopes, " "))
138+
}
139+
for _, opt := range opts {
140+
opt.setValue(v)
141+
}
142+
143+
// "If no value is provided, clients MUST use 5 as the default."
144+
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.2
145+
interval := da.Interval
146+
if interval == 0 {
147+
interval = 5
148+
}
149+
150+
ticker := time.NewTicker(time.Duration(interval) * time.Second)
151+
for {
152+
select {
153+
case <-ctx.Done():
154+
return nil, ctx.Err()
155+
case <-ticker.C:
156+
tok, err := retrieveToken(ctx, c, v)
157+
if err == nil {
158+
return tok, nil
159+
}
160+
161+
e, ok := err.(*RetrieveError)
162+
if !ok {
163+
return nil, err
164+
}
165+
switch e.ErrorCode {
166+
case errSlowDown:
167+
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.5
168+
// "the interval MUST be increased by 5 seconds for this and all subsequent requests"
169+
interval += 5
170+
ticker.Reset(time.Duration(interval) * time.Second)
171+
case errAuthorizationPending:
172+
// Do nothing.
173+
case errAccessDenied, errExpiredToken:
174+
fallthrough
175+
default:
176+
return tok, err
177+
}
178+
}
179+
}
180+
}

deviceauth_test.go

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
package oauth2
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"testing"
8+
"time"
9+
10+
"github.com/google/go-cmp/cmp"
11+
"github.com/google/go-cmp/cmp/cmpopts"
12+
)
13+
14+
func TestDeviceAuthResponseMarshalJson(t *testing.T) {
15+
tests := []struct {
16+
name string
17+
response DeviceAuthResponse
18+
want string
19+
}{
20+
{
21+
name: "empty",
22+
response: DeviceAuthResponse{},
23+
want: `{"device_code":"","user_code":"","verification_uri":""}`,
24+
},
25+
{
26+
name: "soon",
27+
response: DeviceAuthResponse{
28+
Expiry: time.Now().Add(100*time.Second + 999*time.Millisecond),
29+
},
30+
want: `{"expires_in":100,"device_code":"","user_code":"","verification_uri":""}`,
31+
},
32+
}
33+
for _, tc := range tests {
34+
t.Run(tc.name, func(t *testing.T) {
35+
gotBytes, err := json.Marshal(tc.response)
36+
if err != nil {
37+
t.Fatal(err)
38+
}
39+
got := string(gotBytes)
40+
if got != tc.want {
41+
t.Errorf("want=%s, got=%s", tc.want, got)
42+
}
43+
})
44+
}
45+
}
46+
47+
func TestDeviceAuthResponseUnmarshalJson(t *testing.T) {
48+
tests := []struct {
49+
name string
50+
data string
51+
want DeviceAuthResponse
52+
}{
53+
{
54+
name: "empty",
55+
data: `{}`,
56+
want: DeviceAuthResponse{},
57+
},
58+
{
59+
name: "soon",
60+
data: `{"expires_in":100}`,
61+
want: DeviceAuthResponse{Expiry: time.Now().UTC().Add(100 * time.Second)},
62+
},
63+
}
64+
for _, tc := range tests {
65+
t.Run(tc.name, func(t *testing.T) {
66+
got := DeviceAuthResponse{}
67+
err := json.Unmarshal([]byte(tc.data), &got)
68+
if err != nil {
69+
t.Fatal(err)
70+
}
71+
if !cmp.Equal(got, tc.want, cmpopts.IgnoreUnexported(DeviceAuthResponse{}), cmpopts.EquateApproxTime(time.Second)) {
72+
t.Errorf("want=%#v, got=%#v", tc.want, got)
73+
}
74+
})
75+
}
76+
}
77+
78+
func ExampleConfig_DeviceAuth() {
79+
var config Config
80+
ctx := context.Background()
81+
response, err := config.DeviceAuth(ctx)
82+
if err != nil {
83+
panic(err)
84+
}
85+
fmt.Printf("please enter code %s at %s\n", response.UserCode, response.VerificationURI)
86+
token, err := config.DeviceAccessToken(ctx, response)
87+
if err != nil {
88+
panic(err)
89+
}
90+
fmt.Println(token)
91+
}

endpoints/endpoints.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,9 @@ var Fitbit = oauth2.Endpoint{
5555

5656
// GitHub is the endpoint for Github.
5757
var GitHub = oauth2.Endpoint{
58-
AuthURL: "https://github.com/login/oauth/authorize",
59-
TokenURL: "https://github.com/login/oauth/access_token",
58+
AuthURL: "https://github.com/login/oauth/authorize",
59+
TokenURL: "https://github.com/login/oauth/access_token",
60+
DeviceAuthURL: "https://github.com/login/device/code",
6061
}
6162

6263
// GitLab is the endpoint for GitLab.
@@ -69,6 +70,7 @@ var GitLab = oauth2.Endpoint{
6970
var Google = oauth2.Endpoint{
7071
AuthURL: "https://accounts.google.com/o/oauth2/auth",
7172
TokenURL: "https://oauth2.googleapis.com/token",
73+
DeviceAuthURL: "https://oauth2.googleapis.com/device/code",
7274
}
7375

7476
// Heroku is the endpoint for Heroku.

github/github.go

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,8 @@
66
package github // import "golang.org/x/oauth2/github"
77

88
import (
9-
"golang.org/x/oauth2"
9+
"golang.org/x/oauth2/endpoints"
1010
)
1111

1212
// Endpoint is Github's OAuth 2.0 endpoint.
13-
var Endpoint = oauth2.Endpoint{
14-
AuthURL: "https://github.com/login/oauth/authorize",
15-
TokenURL: "https://github.com/login/oauth/access_token",
16-
}
13+
var Endpoint = endpoints.GitHub

google/google.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
var Endpoint = oauth2.Endpoint{
2424
AuthURL: "https://accounts.google.com/o/oauth2/auth",
2525
TokenURL: "https://oauth2.googleapis.com/token",
26+
DeviceAuthURL: "https://oauth2.googleapis.com/device/code",
2627
AuthStyle: oauth2.AuthStyleInParams,
2728
}
2829

oauth2.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,9 @@ type TokenSource interface {
7171
// Endpoint represents an OAuth 2.0 provider's authorization and token
7272
// endpoint URLs.
7373
type Endpoint struct {
74-
AuthURL string
75-
TokenURL string
74+
AuthURL string
75+
DeviceAuthURL string
76+
TokenURL string
7677

7778
// AuthStyle optionally specifies how the endpoint wants the
7879
// client ID & client secret sent. The zero value means to

0 commit comments

Comments
 (0)