Skip to content

Commit 4296f73

Browse files
authored
support saslAuthenticate v0 in protocol package (#869)
1 parent ae86f55 commit 4296f73

File tree

4 files changed

+115
-1
lines changed

4 files changed

+115
-1
lines changed

protocol/conn.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@ func (c *Conn) RoundTrip(msg Message) (Message, error) {
8787
p.Prepare(apiVersion)
8888
}
8989

90+
if raw, ok := msg.(RawExchanger); ok && raw.Required(versions) {
91+
return raw.RawExchange(c)
92+
}
93+
9094
return RoundTrip(c, apiVersion, correlationID, c.clientID, msg)
9195
}
9296

protocol/protocol.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,22 @@ type Partition struct {
422422
Offline []int32
423423
}
424424

425+
// RawExchanger is an extention to the Message interface to allow messages
426+
// to control the request response cycle for the message. This is currently
427+
// only used to facilitate v0 SASL Authenticate requests being written in
428+
// a non-standard fashion when the SASL Handshake was done at v0 but not
429+
// when done at v1.
430+
type RawExchanger interface {
431+
// Required should return true when a RawExchange is needed.
432+
// The passed in versions are the negotiated versions for the connection
433+
// performing the request.
434+
Required(versions map[ApiKey]int16) bool
435+
// RawExchange is given the raw connection to the broker and the Message
436+
// is responsible for writing itself to the connection as well as reading
437+
// the response.
438+
RawExchange(rw io.ReadWriter) (Message, error)
439+
}
440+
425441
// BrokerMessage is an extension of the Message interface implemented by some
426442
// request types to customize the broker assignment logic.
427443
type BrokerMessage interface {

protocol/saslauthenticate/saslauthenticate.go

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
package saslauthenticate
22

3-
import "github.com/segmentio/kafka-go/protocol"
3+
import (
4+
"encoding/binary"
5+
"io"
6+
7+
"github.com/segmentio/kafka-go/protocol"
8+
)
49

510
func init() {
611
protocol.Register(&Request{}, &Response{})
@@ -10,6 +15,43 @@ type Request struct {
1015
AuthBytes []byte `kafka:"min=v0,max=v1"`
1116
}
1217

18+
func (r *Request) RawExchange(rw io.ReadWriter) (protocol.Message, error) {
19+
if err := r.writeTo(rw); err != nil {
20+
return nil, err
21+
}
22+
return r.readResp(rw)
23+
}
24+
25+
func (*Request) Required(versions map[protocol.ApiKey]int16) bool {
26+
const v0 = 0
27+
return versions[protocol.SaslHandshake] == v0
28+
}
29+
30+
func (r *Request) writeTo(w io.Writer) error {
31+
size := len(r.AuthBytes) + 4
32+
buf := make([]byte, size, size)
33+
binary.BigEndian.PutUint32(buf[:4], uint32(len(r.AuthBytes)))
34+
copy(buf[4:], r.AuthBytes)
35+
_, err := w.Write(buf)
36+
return err
37+
}
38+
39+
func (r *Request) readResp(read io.Reader) (protocol.Message, error) {
40+
var lenBuf [4]byte
41+
if _, err := io.ReadFull(read, lenBuf[:]); err != nil {
42+
return nil, err
43+
}
44+
respLen := int32(binary.BigEndian.Uint32(lenBuf[:]))
45+
data := make([]byte, respLen)
46+
47+
if _, err := io.ReadFull(read, data[:]); err != nil {
48+
return nil, err
49+
}
50+
return &Response{
51+
AuthBytes: data,
52+
}, nil
53+
}
54+
1355
func (r *Request) ApiKey() protocol.ApiKey { return protocol.SaslAuthenticate }
1456

1557
type Response struct {
@@ -20,3 +62,5 @@ type Response struct {
2062
}
2163

2264
func (r *Response) ApiKey() protocol.ApiKey { return protocol.SaslAuthenticate }
65+
66+
var _ protocol.RawExchanger = (*Request)(nil)

writer_test.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ import (
1010
"sync"
1111
"testing"
1212
"time"
13+
14+
"github.com/segmentio/kafka-go/sasl/plain"
1315
)
1416

1517
func TestBatchQueue(t *testing.T) {
@@ -164,6 +166,10 @@ func TestWriter(t *testing.T) {
164166
scenario: "terminates on an attempt to write a message to a nonexistent topic",
165167
function: testWriterTerminateMissingTopic,
166168
},
169+
{
170+
scenario: "writing a message with SASL Plain authentication",
171+
function: testWriterSasl,
172+
},
167173
}
168174

169175
for _, test := range tests {
@@ -766,6 +772,50 @@ func testWriterTerminateMissingTopic(t *testing.T) {
766772
}
767773
}
768774

775+
func testWriterSasl(t *testing.T) {
776+
topic := makeTopic()
777+
defer deleteTopic(t, topic)
778+
dialer := &Dialer{
779+
Timeout: 10 * time.Second,
780+
SASLMechanism: plain.Mechanism{
781+
Username: "adminplain",
782+
Password: "admin-secret",
783+
},
784+
}
785+
786+
w := newTestWriter(WriterConfig{
787+
Dialer: dialer,
788+
Topic: topic,
789+
Brokers: []string{"localhost:9093"},
790+
})
791+
792+
w.AllowAutoTopicCreation = true
793+
794+
defer w.Close()
795+
796+
msg := Message{Key: []byte("key"), Value: []byte("Hello World")}
797+
798+
var err error
799+
const retries = 5
800+
for i := 0; i < retries; i++ {
801+
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
802+
defer cancel()
803+
err = w.WriteMessages(ctx, msg)
804+
if errors.Is(err, LeaderNotAvailable) || errors.Is(err, context.DeadlineExceeded) {
805+
time.Sleep(time.Millisecond * 250)
806+
continue
807+
}
808+
809+
if err != nil {
810+
t.Errorf("unexpected error %v", err)
811+
return
812+
}
813+
}
814+
if err != nil {
815+
t.Errorf("unable to create topic %v", err)
816+
}
817+
}
818+
769819
type staticBalancer struct {
770820
partition int
771821
}

0 commit comments

Comments
 (0)