diff --git a/dialer.go b/dialer.go index 9131d4933..f019a421b 100644 --- a/dialer.go +++ b/dialer.go @@ -283,17 +283,16 @@ func (d *Dialer) connect(ctx context.Context, network, address string, connCfg C // In case of error, this function *does not* close the connection. That is the // responsibility of the caller. func (d *Dialer) authenticateSASL(ctx context.Context, conn *Conn) error { - mech, state, err := d.SASLMechanism.Start(ctx) - if err != nil { + if err := conn.saslHandshake(d.SASLMechanism.Name()); err != nil { return err } - err = conn.saslHandshake(mech) + + sess, state, err := d.SASLMechanism.Start(ctx) if err != nil { return err } - var completed bool - for !completed { + for completed := false; !completed; { challenge, err := conn.saslAuthenticate(state) switch err { case nil: @@ -306,7 +305,7 @@ func (d *Dialer) authenticateSASL(ctx context.Context, conn *Conn) error { return err } - completed, state, err = d.SASLMechanism.Next(ctx, challenge) + completed, state, err = sess.Next(ctx, challenge) if err != nil { return err } diff --git a/sasl/plain/plain.go b/sasl/plain/plain.go index 15341d081..10c7632d2 100644 --- a/sasl/plain/plain.go +++ b/sasl/plain/plain.go @@ -3,6 +3,8 @@ package plain import ( "context" "fmt" + + "github.com/segmentio/kafka-go/sasl" ) // Mechanism implements the PLAIN mechanism and passes the credentials in clear @@ -12,8 +14,13 @@ type Mechanism struct { Password string } -func (m Mechanism) Start(ctx context.Context) (string, []byte, error) { - return "PLAIN", []byte(fmt.Sprintf("\x00%s\x00%s", m.Username, m.Password)), nil +func (Mechanism) Name() string { + return "PLAIN" +} + +func (m Mechanism) Start(ctx context.Context) (sasl.StateMachine, []byte, error) { + // Mechanism is stateless, so it can also implement sasl.Session + return m, []byte(fmt.Sprintf("\x00%s\x00%s", m.Username, m.Password)), nil } func (m Mechanism) Next(ctx context.Context, challenge []byte) (bool, []byte, error) { diff --git a/sasl/sasl.go b/sasl/sasl.go index ae7121c30..eb07f64fb 100644 --- a/sasl/sasl.go +++ b/sasl/sasl.go @@ -2,30 +2,43 @@ package sasl import "context" -// Mechanism implements the SASL state machine. It is initialized by calling -// Start at which point the initial bytes should be sent to the server. The -// caller then loops by passing the server's response into Next and then sending -// Next's returned bytes to the server. Eventually either Next will indicate -// that the authentication has been successfully completed or an error will -// cause the state machine to exit prematurely. +// Mechanism implements the SASL state machine for a particular mode of +// authentication. It is used by the kafka.Dialer to perform the SASL +// handshake. // -// A Mechanism must be re-usable, but it does not need to be safe for concurrent -// access by multiple go routines. +// A Mechanism must be re-usable and safe for concurrent access by multiple +// goroutines. type Mechanism interface { - // Start begins SASL authentication. It returns the authentication mechanism - // name and "initial response" data (if required by the selected mechanism). - // A non-nil error causes the client to abort the authentication attempt. + // Name returns the identifier for this SASL mechanism. This string will be + // passed to the SASL handshake request and much match one of the mechanisms + // supported by Kafka. + Name() string + + // Start begins SASL authentication. It returns an authentication state + // machine and "initial response" data (if required by the selected + // mechanism). A non-nil error causes the client to abort the authentication + // attempt. // // A nil ir value is different from a zero-length value. The nil value // indicates that the selected mechanism does not use an initial response, // while a zero-length value indicates an empty initial response, which must // be sent to the server. - // - // In order to ensure that the Mechanism is reusable, calling Start must - // reset any internal state. - Start(ctx context.Context) (mech string, ir []byte, err error) + Start(ctx context.Context) (sess StateMachine, ir []byte, err error) +} - // Next continues challenge-response authentication. A non-nil error causes - // the client to abort the authentication attempt. +// StateMachine implements the SASL challenge/response flow for a single SASL +// handshake. A StateMachine will be created by the Mechanism per connection, +// so it does not need to be safe for concurrent access by multiple goroutines. +// +// Once the StateMachine is created by the Mechanism, the caller loops by +// passing the server's response into Next and then sending Next's returned +// bytes to the server. Eventually either Next will indicate that the +// authentication has been successfully completed via the done return value, or +// it will indicate that the authentication failed by returning a non-nil error. +type StateMachine interface { + // Next continues challenge-response authentication. A non-nil error + // indicates that the client should abort the authentication attempt. If + // the client has been successfully authenticated, then the done return + // value will be true. Next(ctx context.Context, challenge []byte) (done bool, response []byte, err error) } diff --git a/sasl/sasl_test.go b/sasl/sasl_test.go index 4ed5c214a..7a6307c45 100644 --- a/sasl/sasl_test.go +++ b/sasl/sasl_test.go @@ -65,18 +65,18 @@ func TestSASL(t *testing.T) { } for _, tt := range tests { - name, _, _ := tt.valid().Start(context.Background()) + mech := tt.valid() if !ktesting.KafkaIsAtLeast(tt.minKafka) { t.Skip("requires min kafka version " + tt.minKafka) } - t.Run(name+" success", func(t *testing.T) { + t.Run(mech.Name()+" success", func(t *testing.T) { testConnect(t, tt.valid(), true) }) - t.Run(name+" failure", func(t *testing.T) { + t.Run(mech.Name()+" failure", func(t *testing.T) { testConnect(t, tt.invalid(), false) }) - t.Run(name+" is reusable", func(t *testing.T) { + t.Run(mech.Name()+" is reusable", func(t *testing.T) { mech := tt.valid() testConnect(t, mech, true) testConnect(t, mech, true) diff --git a/sasl/scram/scram.go b/sasl/scram/scram.go index 495630383..bc2b28ed2 100644 --- a/sasl/scram/scram.go +++ b/sasl/scram/scram.go @@ -48,7 +48,10 @@ var ( type mechanism struct { algo Algorithm client *scram.Client - convo *scram.ClientConversation +} + +type session struct { + convo *scram.ClientConversation } // Mechanism returns a new sasl.Mechanism that will use SCRAM with the provided @@ -69,13 +72,20 @@ func Mechanism(algo Algorithm, username, password string) (sasl.Mechanism, error }, nil } -func (m *mechanism) Start(ctx context.Context) (string, []byte, error) { - m.convo = m.client.NewConversation() - str, err := m.convo.Step("") - return m.algo.Name(), []byte(str), err +func (m *mechanism) Name() string { + return m.algo.Name() +} + +func (m *mechanism) Start(ctx context.Context) (sasl.StateMachine, []byte, error) { + convo := m.client.NewConversation() + str, err := convo.Step("") + if err != nil { + return nil, nil, err + } + return &session{convo: convo}, []byte(str), nil } -func (m *mechanism) Next(ctx context.Context, challenge []byte) (bool, []byte, error) { - str, err := m.convo.Step(string(challenge)) - return m.convo.Done(), []byte(str), err +func (s *session) Next(ctx context.Context, challenge []byte) (bool, []byte, error) { + str, err := s.convo.Step(string(challenge)) + return s.convo.Done(), []byte(str), err }