Skip to content

improve custom resolver support by allowing port to be overridden #545

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 23, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 36 additions & 14 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,20 +320,10 @@ func (d *Dialer) authenticateSASL(ctx context.Context, conn *Conn) error {
return nil
}

func (d *Dialer) dialContext(ctx context.Context, network string, address string) (net.Conn, error) {
if r := d.Resolver; r != nil {
host, port := splitHostPort(address)
addrs, err := r.LookupHost(ctx, host)
if err != nil {
return nil, err
}
if len(addrs) != 0 {
address = addrs[0]
}
if len(port) != 0 {
address, _ = splitHostPort(address)
address = net.JoinHostPort(address, port)
}
func (d *Dialer) dialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
address, err := lookupHost(ctx, addr, d.Resolver)
if err != nil {
return nil, err
}

dial := d.DialFunc
Expand Down Expand Up @@ -441,3 +431,35 @@ func splitHostPort(s string) (host string, port string) {
}
return
}

func lookupHost(ctx context.Context, address string, resolver Resolver) (string, error) {
host, port := splitHostPort(address)

if resolver != nil {
resolved, err := resolver.LookupHost(ctx, host)
if err != nil {
return "", err
}

// if the resolver doesn't return anything, we'll fall back on the provided
// address instead
if len(resolved) > 0 {
resolvedHost, resolvedPort := splitHostPort(resolved[0])

// we'll always prefer the resolved host
host = resolvedHost

// in the case of port though, the provided address takes priority, and we
// only use the resolved address to set the port when not specified
if port == "" {
port = resolvedPort
}
}
}

if port == "" {
port = "9092"
}

return net.JoinHostPort(host, port), nil
}
103 changes: 103 additions & 0 deletions dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"net"
"reflect"
Expand Down Expand Up @@ -298,3 +299,105 @@ func TestDialerConnectTLSHonorsContext(t *testing.T) {
t.FailNow()
}
}

func TestDialerResolver(t *testing.T) {
ctx := context.TODO()

tests := []struct {
scenario string
address string
resolver map[string][]string
}{
{
scenario: "resolve domain to ip",
address: "example.com",
resolver: map[string][]string{
"example.com": {"127.0.0.1"},
},
},
{
scenario: "resolve domain to ip and port",
address: "example.com",
resolver: map[string][]string{
"example.com": {"127.0.0.1:9092"},
},
},
{
scenario: "resolve domain with port to ip",
address: "example.com:9092",
resolver: map[string][]string{
"example.com": {"127.0.0.1:9092"},
},
},
{
scenario: "resolve domain with port to ip with different port",
address: "example.com:9092",
resolver: map[string][]string{
"example.com": {"127.0.0.1:80"},
},
},
{
scenario: "resolve domain with port to ip",
address: "example.com:9092",
resolver: map[string][]string{
"example.com": {"127.0.0.1"},
},
},
}

for _, test := range tests {
t.Run(test.scenario, func(t *testing.T) {
topic := makeTopic()
createTopic(t, topic, 1)
defer deleteTopic(t, topic)

d := Dialer{
Resolver: &mockResolver{addrs: test.resolver},
}

// Write a message to ensure the partition gets created.
w := NewWriter(WriterConfig{
Brokers: []string{"localhost:9092"},
Topic: topic,
Dialer: &d,
})
w.WriteMessages(context.Background(), Message{})
w.Close()

partitions, err := d.LookupPartitions(ctx, "tcp", test.address, topic)
if err != nil {
t.Error(err)
return
}

sort.Slice(partitions, func(i int, j int) bool {
return partitions[i].ID < partitions[j].ID
})

want := []Partition{
{
Topic: topic,
Leader: Broker{Host: "localhost", Port: 9092, ID: 1},
Replicas: []Broker{{Host: "localhost", Port: 9092, ID: 1}},
Isr: []Broker{{Host: "localhost", Port: 9092, ID: 1}},
ID: 0,
},
}
if !reflect.DeepEqual(partitions, want) {
t.Errorf("bad partitions:\ngot: %+v\nwant: %+v", partitions, want)
}
})
}
}

type mockResolver struct {
addrs map[string][]string
}

func (mr *mockResolver) LookupHost(ctx context.Context, host string) ([]string, error) {
if addrs, ok := mr.addrs[host]; !ok {
return nil, fmt.Errorf("unrecognized host %s", host)
} else {
return addrs, nil
}
}
18 changes: 4 additions & 14 deletions writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -425,25 +425,15 @@ func NewWriter(config WriterConfig) *Writer {
stats := new(writerStats)
// For backward compatibility with the pre-0.4 APIs, support custom
// resolvers by wrapping the dial function.
dial := func(ctx context.Context, network, address string) (net.Conn, error) {
dial := func(ctx context.Context, network, addr string) (net.Conn, error) {
start := time.Now()
defer func() {
stats.dials.observe(1)
stats.dialTime.observe(int64(time.Since(start)))
}()
if resolver != nil {
host, port := splitHostPort(address)
addrs, err := resolver.LookupHost(ctx, host)
if err != nil {
return nil, err
}
if len(addrs) != 0 {
address = addrs[0]
}
if len(port) != 0 {
address, _ = splitHostPort(address)
address = net.JoinHostPort(address, port)
}
address, err := lookupHost(ctx, addr, resolver)
if err != nil {
return nil, err
}
return dialer.DialContext(ctx, network, address)
}
Expand Down