Skip to content

Commit e6b8599

Browse files
Achillenlsun
Achille
andauthored
fix zstd decoder leak (#543)
* fix zstd decoder leak * fix tests * fix panic * fix tests (2) * fix tests (3) * fix tests (4) * move ConnWaitGroup to testing package * fix zstd codec * Update compress/zstd/zstd.go Co-authored-by: Nicholas Sun <[email protected]> * PR feedback Co-authored-by: Nicholas Sun <[email protected]>
1 parent b7a001a commit e6b8599

File tree

7 files changed

+237
-137
lines changed

7 files changed

+237
-137
lines changed

client_test.go

+34-44
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,43 @@ import (
66
"io"
77
"math/rand"
88
"net"
9-
"sync"
109
"testing"
1110
"time"
1211

1312
"github.com/segmentio/kafka-go/compress"
13+
ktesting "github.com/segmentio/kafka-go/testing"
1414
)
1515

1616
func newLocalClientAndTopic() (*Client, string, func()) {
1717
topic := makeTopic()
18-
client, shutdown := newClient(TCP("localhost"))
18+
client, shutdown := newLocalClientWithTopic(topic, 1)
19+
return client, topic, shutdown
20+
}
21+
22+
func newLocalClientWithTopic(topic string, partitions int) (*Client, func()) {
23+
client, shutdown := newLocalClient()
24+
if err := clientCreateTopic(client, topic, partitions); err != nil {
25+
shutdown()
26+
panic(err)
27+
}
28+
return client, func() {
29+
client.DeleteTopics(context.Background(), &DeleteTopicsRequest{
30+
Topics: []string{topic},
31+
})
32+
shutdown()
33+
}
34+
}
1935

36+
func clientCreateTopic(client *Client, topic string, partitions int) error {
2037
_, err := client.CreateTopics(context.Background(), &CreateTopicsRequest{
2138
Topics: []TopicConfig{{
2239
Topic: topic,
23-
NumPartitions: 1,
40+
NumPartitions: partitions,
2441
ReplicationFactor: 1,
2542
}},
2643
})
2744
if err != nil {
28-
shutdown()
29-
panic(err)
45+
return err
3046
}
3147

3248
// Topic creation seems to be asynchronous. Metadata for the topic partition
@@ -48,21 +64,16 @@ func newLocalClientAndTopic() (*Client, string, func()) {
4864
time.Sleep(100 * time.Millisecond)
4965
}
5066

51-
return client, topic, func() {
52-
client.DeleteTopics(context.Background(), &DeleteTopicsRequest{
53-
Topics: []string{topic},
54-
})
55-
shutdown()
56-
}
67+
return nil
5768
}
5869

5970
func newLocalClient() (*Client, func()) {
6071
return newClient(TCP("localhost"))
6172
}
6273

6374
func newClient(addr net.Addr) (*Client, func()) {
64-
conns := &connWaitGroup{
65-
dial: (&net.Dialer{}).DialContext,
75+
conns := &ktesting.ConnWaitGroup{
76+
DialFunc: (&net.Dialer{}).DialContext,
6677
}
6778

6879
transport := &Transport{
@@ -79,31 +90,6 @@ func newClient(addr net.Addr) (*Client, func()) {
7990
return client, func() { transport.CloseIdleConnections(); conns.Wait() }
8091
}
8192

82-
type connWaitGroup struct {
83-
dial func(context.Context, string, string) (net.Conn, error)
84-
sync.WaitGroup
85-
}
86-
87-
func (g *connWaitGroup) Dial(ctx context.Context, network, address string) (net.Conn, error) {
88-
c, err := g.dial(ctx, network, address)
89-
if err != nil {
90-
return nil, err
91-
}
92-
g.Add(1)
93-
return &groupConn{Conn: c, group: g}, nil
94-
}
95-
96-
type groupConn struct {
97-
net.Conn
98-
group *connWaitGroup
99-
once sync.Once
100-
}
101-
102-
func (c *groupConn) Close() error {
103-
defer c.once.Do(c.group.Done)
104-
return c.Conn.Close()
105-
}
106-
10793
func TestClient(t *testing.T) {
10894
tests := []struct {
10995
scenario string
@@ -121,20 +107,23 @@ func TestClient(t *testing.T) {
121107
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
122108
defer cancel()
123109

124-
c := &Client{Addr: TCP("localhost:9092")}
125-
testFunc(t, ctx, c)
110+
client, shutdown := newLocalClient()
111+
defer shutdown()
112+
113+
testFunc(t, ctx, client)
126114
})
127115
}
128116
}
129117

130-
func testConsumerGroupFetchOffsets(t *testing.T, ctx context.Context, c *Client) {
118+
func testConsumerGroupFetchOffsets(t *testing.T, ctx context.Context, client *Client) {
131119
const totalMessages = 144
132120
const partitions = 12
133121
const msgPerPartition = totalMessages / partitions
134122

135123
topic := makeTopic()
136-
createTopic(t, topic, partitions)
137-
defer deleteTopic(t, topic)
124+
if err := clientCreateTopic(client, topic, partitions); err != nil {
125+
t.Fatal(err)
126+
}
138127

139128
groupId := makeGroupID()
140129
brokers := []string{"localhost:9092"}
@@ -144,6 +133,7 @@ func testConsumerGroupFetchOffsets(t *testing.T, ctx context.Context, c *Client)
144133
Topic: topic,
145134
Balancer: &RoundRobin{},
146135
BatchSize: 1,
136+
Transport: client.Transport,
147137
}
148138
if err := writer.WriteMessages(ctx, makeTestSequence(totalMessages)...); err != nil {
149139
t.Fatalf("bad write messages: %v", err)
@@ -172,7 +162,7 @@ func testConsumerGroupFetchOffsets(t *testing.T, ctx context.Context, c *Client)
172162
}
173163
}
174164

175-
offsets, err := c.ConsumerOffsets(ctx, TopicAndGroup{GroupId: groupId, Topic: topic})
165+
offsets, err := client.ConsumerOffsets(ctx, TopicAndGroup{GroupId: groupId, Topic: topic})
176166
if err != nil {
177167
t.Fatal(err)
178168
}

compress/compress_test.go

+71-47
Original file line numberDiff line numberDiff line change
@@ -88,19 +88,24 @@ func testEncodeDecode(t *testing.T, m kafka.Message, codec pkg.Codec) {
8888
t.Run("encode with "+codec.Name(), func(t *testing.T) {
8989
r1, err = compress(codec, m.Value)
9090
if err != nil {
91-
t.Error(err)
91+
t.Fatal(err)
9292
}
9393
})
9494

9595
t.Run("decode with "+codec.Name(), func(t *testing.T) {
96+
if r1 == nil {
97+
if r1, err = compress(codec, m.Value); err != nil {
98+
t.Fatal(err)
99+
}
100+
}
96101
r2, err = decompress(codec, r1)
97102
if err != nil {
98-
t.Error(err)
103+
t.Fatal(err)
99104
}
100105
if string(r2) != "message" {
101106
t.Error("bad message")
102-
t.Log("got: ", string(r2))
103-
t.Log("expected: ", string(m.Value))
107+
t.Logf("expected: %q", string(m.Value))
108+
t.Logf("got: %q", string(r2))
104109
}
105110
})
106111
}
@@ -116,15 +121,16 @@ func TestCompressedMessages(t *testing.T) {
116121
}
117122

118123
func testCompressedMessages(t *testing.T, codec pkg.Codec) {
119-
t.Run("produce/consume with"+codec.Name(), func(t *testing.T) {
120-
topic := createTopic(t, 1)
121-
defer deleteTopic(t, topic)
124+
t.Run(codec.Name(), func(t *testing.T) {
125+
client, topic, shutdown := newLocalClientAndTopic()
126+
defer shutdown()
122127

123128
w := &kafka.Writer{
124129
Addr: kafka.TCP("127.0.0.1:9092"),
125130
Topic: topic,
126131
Compression: kafka.Compression(codec.Code()),
127132
BatchTimeout: 10 * time.Millisecond,
133+
Transport: client.Transport,
128134
}
129135
defer w.Close()
130136

@@ -185,19 +191,23 @@ func testCompressedMessages(t *testing.T, codec pkg.Codec) {
185191
}
186192

187193
func TestMixedCompressedMessages(t *testing.T) {
188-
topic := createTopic(t, 1)
189-
defer deleteTopic(t, topic)
194+
client, topic, shutdown := newLocalClientAndTopic()
195+
defer shutdown()
190196

191197
offset := 0
192198
var values []string
193199
produce := func(n int, codec pkg.Codec) {
194200
w := &kafka.Writer{
195-
Addr: kafka.TCP("127.0.0.1:9092"),
196-
Topic: topic,
197-
Compression: kafka.Compression(codec.Code()),
201+
Addr: kafka.TCP("127.0.0.1:9092"),
202+
Topic: topic,
203+
Transport: client.Transport,
198204
}
199205
defer w.Close()
200206

207+
if codec != nil {
208+
w.Compression = kafka.Compression(codec.Code())
209+
}
210+
201211
msgs := make([]kafka.Message, n)
202212
for i := range msgs {
203213
value := fmt.Sprintf("Hello World %d!", offset)
@@ -407,58 +417,72 @@ func benchmarkCompression(b *testing.B, codec pkg.Codec, buf *bytes.Buffer, payl
407417
return 1 - (float64(buf.Len()) / float64(len(payload)))
408418
}
409419

420+
func init() {
421+
rand.Seed(time.Now().UnixNano())
422+
}
423+
410424
func makeTopic() string {
411425
return fmt.Sprintf("kafka-go-%016x", rand.Int63())
412426
}
413427

414-
func createTopic(t *testing.T, partitions int) string {
428+
func newLocalClientAndTopic() (*kafka.Client, string, func()) {
415429
topic := makeTopic()
416-
417-
conn, err := kafka.Dial("tcp", "localhost:9092")
430+
client, shutdown := newLocalClient()
431+
432+
_, err := client.CreateTopics(context.Background(), &kafka.CreateTopicsRequest{
433+
Topics: []kafka.TopicConfig{{
434+
Topic: topic,
435+
NumPartitions: 1,
436+
ReplicationFactor: 1,
437+
}},
438+
})
418439
if err != nil {
419-
t.Fatal(err)
440+
shutdown()
441+
panic(err)
420442
}
421-
defer conn.Close()
422443

423-
err = conn.CreateTopics(kafka.TopicConfig{
424-
Topic: topic,
425-
NumPartitions: partitions,
426-
ReplicationFactor: 1,
427-
})
444+
// Topic creation seems to be asynchronous. Metadata for the topic partition
445+
// layout in the cluster is available in the controller before being synced
446+
// with the other brokers, which causes "Error:[3] Unknown Topic Or Partition"
447+
// when sending requests to the partition leaders.
448+
for i := 0; i < 20; i++ {
449+
r, err := client.Fetch(context.Background(), &kafka.FetchRequest{
450+
Topic: topic,
451+
Partition: 0,
452+
Offset: 0,
453+
})
454+
if err == nil && r.Error == nil {
455+
break
456+
}
457+
time.Sleep(100 * time.Millisecond)
458+
}
428459

429-
switch err {
430-
case nil:
431-
// ok
432-
case kafka.TopicAlreadyExists:
433-
// ok
434-
default:
435-
t.Error("bad createTopics", err)
436-
t.FailNow()
460+
return client, topic, func() {
461+
client.DeleteTopics(context.Background(), &kafka.DeleteTopicsRequest{
462+
Topics: []string{topic},
463+
})
464+
shutdown()
437465
}
466+
}
438467

439-
return topic
468+
func newLocalClient() (*kafka.Client, func()) {
469+
return newClient(kafka.TCP("127.0.0.1:9092"))
440470
}
441471

442-
func deleteTopic(t *testing.T, topic ...string) {
443-
conn, err := kafka.Dial("tcp", "localhost:9092")
444-
if err != nil {
445-
t.Fatal(err)
472+
func newClient(addr net.Addr) (*kafka.Client, func()) {
473+
conns := &ktesting.ConnWaitGroup{
474+
DialFunc: (&net.Dialer{}).DialContext,
446475
}
447-
defer conn.Close()
448476

449-
controller, err := conn.Controller()
450-
if err != nil {
451-
t.Fatal(err)
477+
transport := &kafka.Transport{
478+
Dial: conns.Dial,
452479
}
453480

454-
conn, err = kafka.Dial("tcp", net.JoinHostPort(controller.Host, strconv.Itoa(controller.Port)))
455-
if err != nil {
456-
t.Fatal(err)
481+
client := &kafka.Client{
482+
Addr: addr,
483+
Timeout: 5 * time.Second,
484+
Transport: transport,
457485
}
458486

459-
conn.SetDeadline(time.Now().Add(2 * time.Second))
460-
461-
if err := conn.DeleteTopics(topic...); err != nil {
462-
t.Fatal(err)
463-
}
487+
return client, func() { transport.CloseIdleConnections(); conns.Wait() }
464488
}

0 commit comments

Comments
 (0)