Skip to content

Commit f6feedf

Browse files
author
Igor Drozdov
committed
Simplify 2FA Push auth processing
Use a single channel to handle both Push Auth and OTP results
1 parent fe5feee commit f6feedf

11 files changed

+573
-870
lines changed

internal/command/twofactorverify/twofactorverify.go

+37-107
Original file line numberDiff line numberDiff line change
@@ -14,147 +14,77 @@ import (
1414
"gitlab.com/gitlab-org/gitlab-shell/v14/internal/gitlabnet/twofactorverify"
1515
)
1616

17+
const (
18+
timeout = 30 * time.Second
19+
prompt = "OTP: "
20+
)
21+
1722
type Command struct {
1823
Config *config.Config
19-
Client *twofactorverify.Client
2024
Args *commandargs.Shell
2125
ReadWriter *readwriter.ReadWriter
2226
}
2327

24-
type Result struct {
25-
Error error
26-
Status string
27-
Success bool
28-
}
29-
3028
func (c *Command) Execute(ctx context.Context) error {
31-
ctxlog := log.ContextLogger(ctx)
32-
33-
// config.GetHTTPClient isn't thread-safe so save Client in struct for concurrency
34-
// workaround until #518 is fixed
35-
var err error
36-
c.Client, err = twofactorverify.NewClient(c.Config)
37-
29+
client, err := twofactorverify.NewClient(c.Config)
3830
if err != nil {
39-
ctxlog.WithError(err).Error("twofactorverify: execute: OTP verification failed")
4031
return err
4132
}
4233

43-
// Create timeout context
44-
// TODO: make timeout configurable
45-
const ctxTimeout = 30
46-
timeoutCtx, cancelTimeout := context.WithTimeout(ctx, ctxTimeout*time.Second)
47-
defer cancelTimeout()
34+
ctx, cancel := context.WithTimeout(ctx, timeout)
35+
defer cancel()
36+
37+
fmt.Fprint(c.ReadWriter.Out, prompt)
4838

49-
// Background push notification with timeout
50-
pushauth := make(chan Result)
39+
resultCh := make(chan string)
5140
go func() {
52-
defer close(pushauth)
53-
ctxlog.Info("twofactorverify: execute: waiting for push auth")
54-
status, success, err := c.pushAuth(timeoutCtx)
55-
ctxlog.WithError(err).Info("twofactorverify: execute: push auth verified")
56-
57-
select {
58-
case <-timeoutCtx.Done(): // push cancelled by manual OTP
59-
// skip writing to channel
60-
ctxlog.Info("twofactorverify: execute: push auth cancelled")
61-
default:
62-
pushauth <- Result{Error: err, Status: status, Success: success}
41+
err := client.PushAuth(ctx, c.Args)
42+
if err == nil {
43+
resultCh <- "OTP has been validated by Push Authentication. Git operations are now allowed."
6344
}
6445
}()
6546

66-
// Also allow manual OTP entry while waiting for push, with same timeout as push
67-
verify := make(chan Result)
6847
go func() {
69-
defer close(verify)
70-
ctxlog.Info("twofactorverify: execute: waiting for user input")
71-
answer := ""
72-
answer = c.getOTP(timeoutCtx)
73-
ctxlog.Info("twofactorverify: execute: user input received")
74-
75-
select {
76-
case <-timeoutCtx.Done(): // manual OTP cancelled by push
77-
// skip writing to channel
78-
ctxlog.Info("twofactorverify: execute: verify cancelled")
79-
default:
80-
ctxlog.Info("twofactorverify: execute: verifying entered OTP")
81-
status, success, err := c.verifyOTP(timeoutCtx, answer)
82-
ctxlog.WithError(err).Info("twofactorverify: execute: OTP verified")
83-
verify <- Result{Error: err, Status: status, Success: success}
48+
answer, err := c.getOTP(ctx)
49+
if err != nil {
50+
resultCh <- formatErr(err)
8451
}
85-
}()
8652

87-
for {
88-
select {
89-
case res := <-verify:
90-
if res.Status == "" {
91-
// channel closed; don't print anything
92-
} else {
93-
fmt.Fprint(c.ReadWriter.Out, res.Status)
94-
return nil
95-
}
96-
case res := <-pushauth:
97-
if res.Status == "" {
98-
// channel closed; don't print anything
99-
} else {
100-
fmt.Fprint(c.ReadWriter.Out, res.Status)
101-
return nil
102-
}
103-
case <-timeoutCtx.Done(): // push timed out
104-
fmt.Fprint(c.ReadWriter.Out, "\nOTP verification timed out\n")
105-
return nil
53+
if err := client.VerifyOTP(ctx, c.Args, answer); err != nil {
54+
resultCh <- formatErr(err)
55+
} else {
56+
resultCh <- "OTP validation successful. Git operations are now allowed."
10657
}
58+
}()
59+
60+
var message string
61+
select {
62+
case message = <-resultCh:
63+
case <-ctx.Done():
64+
message = formatErr(ctx.Err())
10765
}
10866

67+
log.WithContextFields(ctx, log.Fields{"message": message}).Info("Two factor verify command finished")
68+
fmt.Fprintf(c.ReadWriter.Out, "\n%v\n", message)
69+
10970
return nil
11071
}
11172

112-
func (c *Command) getOTP(ctx context.Context) string {
113-
prompt := "OTP: "
114-
fmt.Fprint(c.ReadWriter.Out, prompt)
115-
73+
func (c *Command) getOTP(ctx context.Context) (string, error) {
11674
var answer string
11775
otpLength := int64(64)
11876
reader := io.LimitReader(c.ReadWriter.In, otpLength)
11977
if _, err := fmt.Fscanln(reader, &answer); err != nil {
12078
log.ContextLogger(ctx).WithError(err).Debug("twofactorverify: getOTP: Failed to get user input")
12179
}
12280

123-
return answer
124-
}
125-
126-
func (c *Command) verifyOTP(ctx context.Context, otp string) (status string, success bool, err error) {
127-
reason := ""
128-
129-
success, reason, err = c.Client.VerifyOTP(ctx, c.Args, otp)
130-
if success {
131-
status = fmt.Sprintf("\nOTP validation successful. Git operations are now allowed.\n")
132-
} else {
133-
if err != nil {
134-
status = fmt.Sprintf("\nOTP validation failed.\n%v\n", err)
135-
} else {
136-
status = fmt.Sprintf("\nOTP validation failed.\n%v\n", reason)
137-
}
81+
if answer == "" {
82+
return "", fmt.Errorf("OTP cannot be blank.")
13883
}
13984

140-
err = nil
141-
142-
return
85+
return answer, nil
14386
}
14487

145-
func (c *Command) pushAuth(ctx context.Context) (status string, success bool, err error) {
146-
reason := ""
147-
148-
success, reason, err = c.Client.PushAuth(ctx, c.Args)
149-
if success {
150-
status = fmt.Sprintf("\nPush OTP validation successful. Git operations are now allowed.\n")
151-
} else {
152-
if err != nil {
153-
status = fmt.Sprintf("\nPush OTP validation failed.\n%v\n", err)
154-
} else {
155-
status = fmt.Sprintf("\nPush OTP validation failed.\n%v\n", reason)
156-
}
157-
}
158-
159-
return
88+
func formatErr(err error) string {
89+
return fmt.Sprintf("OTP validation failed: %v", err)
16090
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
package twofactorverify
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"encoding/json"
7+
"io"
8+
"net/http"
9+
"testing"
10+
11+
"github.com/stretchr/testify/require"
12+
13+
"gitlab.com/gitlab-org/gitlab-shell/v14/client/testserver"
14+
"gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/commandargs"
15+
"gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/readwriter"
16+
"gitlab.com/gitlab-org/gitlab-shell/v14/internal/config"
17+
"gitlab.com/gitlab-org/gitlab-shell/v14/internal/gitlabnet/twofactorverify"
18+
)
19+
20+
type blockingReader struct{}
21+
22+
func (*blockingReader) Read([]byte) (int, error) {
23+
waitInfinitely := make(chan struct{})
24+
<-waitInfinitely
25+
26+
return 0, nil
27+
}
28+
29+
func setup(t *testing.T) []testserver.TestRequestHandler {
30+
waitInfinitely := make(chan struct{})
31+
requests := []testserver.TestRequestHandler{
32+
{
33+
Path: "/api/v4/internal/two_factor_manual_otp_check",
34+
Handler: func(w http.ResponseWriter, r *http.Request) {
35+
b, err := io.ReadAll(r.Body)
36+
defer r.Body.Close()
37+
38+
require.NoError(t, err)
39+
40+
var requestBody *twofactorverify.RequestBody
41+
require.NoError(t, json.Unmarshal(b, &requestBody))
42+
43+
switch requestBody.KeyId {
44+
case "verify_via_otp", "verify_via_otp_with_push_error":
45+
body := map[string]interface{}{
46+
"success": true,
47+
}
48+
json.NewEncoder(w).Encode(body)
49+
case "wait_infinitely":
50+
<-waitInfinitely
51+
case "error":
52+
body := map[string]interface{}{
53+
"success": false,
54+
"message": "error message",
55+
}
56+
require.NoError(t, json.NewEncoder(w).Encode(body))
57+
case "broken":
58+
w.WriteHeader(http.StatusInternalServerError)
59+
}
60+
},
61+
},
62+
{
63+
Path: "/api/v4/internal/two_factor_push_otp_check",
64+
Handler: func(w http.ResponseWriter, r *http.Request) {
65+
b, err := io.ReadAll(r.Body)
66+
defer r.Body.Close()
67+
68+
require.NoError(t, err)
69+
70+
var requestBody *twofactorverify.RequestBody
71+
require.NoError(t, json.Unmarshal(b, &requestBody))
72+
73+
switch requestBody.KeyId {
74+
case "verify_via_push":
75+
body := map[string]interface{}{
76+
"success": true,
77+
}
78+
json.NewEncoder(w).Encode(body)
79+
case "verify_via_otp_with_push_error":
80+
w.WriteHeader(http.StatusInternalServerError)
81+
default:
82+
<-waitInfinitely
83+
}
84+
},
85+
},
86+
}
87+
88+
return requests
89+
}
90+
91+
const errorHeader = "OTP validation failed: "
92+
93+
func TestExecute(t *testing.T) {
94+
requests := setup(t)
95+
96+
url := testserver.StartSocketHttpServer(t, requests)
97+
98+
testCases := []struct {
99+
desc string
100+
arguments *commandargs.Shell
101+
input io.Reader
102+
expectedOutput string
103+
}{
104+
{
105+
desc: "Verify via OTP",
106+
arguments: &commandargs.Shell{GitlabKeyId: "verify_via_otp"},
107+
expectedOutput: "OTP validation successful. Git operations are now allowed.\n",
108+
},
109+
{
110+
desc: "Verify via OTP",
111+
arguments: &commandargs.Shell{GitlabKeyId: "verify_via_otp_with_push_error"},
112+
expectedOutput: "OTP validation successful. Git operations are now allowed.\n",
113+
},
114+
{
115+
desc: "Verify via push authentication",
116+
arguments: &commandargs.Shell{GitlabKeyId: "verify_via_push"},
117+
input: &blockingReader{},
118+
expectedOutput: "OTP has been validated by Push Authentication. Git operations are now allowed.\n",
119+
},
120+
{
121+
desc: "With an empty OTP",
122+
arguments: &commandargs.Shell{GitlabKeyId: "verify_via_otp"},
123+
input: bytes.NewBufferString("\n"),
124+
expectedOutput: errorHeader + "OTP cannot be blank.\n",
125+
},
126+
{
127+
desc: "With bad response",
128+
arguments: &commandargs.Shell{GitlabKeyId: "-1"},
129+
expectedOutput: errorHeader + "Parsing failed\n",
130+
},
131+
{
132+
desc: "With API returns an error",
133+
arguments: &commandargs.Shell{GitlabKeyId: "error"},
134+
expectedOutput: errorHeader + "error message\n",
135+
},
136+
{
137+
desc: "With API fails",
138+
arguments: &commandargs.Shell{GitlabKeyId: "broken"},
139+
expectedOutput: errorHeader + "Internal API error (500)\n",
140+
},
141+
{
142+
desc: "With missing arguments",
143+
arguments: &commandargs.Shell{},
144+
expectedOutput: errorHeader + "who='' is invalid\n",
145+
},
146+
}
147+
148+
for _, tc := range testCases {
149+
t.Run(tc.desc, func(t *testing.T) {
150+
output := &bytes.Buffer{}
151+
152+
input := tc.input
153+
if input == nil {
154+
input = bytes.NewBufferString("123456\n")
155+
}
156+
157+
cmd := &Command{
158+
Config: &config.Config{GitlabUrl: url},
159+
Args: tc.arguments,
160+
ReadWriter: &readwriter.ReadWriter{Out: output, In: input},
161+
}
162+
163+
err := cmd.Execute(context.Background())
164+
165+
require.NoError(t, err)
166+
require.Equal(t, prompt+"\n"+tc.expectedOutput, output.String())
167+
})
168+
}
169+
}
170+
171+
func TestCanceledContext(t *testing.T) {
172+
requests := setup(t)
173+
174+
output := &bytes.Buffer{}
175+
176+
url := testserver.StartSocketHttpServer(t, requests)
177+
cmd := &Command{
178+
Config: &config.Config{GitlabUrl: url},
179+
Args: &commandargs.Shell{GitlabKeyId: "wait_infinitely"},
180+
ReadWriter: &readwriter.ReadWriter{Out: output, In: &bytes.Buffer{}},
181+
}
182+
183+
ctx, cancel := context.WithCancel(context.Background())
184+
185+
errCh := make(chan error)
186+
go func() { errCh <- cmd.Execute(ctx) }()
187+
cancel()
188+
189+
require.NoError(t, <-errCh)
190+
require.Equal(t, prompt+"\n"+errorHeader+"context canceled\n", output.String())
191+
}

0 commit comments

Comments
 (0)