Skip to content

Commit 3d2413f

Browse files
author
Steve van Loben Sels
authored
Make sasl.Mechanism safe for concurrent use (#323)
While is a breaking change for the sasl.Mechanism interface, it's not expected that library clients are implementing said interface as Kafka only accepts a very specific set of mechanisms. Fixes #317
1 parent 1248320 commit 3d2413f

File tree

5 files changed

+66
-37
lines changed

5 files changed

+66
-37
lines changed

dialer.go

+5-6
Original file line numberDiff line numberDiff line change
@@ -283,17 +283,16 @@ func (d *Dialer) connect(ctx context.Context, network, address string, connCfg C
283283
// In case of error, this function *does not* close the connection. That is the
284284
// responsibility of the caller.
285285
func (d *Dialer) authenticateSASL(ctx context.Context, conn *Conn) error {
286-
mech, state, err := d.SASLMechanism.Start(ctx)
287-
if err != nil {
286+
if err := conn.saslHandshake(d.SASLMechanism.Name()); err != nil {
288287
return err
289288
}
290-
err = conn.saslHandshake(mech)
289+
290+
sess, state, err := d.SASLMechanism.Start(ctx)
291291
if err != nil {
292292
return err
293293
}
294294

295-
var completed bool
296-
for !completed {
295+
for completed := false; !completed; {
297296
challenge, err := conn.saslAuthenticate(state)
298297
switch err {
299298
case nil:
@@ -306,7 +305,7 @@ func (d *Dialer) authenticateSASL(ctx context.Context, conn *Conn) error {
306305
return err
307306
}
308307

309-
completed, state, err = d.SASLMechanism.Next(ctx, challenge)
308+
completed, state, err = sess.Next(ctx, challenge)
310309
if err != nil {
311310
return err
312311
}

sasl/plain/plain.go

+9-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ package plain
33
import (
44
"context"
55
"fmt"
6+
7+
"github.com/segmentio/kafka-go/sasl"
68
)
79

810
// Mechanism implements the PLAIN mechanism and passes the credentials in clear
@@ -12,8 +14,13 @@ type Mechanism struct {
1214
Password string
1315
}
1416

15-
func (m Mechanism) Start(ctx context.Context) (string, []byte, error) {
16-
return "PLAIN", []byte(fmt.Sprintf("\x00%s\x00%s", m.Username, m.Password)), nil
17+
func (Mechanism) Name() string {
18+
return "PLAIN"
19+
}
20+
21+
func (m Mechanism) Start(ctx context.Context) (sasl.StateMachine, []byte, error) {
22+
// Mechanism is stateless, so it can also implement sasl.Session
23+
return m, []byte(fmt.Sprintf("\x00%s\x00%s", m.Username, m.Password)), nil
1724
}
1825

1926
func (m Mechanism) Next(ctx context.Context, challenge []byte) (bool, []byte, error) {

sasl/sasl.go

+30-17
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,43 @@ package sasl
22

33
import "context"
44

5-
// Mechanism implements the SASL state machine. It is initialized by calling
6-
// Start at which point the initial bytes should be sent to the server. The
7-
// caller then loops by passing the server's response into Next and then sending
8-
// Next's returned bytes to the server. Eventually either Next will indicate
9-
// that the authentication has been successfully completed or an error will
10-
// cause the state machine to exit prematurely.
5+
// Mechanism implements the SASL state machine for a particular mode of
6+
// authentication. It is used by the kafka.Dialer to perform the SASL
7+
// handshake.
118
//
12-
// A Mechanism must be re-usable, but it does not need to be safe for concurrent
13-
// access by multiple go routines.
9+
// A Mechanism must be re-usable and safe for concurrent access by multiple
10+
// goroutines.
1411
type Mechanism interface {
15-
// Start begins SASL authentication. It returns the authentication mechanism
16-
// name and "initial response" data (if required by the selected mechanism).
17-
// A non-nil error causes the client to abort the authentication attempt.
12+
// Name returns the identifier for this SASL mechanism. This string will be
13+
// passed to the SASL handshake request and much match one of the mechanisms
14+
// supported by Kafka.
15+
Name() string
16+
17+
// Start begins SASL authentication. It returns an authentication state
18+
// machine and "initial response" data (if required by the selected
19+
// mechanism). A non-nil error causes the client to abort the authentication
20+
// attempt.
1821
//
1922
// A nil ir value is different from a zero-length value. The nil value
2023
// indicates that the selected mechanism does not use an initial response,
2124
// while a zero-length value indicates an empty initial response, which must
2225
// be sent to the server.
23-
//
24-
// In order to ensure that the Mechanism is reusable, calling Start must
25-
// reset any internal state.
26-
Start(ctx context.Context) (mech string, ir []byte, err error)
26+
Start(ctx context.Context) (sess StateMachine, ir []byte, err error)
27+
}
2728

28-
// Next continues challenge-response authentication. A non-nil error causes
29-
// the client to abort the authentication attempt.
29+
// StateMachine implements the SASL challenge/response flow for a single SASL
30+
// handshake. A StateMachine will be created by the Mechanism per connection,
31+
// so it does not need to be safe for concurrent access by multiple goroutines.
32+
//
33+
// Once the StateMachine is created by the Mechanism, the caller loops by
34+
// passing the server's response into Next and then sending Next's returned
35+
// bytes to the server. Eventually either Next will indicate that the
36+
// authentication has been successfully completed via the done return value, or
37+
// it will indicate that the authentication failed by returning a non-nil error.
38+
type StateMachine interface {
39+
// Next continues challenge-response authentication. A non-nil error
40+
// indicates that the client should abort the authentication attempt. If
41+
// the client has been successfully authenticated, then the done return
42+
// value will be true.
3043
Next(ctx context.Context, challenge []byte) (done bool, response []byte, err error)
3144
}

sasl/sasl_test.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -65,18 +65,18 @@ func TestSASL(t *testing.T) {
6565
}
6666

6767
for _, tt := range tests {
68-
name, _, _ := tt.valid().Start(context.Background())
68+
mech := tt.valid()
6969
if !ktesting.KafkaIsAtLeast(tt.minKafka) {
7070
t.Skip("requires min kafka version " + tt.minKafka)
7171
}
7272

73-
t.Run(name+" success", func(t *testing.T) {
73+
t.Run(mech.Name()+" success", func(t *testing.T) {
7474
testConnect(t, tt.valid(), true)
7575
})
76-
t.Run(name+" failure", func(t *testing.T) {
76+
t.Run(mech.Name()+" failure", func(t *testing.T) {
7777
testConnect(t, tt.invalid(), false)
7878
})
79-
t.Run(name+" is reusable", func(t *testing.T) {
79+
t.Run(mech.Name()+" is reusable", func(t *testing.T) {
8080
mech := tt.valid()
8181
testConnect(t, mech, true)
8282
testConnect(t, mech, true)

sasl/scram/scram.go

+18-8
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,10 @@ var (
4848
type mechanism struct {
4949
algo Algorithm
5050
client *scram.Client
51-
convo *scram.ClientConversation
51+
}
52+
53+
type session struct {
54+
convo *scram.ClientConversation
5255
}
5356

5457
// Mechanism returns a new sasl.Mechanism that will use SCRAM with the provided
@@ -69,13 +72,20 @@ func Mechanism(algo Algorithm, username, password string) (sasl.Mechanism, error
6972
}, nil
7073
}
7174

72-
func (m *mechanism) Start(ctx context.Context) (string, []byte, error) {
73-
m.convo = m.client.NewConversation()
74-
str, err := m.convo.Step("")
75-
return m.algo.Name(), []byte(str), err
75+
func (m *mechanism) Name() string {
76+
return m.algo.Name()
77+
}
78+
79+
func (m *mechanism) Start(ctx context.Context) (sasl.StateMachine, []byte, error) {
80+
convo := m.client.NewConversation()
81+
str, err := convo.Step("")
82+
if err != nil {
83+
return nil, nil, err
84+
}
85+
return &session{convo: convo}, []byte(str), nil
7686
}
7787

78-
func (m *mechanism) Next(ctx context.Context, challenge []byte) (bool, []byte, error) {
79-
str, err := m.convo.Step(string(challenge))
80-
return m.convo.Done(), []byte(str), err
88+
func (s *session) Next(ctx context.Context, challenge []byte) (bool, []byte, error) {
89+
str, err := s.convo.Step(string(challenge))
90+
return s.convo.Done(), []byte(str), err
8191
}

0 commit comments

Comments
 (0)