diff --git a/.circleci/config.yml b/.circleci/config.yml index 5ec8e2bbd..1566a5faf 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -123,6 +123,37 @@ jobs: - run: go get -v -t . ./gzip ./lz4 ./sasl ./snappy - run: go test -v -race -cover -timeout 150s $(go list ./... | grep -v examples) + kafka-221: + working_directory: /go/src/github.com/segmentio/kafka-go + environment: + KAFKA_VERSION: "2.2.1" + docker: + - image: circleci/golang + - image: wurstmeister/zookeeper + ports: ['2181:2181'] + - image: wurstmeister/kafka:2.12-2.2.1 + ports: ['9092:9092','9093:9093'] + environment: + KAFKA_BROKER_ID: '1' + KAFKA_CREATE_TOPICS: 'test-writer-0:3:1,test-writer-1:3:1' + KAFKA_DELETE_TOPIC_ENABLE: 'true' + KAFKA_ADVERTISED_HOST_NAME: 'localhost' + KAFKA_ADVERTISED_PORT: '9092' + KAFKA_ZOOKEEPER_CONNECT: 'localhost:2181' + KAFKA_AUTO_CREATE_TOPICS_ENABLE: 'true' + KAFKA_LISTENERS: 'PLAINTEXT://:9092,SASL_PLAINTEXT://:9093' + KAFKA_ADVERTISED_LISTENERS: 'PLAINTEXT://localhost:9092,SASL_PLAINTEXT://localhost:9093' + KAFKA_SASL_ENABLED_MECHANISMS: SCRAM-SHA-256,SCRAM-SHA-512,PLAIN + KAFKA_OPTS: "-Djava.security.auth.login.config=/opt/kafka/config/kafka_server_jaas.conf" + CUSTOM_INIT_SCRIPT: |- + echo -e 'KafkaServer {\norg.apache.kafka.common.security.scram.ScramLoginModule required\n username="adminscram"\n password="admin-secret";\n org.apache.kafka.common.security.plain.PlainLoginModule required\n username="adminplain"\n password="admin-secret"\n user_adminplain="admin-secret";\n };' > /opt/kafka/config/kafka_server_jaas.conf; + /opt/kafka/bin/kafka-configs.sh --zookeeper localhost:2181 --alter --add-config 'SCRAM-SHA-256=[password=admin-secret-256],SCRAM-SHA-512=[password=admin-secret-512]' --entity-type users --entity-name adminscram + steps: + - checkout + - setup_remote_docker: { reusable: true, docker_layer_caching: true } + - run: go get -v -t . ./gzip ./lz4 ./sasl ./snappy + - run: go test -v -race -cover -timeout 150s $(go list ./... | grep -v examples) + workflows: version: 2 run: @@ -131,3 +162,4 @@ workflows: - kafka-011 - kafka-111 - kafka-210 + - kafka-221 \ No newline at end of file diff --git a/writer.go b/writer.go index 55d3f2895..b221c1893 100644 --- a/writer.go +++ b/writer.go @@ -7,6 +7,7 @@ import ( "io" "math/rand" "sort" + "strings" "sync" "time" ) @@ -133,6 +134,45 @@ type WriterConfig struct { newPartitionWriter func(partition int, config WriterConfig, stats *writerStats) partitionWriter } +type WriterError struct { + Msg Message + Err error +} + +func (e *WriterError) Error() string { + return e.Err.Error() +} + +func (e *WriterError) Temporary() bool { + return isTemporary(e.Err) +} + +func (e *WriterError) Timeout() bool { + return isTimeout(e.Err) +} + +func (e *WriterError) Unwrap() error { + return e.Err +} + +type WriterErrors []WriterError + +func (wes WriterErrors) Error() string { + if len(wes) == 1 { + return fmt.Sprintf("1 WriterError occurred:\n\t* %s\n", wes[0].Err) + } + + points := make([]string, len(wes)) + for i, we := range wes { + points[i] = fmt.Sprintf("* %s", we.Err) + } + + return fmt.Sprintf( + "%d WriterErrors occurred:\n\t%s\n", + len(wes), + strings.Join(points, "\n\t")) +} + // WriterStats is a data structure returned by a call to Writer.Stats that // exposes details about the behavior of the writer. type WriterStats struct { @@ -302,57 +342,93 @@ func (w *Writer) WriteMessages(ctx context.Context, msgs ...Message) error { return nil } - var err error - var res chan error + errs := make(WriterErrors, 0, len(msgs)) + var res chan writerResponse if !w.config.Async { - res = make(chan error, len(msgs)) + res = make(chan writerResponse, len(msgs)) } t0 := time.Now() defer w.stats.writeTime.observeDuration(time.Since(t0)) + handled := make(map[int]bool, len(msgs)) w.mutex.RLock() - closed := w.closed - w.mutex.RUnlock() - if closed { - return io.ErrClosedPipe + if w.closed { + w.mutex.RUnlock() + for _, m := range msgs { + errs = append(errs, WriterError{ + Msg: m, + Err: io.ErrClosedPipe, + }) + } + return errs } for i, msg := range msgs { - if int(msg.size()) > w.config.BatchBytes { - err := MessageTooLargeError{ - Message: msg, - Remaining: msgs[i+1:], + errs = append(errs, WriterError{ + Msg: msg, + Err: MessageTooLargeError{ + Message: msg, + }, + }) + handled[i] = true + } else { + select { + case w.msgs <- writerMessage{ + msg: msg, + res: res, + id: i, + }: + case <-ctx.Done(): + w.mutex.RUnlock() + for j, m := range msgs { + // don't double count MessageTooLargeErrors which may already be present in errs + if _, ok := handled[j]; !ok { + errs = append(errs, WriterError{ + Msg: m, + Err: ctx.Err(), + }) + } + } + return errs } - return err - } - - wm := writerMessage{msg: msg, res: res} - - select { - case w.msgs <- wm: - case <-ctx.Done(): - return ctx.Err() } } - + w.mutex.RUnlock() if w.config.Async { + if len(errs) > 0 { + return errs + } return nil } - for i := 0; i != len(msgs); i++ { + sent := len(msgs) - len(handled) + for i := 0; i != sent; i++ { select { - case e := <-res: - if e != nil { - err = e + case r := <-res: + handled[r.id] = true + if r.err != nil { + errs = append(errs, *r.err) } case <-ctx.Done(): - return ctx.Err() + // all unacked msgs become errors + for x := range msgs { + if _, ok := handled[x]; !ok { + errs = append(errs, WriterError{ + Msg: msgs[x], + Err: ctx.Err(), + }) + } + } + return errs } } - return err + if len(errs) > 0 { + return errs + } + return nil } // Stats returns a snapshot of the writer stats since the last time the method @@ -459,7 +535,13 @@ func (w *Writer) run() { err = fmt.Errorf("failed to find any partitions for topic %s", w.config.Topic) } if wm.res != nil { - wm.res <- &writerError{msg: wm.msg, err: err} + wm.res <- writerResponse{ + id: wm.id, + err: &WriterError{ + Msg: wm.msg, + Err: err, + }, + } } } @@ -599,7 +681,8 @@ func (w *writer) run() { var conn *Conn var done bool var batch = make([]Message, 0, w.batchSize) - var resch = make([](chan<- error), 0, w.batchSize) + var resch = make([](chan<- writerResponse), 0, w.batchSize) + var ids = make([]int, 0, w.batchSize) var lastMsg writerMessage var batchSizeBytes int var idleConnDeadline time.Time @@ -616,9 +699,8 @@ func (w *writer) run() { // If a lstMsg exists we need to add it to the batch so we don't lose it. if len(lastMsg.msg.Value) != 0 { batch = append(batch, lastMsg.msg) - if lastMsg.res != nil { - resch = append(resch, lastMsg.res) - } + resch = append(resch, lastMsg.res) + ids = append(ids, lastMsg.id) batchSizeBytes += int(lastMsg.msg.size()) lastMsg = writerMessage{} if !batchTimerRunning { @@ -639,9 +721,8 @@ func (w *writer) run() { break } batch = append(batch, wm.msg) - if wm.res != nil { - resch = append(resch, wm.res) - } + resch = append(resch, wm.res) + ids = append(ids, wm.id) batchSizeBytes += int(wm.msg.size()) mustFlush = len(batch) >= w.batchSize || batchSizeBytes >= w.maxMessageBytes } @@ -672,7 +753,7 @@ func (w *writer) run() { } var err error - if conn, err = w.writeWithRetries(conn, batch, resch); err != nil { + if conn, err = w.writeWithRetries(conn, batch, resch, ids); err != nil { if conn != nil { conn.Close() conn = nil @@ -687,8 +768,13 @@ func (w *writer) run() { for i := range resch { resch[i] = nil } + + for i := range ids { + ids[i] = -1 + } batch = batch[:0] resch = resch[:0] + ids = ids[:0] batchSizeBytes = 0 } } @@ -708,21 +794,37 @@ func (w *writer) dial() (conn *Conn, err error) { return } -func (w *writer) writeWithRetries(conn *Conn, batch []Message, resch [](chan<- error)) (*Conn, error) { +func (w *writer) writeWithRetries(conn *Conn, batch []Message, resch [](chan<- writerResponse), ids []int) (*Conn, error) { var err error - for attempt := 0; attempt < w.maxAttempts; attempt++ { - conn, err = w.write(conn, batch, resch) + conn, err = w.write(conn, batch, resch, ids) if err == nil { break } w.stats.retries.observe(1) time.Sleep(backoff(attempt+1, 100*time.Millisecond, 1*time.Second)) } + + for i, res := range resch { + if res != nil { + var we *WriterError + if err != nil { + we = &WriterError{ + Msg: batch[i], + Err: err, + } + } + res <- writerResponse{ + id: ids[i], + err: we, + } + } + } + return conn, err } -func (w *writer) write(conn *Conn, batch []Message, resch [](chan<- error)) (ret *Conn, err error) { +func (w *writer) write(conn *Conn, batch []Message, resch [](chan<- writerResponse), ids []int) (ret *Conn, err error) { w.stats.writes.observe(1) if conn == nil { if conn, err = w.dial(); err != nil { @@ -730,9 +832,6 @@ func (w *writer) write(conn *Conn, batch []Message, resch [](chan<- error)) (ret w.withErrorLogger(func(logger Logger) { logger.Printf("error dialing kafka brokers for topic %s (partition %d): %s", w.topic, w.partition, err) }) - for i, res := range resch { - res <- &writerError{msg: batch[i], err: err} - } return } } @@ -744,17 +843,11 @@ func (w *writer) write(conn *Conn, batch []Message, resch [](chan<- error)) (ret w.withErrorLogger(func(logger Logger) { logger.Printf("error writing messages to %s (partition %d): %s", w.topic, w.partition, err) }) - for i, res := range resch { - res <- &writerError{msg: batch[i], err: err} - } } else { for _, m := range batch { w.stats.messages.observe(1) w.stats.bytes.observe(int64(len(m.Key) + len(m.Value))) } - for _, res := range resch { - res <- nil - } } t1 := time.Now() w.stats.waitTime.observeDuration(t1.Sub(t0)) @@ -766,28 +859,13 @@ func (w *writer) write(conn *Conn, batch []Message, resch [](chan<- error)) (ret type writerMessage struct { msg Message - res chan<- error -} - -type writerError struct { - msg Message - err error -} - -func (e *writerError) Cause() error { - return e.err -} - -func (e *writerError) Error() string { - return e.err.Error() -} - -func (e *writerError) Temporary() bool { - return isTemporary(e.err) + res chan<- writerResponse + id int } -func (e *writerError) Timeout() bool { - return isTimeout(e.err) +type writerResponse struct { + id int + err *WriterError } func shuffledStrings(list []string) []string { diff --git a/writer_test.go b/writer_test.go index fa419b98f..cc10396a4 100644 --- a/writer_test.go +++ b/writer_test.go @@ -5,7 +5,6 @@ import ( "errors" "io" "math" - "strings" "testing" "time" ) @@ -21,15 +20,25 @@ func TestWriter(t *testing.T) { scenario: "closing a writer right after creating it returns promptly with no error", function: testWriterClose, }, - + { + scenario: "writing messages on closed writer should return error", + function: testClosedWriterErr, + }, + { + scenario: "writing empty Message slice returns promptly with no error", + function: testEmptyWrite, + }, + { + scenario: "writing messages after context is done should return an error", + function: testContextDoneErr, + }, { scenario: "writing 1 message through a writer using round-robin balancing produces 1 message to the first partition", function: testWriterRoundRobin1, }, - { - scenario: "running out of max attempts should return an error", - function: testWriterMaxAttemptsErr, + scenario: "errors returned when writing messages should be WriterErrors", + function: testWriterErrors, }, { scenario: "writing a message larger then the max bytes should return an error", @@ -58,6 +67,69 @@ func TestWriter(t *testing.T) { } } +type writerTestCase WriterErrors + +func (wt writerTestCase) errorsEqual(wes WriterErrors) bool { + exp := make(map[string]int) + numExp := 0 + for _, t := range wt { + if t.Err != nil { + numExp += 1 + k := string(t.Msg.Value) + t.Err.Error() + if _, ok := exp[k]; ok { + exp[k] += 1 + } else { + exp[k] = 1 + } + } + } + + if len(wes) != numExp { + return false + } + + for _, e := range wes { + k := string(e.Msg.Value) + e.Err.Error() + if _, ok := exp[k]; ok { + exp[k] -= 1 + } else { + return false + } + } + + for _, e := range exp { + if e != 0 { + return false + } + } + + return true +} + +func (wt writerTestCase) msgs() []Message { + msgs := make([]Message, len(wt)) + for i, m := range wt { + msgs[i] = m.Msg + } + + return msgs +} + +func (wt writerTestCase) expected() WriterErrors { + exp := make(WriterErrors, 0, len(wt)) + for _, v := range wt { + if v.Err != nil { + exp = append(exp, v) + } + } + + if len(exp) > 0 { + return exp + } + + return nil +} + func newTestWriter(config WriterConfig) *Writer { if len(config.Brokers) == 0 { config.Brokers = []string{"localhost:9092"} @@ -78,6 +150,120 @@ func testWriterClose(t *testing.T) { } } +func testClosedWriterErr(t *testing.T) { + tcs := []writerTestCase{ + { + { + Msg: Message{Value: []byte("Hello World!")}, + Err: io.ErrClosedPipe, + }, + }, + { + { + Msg: Message{Value: []byte("Hello")}, + Err: io.ErrClosedPipe, + }, + { + Msg: Message{Value: []byte("World!")}, + Err: io.ErrClosedPipe, + }, + }, + } + + const topic = "test-writer-0" + w := newTestWriter(WriterConfig{ + Topic: topic, + }) + + if err := w.Close(); err != nil { + t.Fatal(err) + } + + for i, tc := range tcs { + err := w.WriteMessages(context.Background(), tc.msgs()...) + if err == nil { + t.Errorf("test %d: expected error", i) + continue + } + + wes, ok := err.(WriterErrors) + if !ok { + t.Errorf("test %d: expected WriterErrors", i) + continue + } + + if !tc.errorsEqual(wes) { + t.Errorf("test %d: unexpected errors occurred.\nExpected:\n%sFound:\n%s", i, tc.expected(), wes) + } + } +} + +func testEmptyWrite(t *testing.T) { + const topic = "test-writer-0" + w := newTestWriter(WriterConfig{ + Topic: topic, + }) + + defer func() { + _ = w.Close() + }() + + if err := w.WriteMessages(context.Background(), []Message{}...); err != nil { + t.Error("unexpected error occurred", err) + } +} + +func testContextDoneErr(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + tcs := []writerTestCase{ + { + { + Msg: Message{Value: []byte("Hello World!")}, + Err: ctx.Err(), + }, + }, + { + { + Msg: Message{Value: []byte("Hello")}, + Err: ctx.Err(), + }, + { + Msg: Message{Value: []byte("World")}, + Err: ctx.Err(), + }, + }, + } + + const topic = "test-writer-0" + w := newTestWriter(WriterConfig{ + Topic: topic, + }) + + defer func() { + _ = w.Close() + }() + + for i, tc := range tcs { + err := w.WriteMessages(ctx, tc.msgs()...) + if err == nil { + t.Errorf("test %d: expected error", i) + continue + } + + wes, ok := err.(WriterErrors) + if !ok { + t.Errorf("test %d: expected WriterErrors", i) + continue + } + + if !tc.errorsEqual(wes) { + t.Errorf("test %d: unexpected errors occurred.\nExpected:\n%sFound:\n%s", i, tc.expected(), wes) + } + } +} + func testWriterRoundRobin1(t *testing.T) { const topic = "test-writer-1" @@ -140,16 +326,20 @@ func TestValidateWriter(t *testing.T) { } } -type fakeWriter struct{} +type errorWriter struct{} -func (f *fakeWriter) messages() chan<- writerMessage { +func (w *errorWriter) messages() chan<- writerMessage { ch := make(chan writerMessage, 1) go func() { for { msg := <-ch - msg.res <- &writerError{ - err: errors.New("bad attempt"), + msg.res <- writerResponse{ + id: msg.id, + err: &WriterError{ + Err: errors.New("bad attempt"), + Msg: msg.msg, + }, } } }() @@ -157,11 +347,30 @@ func (f *fakeWriter) messages() chan<- writerMessage { return ch } -func (f *fakeWriter) close() { +func (w *errorWriter) close() { } -func testWriterMaxAttemptsErr(t *testing.T) { +func testWriterErrors(t *testing.T) { + tcs := []writerTestCase{ + { + { + Msg: Message{Value: []byte("test 1 error")}, + Err: errors.New("bad attempt"), + }, + }, + { + { + Msg: Message{Value: []byte("test multi error")}, + Err: errors.New("bad attempt"), + }, + { + Msg: Message{Value: []byte("test multi error")}, + Err: errors.New("bad attempt"), + }, + }, + } + const topic = "test-writer-2" createTopic(t, topic, 1) @@ -169,34 +378,68 @@ func testWriterMaxAttemptsErr(t *testing.T) { Topic: topic, MaxAttempts: 1, Balancer: &RoundRobin{}, - newPartitionWriter: func(p int, config WriterConfig, stats *writerStats) partitionWriter { - return &fakeWriter{} + newPartitionWriter: func(_ int, _ WriterConfig, _ *writerStats) partitionWriter { + return &errorWriter{} }, }) - defer w.Close() + defer func() { + _ = w.Close() + }() - if err := w.WriteMessages(context.Background(), Message{ - Value: []byte("Hello World!"), - }); err == nil { - t.Error("expected error") - return - } else if err != nil { - if !strings.Contains(err.Error(), "bad attempt") { - t.Errorf("unexpected error: %s", err) - return + for i, tc := range tcs { + err := w.WriteMessages(context.Background(), tc.msgs()...) + if err == nil { + t.Errorf("test %d: expected error", i) + continue + } + + wes, ok := err.(WriterErrors) + if !ok { + t.Errorf("test %d: expected WriterErrors", i) + continue + } + + if !tc.errorsEqual(wes) { + t.Errorf("test %d: unexpected errors occurred.\nExpected:\n%sFound:\n%s", i, tc.expected(), wes) } } } func testWriterMaxBytes(t *testing.T) { - topic := makeTopic() + tcs := []writerTestCase{ + { + { + Msg: Message{Value: []byte("Hello World!")}, + Err: MessageTooLargeError{}, + }, + { + Msg: Message{Value: []byte("Hi")}, + Err: nil, + }, + }, + { + { + Msg: Message{Value: []byte("Too large!")}, + Err: MessageTooLargeError{}, + }, + { + Msg: Message{Value: []byte("Also too long!")}, + Err: MessageTooLargeError{}, + }, + }, + } + topic := makeTopic() + maxBytes := 25 createTopic(t, topic, 1) w := newTestWriter(WriterConfig{ Topic: topic, - BatchBytes: 25, + BatchBytes: maxBytes, }) - defer w.Close() + + defer func() { + _ = w.Close() + }() if err := w.WriteMessages(context.Background(), Message{ Value: []byte("Hi"), @@ -205,37 +448,21 @@ func testWriterMaxBytes(t *testing.T) { return } - firstMsg := []byte("Hello World!") - secondMsg := []byte("LeftOver!") - msgs := []Message{ - { - Value: firstMsg, - }, - { - Value: secondMsg, - }, - } - if err := w.WriteMessages(context.Background(), msgs...); err == nil { - t.Error("expected error") - return - } else if err != nil { - switch e := err.(type) { - case MessageTooLargeError: - if string(e.Message.Value) != string(firstMsg) { - t.Errorf("unxpected returned message. Expected: %s, Got %s", firstMsg, e.Message.Value) - return - } - if len(e.Remaining) != 1 { - t.Error("expected remaining errors; found none") - return - } - if string(e.Remaining[0].Value) != string(secondMsg) { - t.Errorf("unxpected returned message. Expected: %s, Got %s", secondMsg, e.Message.Value) - return - } - default: - t.Errorf("unexpected error: %s", err) - return + for i, tc := range tcs { + err := w.WriteMessages(context.Background(), tc.msgs()...) + if err == nil { + t.Errorf("test %d: expected error", i) + continue + } + + wes, ok := err.(WriterErrors) + if !ok { + t.Errorf("test %d: expected WriterErrors", i) + continue + } + + if !tc.errorsEqual(wes) { + t.Errorf("test %d: unexpected errors occurred.\nExpected:\n%sFound:\n%s", i, tc.expected(), wes) } } }