Skip to content

Commit 52c3d81

Browse files
committed
support saslAuthenticate v0 in protocol package
1 parent a9325a2 commit 52c3d81

File tree

4 files changed

+113
-1
lines changed

4 files changed

+113
-1
lines changed

protocol/conn.go

+4
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@ func (c *Conn) RoundTrip(msg Message) (Message, error) {
8787
p.Prepare(apiVersion)
8888
}
8989

90+
if raw, ok := msg.(RawExchanger); ok && raw.Required(versions) {
91+
return raw.RawExchange(c)
92+
}
93+
9094
return RoundTrip(c, apiVersion, correlationID, c.clientID, msg)
9195
}
9296

protocol/protocol.go

+16
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,22 @@ type Partition struct {
422422
Offline []int32
423423
}
424424

425+
// RawExchanger is an extention to the Message interface to allow messages
426+
// to control the request response cycle for the message. This is currently
427+
// only used to facilitate v0 SASL Authenticate requests being written in
428+
// a non-standard fashion when the SASL Handshake was done at v0 but not
429+
// when done at v1.
430+
type RawExchanger interface {
431+
// Required should return true when a RawExchange is needed.
432+
// The passed in versions are the negotiated versions for the connection
433+
// performing the request.
434+
Required(versions map[ApiKey]int16) bool
435+
// RawExchange is given the raw connection to the broker and the Message
436+
// is responsible for writing itself to the connection as well as reading
437+
// the response.
438+
RawExchange(rw io.ReadWriter) (Message, error)
439+
}
440+
425441
// BrokerMessage is an extension of the Message interface implemented by some
426442
// request types to customize the broker assignment logic.
427443
type BrokerMessage interface {

protocol/saslauthenticate/saslauthenticate.go

+45-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
package saslauthenticate
22

3-
import "github.com/segmentio/kafka-go/protocol"
3+
import (
4+
"encoding/binary"
5+
"io"
6+
7+
"github.com/segmentio/kafka-go/protocol"
8+
)
49

510
func init() {
611
protocol.Register(&Request{}, &Response{})
@@ -10,6 +15,43 @@ type Request struct {
1015
AuthBytes []byte `kafka:"min=v0,max=v1"`
1116
}
1217

18+
func (r *Request) RawExchange(rw io.ReadWriter) (protocol.Message, error) {
19+
if err := r.writeTo(rw); err != nil {
20+
return nil, err
21+
}
22+
return r.readResp(rw)
23+
}
24+
25+
func (*Request) Required(versions map[protocol.ApiKey]int16) bool {
26+
const v0 = 0
27+
return versions[protocol.SaslHandshake] == v0
28+
}
29+
30+
func (r *Request) writeTo(w io.Writer) error {
31+
size := len(r.AuthBytes) + 4
32+
buf := make([]byte, size, size)
33+
binary.BigEndian.PutUint32(buf[:4], uint32(len(r.AuthBytes)))
34+
copy(buf[4:], r.AuthBytes)
35+
_, err := w.Write(buf)
36+
return err
37+
}
38+
39+
func (r *Request) readResp(read io.Reader) (protocol.Message, error) {
40+
var lenBuf [4]byte
41+
if _, err := io.ReadFull(read, lenBuf[:]); err != nil {
42+
return nil, err
43+
}
44+
respLen := int32(binary.BigEndian.Uint32(lenBuf[:]))
45+
data := make([]byte, respLen)
46+
47+
if _, err := io.ReadFull(read, data[:]); err != nil {
48+
return nil, err
49+
}
50+
return &Response{
51+
AuthBytes: data,
52+
}, nil
53+
}
54+
1355
func (r *Request) ApiKey() protocol.ApiKey { return protocol.SaslAuthenticate }
1456

1557
type Response struct {
@@ -20,3 +62,5 @@ type Response struct {
2062
}
2163

2264
func (r *Response) ApiKey() protocol.ApiKey { return protocol.SaslAuthenticate }
65+
66+
var _ protocol.RawExchanger = (*Request)(nil)

writer_test.go

+48
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ import (
1010
"sync"
1111
"testing"
1212
"time"
13+
14+
"github.com/segmentio/kafka-go/sasl/plain"
1315
)
1416

1517
func TestBatchQueue(t *testing.T) {
@@ -160,6 +162,10 @@ func TestWriter(t *testing.T) {
160162
scenario: "writing a message to a non-existant topic creates the topic",
161163
function: testWriterAutoCreateTopic,
162164
},
165+
{
166+
scenario: "writing a message with SASL Plain authentication",
167+
function: testWriterSasl,
168+
},
163169
}
164170

165171
for _, test := range tests {
@@ -737,6 +743,48 @@ func testWriterAutoCreateTopic(t *testing.T) {
737743
}
738744
}
739745

746+
func testWriterSasl(t *testing.T) {
747+
topic := makeTopic()
748+
defer deleteTopic(t, topic)
749+
dialer := &Dialer{
750+
Timeout: 10 * time.Second,
751+
SASLMechanism: plain.Mechanism{
752+
Username: "adminplain",
753+
Password: "admin-secret",
754+
},
755+
}
756+
757+
w := newTestWriter(WriterConfig{
758+
Dialer: dialer,
759+
Topic: topic,
760+
Brokers: []string{"localhost:9093"},
761+
})
762+
763+
defer w.Close()
764+
765+
msg := Message{Key: []byte("key"), Value: []byte("Hello World")}
766+
767+
var err error
768+
const retries = 5
769+
for i := 0; i < retries; i++ {
770+
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
771+
defer cancel()
772+
err = w.WriteMessages(ctx, msg)
773+
if errors.Is(err, LeaderNotAvailable) || errors.Is(err, context.DeadlineExceeded) {
774+
time.Sleep(time.Millisecond * 250)
775+
continue
776+
}
777+
778+
if err != nil {
779+
t.Errorf("unexpected error %v", err)
780+
return
781+
}
782+
}
783+
if err != nil {
784+
t.Errorf("unable to create topic %v", err)
785+
}
786+
}
787+
740788
type staticBalancer struct {
741789
partition int
742790
}

0 commit comments

Comments
 (0)