Skip to content

Commit fe5feee

Browse files
kmcknightIgor Drozdov
kmcknight
authored and
Igor Drozdov
committed
Implement Push Auth support for 2FA verification
When `2fa_verify` command is executed: - A user is asked to enter OTP - A blocking call for push auth is performed Then: - If the push auth request fails, the user is still able to enter OTP - If OTP is invalid, the `2fa_verify` command ends the execution - If OTP is valid or push auth request succeeded, then the user is successfully authenticated - If 30 seconds passed while no OTP or Push have been provided, then the `2fa_verify` command ends the execution
1 parent 9cacca5 commit fe5feee

8 files changed

+568
-61
lines changed

internal/command/twofactorverify/twofactorverify.go

+108-14
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"fmt"
66
"io"
7+
"time"
78

89
"gitlab.com/gitlab-org/labkit/log"
910

@@ -15,23 +16,96 @@ import (
1516

1617
type Command struct {
1718
Config *config.Config
19+
Client *twofactorverify.Client
1820
Args *commandargs.Shell
1921
ReadWriter *readwriter.ReadWriter
2022
}
2123

24+
type Result struct {
25+
Error error
26+
Status string
27+
Success bool
28+
}
29+
2230
func (c *Command) Execute(ctx context.Context) error {
2331
ctxlog := log.ContextLogger(ctx)
24-
ctxlog.Info("twofactorverify: execute: waiting for user input")
25-
otp := c.getOTP(ctx)
2632

27-
ctxlog.Info("twofactorverify: execute: verifying entered OTP")
28-
err := c.verifyOTP(ctx, otp)
33+
// config.GetHTTPClient isn't thread-safe so save Client in struct for concurrency
34+
// workaround until #518 is fixed
35+
var err error
36+
c.Client, err = twofactorverify.NewClient(c.Config)
37+
2938
if err != nil {
3039
ctxlog.WithError(err).Error("twofactorverify: execute: OTP verification failed")
3140
return err
3241
}
3342

34-
ctxlog.WithError(err).Info("twofactorverify: execute: OTP verified")
43+
// Create timeout context
44+
// TODO: make timeout configurable
45+
const ctxTimeout = 30
46+
timeoutCtx, cancelTimeout := context.WithTimeout(ctx, ctxTimeout*time.Second)
47+
defer cancelTimeout()
48+
49+
// Background push notification with timeout
50+
pushauth := make(chan Result)
51+
go func() {
52+
defer close(pushauth)
53+
ctxlog.Info("twofactorverify: execute: waiting for push auth")
54+
status, success, err := c.pushAuth(timeoutCtx)
55+
ctxlog.WithError(err).Info("twofactorverify: execute: push auth verified")
56+
57+
select {
58+
case <-timeoutCtx.Done(): // push cancelled by manual OTP
59+
// skip writing to channel
60+
ctxlog.Info("twofactorverify: execute: push auth cancelled")
61+
default:
62+
pushauth <- Result{Error: err, Status: status, Success: success}
63+
}
64+
}()
65+
66+
// Also allow manual OTP entry while waiting for push, with same timeout as push
67+
verify := make(chan Result)
68+
go func() {
69+
defer close(verify)
70+
ctxlog.Info("twofactorverify: execute: waiting for user input")
71+
answer := ""
72+
answer = c.getOTP(timeoutCtx)
73+
ctxlog.Info("twofactorverify: execute: user input received")
74+
75+
select {
76+
case <-timeoutCtx.Done(): // manual OTP cancelled by push
77+
// skip writing to channel
78+
ctxlog.Info("twofactorverify: execute: verify cancelled")
79+
default:
80+
ctxlog.Info("twofactorverify: execute: verifying entered OTP")
81+
status, success, err := c.verifyOTP(timeoutCtx, answer)
82+
ctxlog.WithError(err).Info("twofactorverify: execute: OTP verified")
83+
verify <- Result{Error: err, Status: status, Success: success}
84+
}
85+
}()
86+
87+
for {
88+
select {
89+
case res := <-verify:
90+
if res.Status == "" {
91+
// channel closed; don't print anything
92+
} else {
93+
fmt.Fprint(c.ReadWriter.Out, res.Status)
94+
return nil
95+
}
96+
case res := <-pushauth:
97+
if res.Status == "" {
98+
// channel closed; don't print anything
99+
} else {
100+
fmt.Fprint(c.ReadWriter.Out, res.Status)
101+
return nil
102+
}
103+
case <-timeoutCtx.Done(): // push timed out
104+
fmt.Fprint(c.ReadWriter.Out, "\nOTP verification timed out\n")
105+
return nil
106+
}
107+
}
108+
35109
return nil
36110
}
37111

@@ -49,18 +123,38 @@ func (c *Command) getOTP(ctx context.Context) string {
49123
return answer
50124
}
51125

52-
func (c *Command) verifyOTP(ctx context.Context, otp string) error {
53-
client, err := twofactorverify.NewClient(c.Config)
54-
if err != nil {
55-
return err
126+
func (c *Command) verifyOTP(ctx context.Context, otp string) (status string, success bool, err error) {
127+
reason := ""
128+
129+
success, reason, err = c.Client.VerifyOTP(ctx, c.Args, otp)
130+
if success {
131+
status = fmt.Sprintf("\nOTP validation successful. Git operations are now allowed.\n")
132+
} else {
133+
if err != nil {
134+
status = fmt.Sprintf("\nOTP validation failed.\n%v\n", err)
135+
} else {
136+
status = fmt.Sprintf("\nOTP validation failed.\n%v\n", reason)
137+
}
56138
}
57139

58-
err = client.VerifyOTP(ctx, c.Args, otp)
59-
if err == nil {
60-
fmt.Fprint(c.ReadWriter.Out, "\nOTP validation successful. Git operations are now allowed.\n")
140+
err = nil
141+
142+
return
143+
}
144+
145+
func (c *Command) pushAuth(ctx context.Context) (status string, success bool, err error) {
146+
reason := ""
147+
148+
success, reason, err = c.Client.PushAuth(ctx, c.Args)
149+
if success {
150+
status = fmt.Sprintf("\nPush OTP validation successful. Git operations are now allowed.\n")
61151
} else {
62-
fmt.Fprintf(c.ReadWriter.Out, "\nOTP validation failed.\n%v\n", err)
152+
if err != nil {
153+
status = fmt.Sprintf("\nPush OTP validation failed.\n%v\n", err)
154+
} else {
155+
status = fmt.Sprintf("\nPush OTP validation failed.\n%v\n", reason)
156+
}
63157
}
64158

65-
return nil
159+
return
66160
}

internal/command/twofactorverify/twofactorverify_test.go renamed to internal/command/twofactorverify/twofactorverifymanual_test.go

+55-22
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"io"
88
"net/http"
99
"testing"
10+
"time"
1011

1112
"github.com/stretchr/testify/require"
1213

@@ -17,10 +18,10 @@ import (
1718
"gitlab.com/gitlab-org/gitlab-shell/v14/internal/gitlabnet/twofactorverify"
1819
)
1920

20-
func setup(t *testing.T) []testserver.TestRequestHandler {
21+
func setupManual(t *testing.T) []testserver.TestRequestHandler {
2122
requests := []testserver.TestRequestHandler{
2223
{
23-
Path: "/api/v4/internal/two_factor_otp_check",
24+
Path: "/api/v4/internal/two_factor_manual_otp_check",
2425
Handler: func(w http.ResponseWriter, r *http.Request) {
2526
b, err := io.ReadAll(r.Body)
2627
defer r.Body.Close()
@@ -30,14 +31,15 @@ func setup(t *testing.T) []testserver.TestRequestHandler {
3031
var requestBody *twofactorverify.RequestBody
3132
require.NoError(t, json.Unmarshal(b, &requestBody))
3233

34+
var body map[string]interface{}
3335
switch requestBody.KeyId {
3436
case "1":
35-
body := map[string]interface{}{
37+
body = map[string]interface{}{
3638
"success": true,
3739
}
3840
json.NewEncoder(w).Encode(body)
3941
case "error":
40-
body := map[string]interface{}{
42+
body = map[string]interface{}{
4143
"success": false,
4244
"message": "error message",
4345
}
@@ -47,18 +49,55 @@ func setup(t *testing.T) []testserver.TestRequestHandler {
4749
}
4850
},
4951
},
52+
{
53+
Path: "/api/v4/internal/two_factor_push_otp_check",
54+
Handler: func(w http.ResponseWriter, r *http.Request) {
55+
b, err := io.ReadAll(r.Body)
56+
defer r.Body.Close()
57+
58+
require.NoError(t, err)
59+
60+
var requestBody *twofactorverify.RequestBody
61+
require.NoError(t, json.Unmarshal(b, &requestBody))
62+
63+
time.Sleep(5 * time.Second)
64+
65+
var body map[string]interface{}
66+
switch requestBody.KeyId {
67+
case "1":
68+
body = map[string]interface{}{
69+
"success": true,
70+
}
71+
json.NewEncoder(w).Encode(body)
72+
case "error":
73+
body = map[string]interface{}{
74+
"success": false,
75+
"message": "error message",
76+
}
77+
require.NoError(t, json.NewEncoder(w).Encode(body))
78+
case "broken":
79+
w.WriteHeader(http.StatusInternalServerError)
80+
default:
81+
body = map[string]interface{}{
82+
"success": true,
83+
"message": "default message",
84+
}
85+
json.NewEncoder(w).Encode(body)
86+
}
87+
},
88+
},
5089
}
5190

5291
return requests
5392
}
5493

5594
const (
56-
question = "OTP: \n"
57-
errorHeader = "OTP validation failed.\n"
95+
manualQuestion = "OTP: \n"
96+
manualErrorHeader = "OTP validation failed.\n"
5897
)
5998

60-
func TestExecute(t *testing.T) {
61-
requests := setup(t)
99+
func TestExecuteManual(t *testing.T) {
100+
requests := setupManual(t)
62101

63102
url := testserver.StartSocketHttpServer(t, requests)
64103

@@ -69,41 +108,35 @@ func TestExecute(t *testing.T) {
69108
expectedOutput string
70109
}{
71110
{
72-
desc: "With a known key id",
73-
arguments: &commandargs.Shell{GitlabKeyId: "1"},
74-
answer: "123456\n",
75-
expectedOutput: question +
76-
"OTP validation successful. Git operations are now allowed.\n",
111+
desc: "With a known key id",
112+
arguments: &commandargs.Shell{GitlabKeyId: "1"},
113+
answer: "123456\n",
114+
expectedOutput: manualQuestion + "OTP validation successful. Git operations are now allowed.\n",
77115
},
78116
{
79117
desc: "With bad response",
80118
arguments: &commandargs.Shell{GitlabKeyId: "-1"},
81119
answer: "123456\n",
82-
expectedOutput: question + errorHeader + "Parsing failed\n",
120+
expectedOutput: manualQuestion + manualErrorHeader + "Parsing failed\n",
83121
},
84122
{
85123
desc: "With API returns an error",
86124
arguments: &commandargs.Shell{GitlabKeyId: "error"},
87125
answer: "yes\n",
88-
expectedOutput: question + errorHeader + "error message\n",
126+
expectedOutput: manualQuestion + manualErrorHeader + "error message\n",
89127
},
90128
{
91129
desc: "With API fails",
92130
arguments: &commandargs.Shell{GitlabKeyId: "broken"},
93131
answer: "yes\n",
94-
expectedOutput: question + errorHeader + "Internal API error (500)\n",
95-
},
96-
{
97-
desc: "With missing arguments",
98-
arguments: &commandargs.Shell{},
99-
answer: "yes\n",
100-
expectedOutput: question + errorHeader + "who='' is invalid\n",
132+
expectedOutput: manualQuestion + manualErrorHeader + "Internal API error (500)\n",
101133
},
102134
}
103135

104136
for _, tc := range testCases {
105137
t.Run(tc.desc, func(t *testing.T) {
106138
output := &bytes.Buffer{}
139+
107140
input := bytes.NewBufferString(tc.answer)
108141

109142
cmd := &Command{

0 commit comments

Comments
 (0)