From d9a58f8a5edb86f4a37e9839211ce919490a0b50 Mon Sep 17 00:00:00 2001 From: Dominic Barnes Date: Tue, 20 Oct 2020 22:33:02 -0700 Subject: [PATCH 1/2] feat: improve custom resolver support by allowing port to be overridden --- dialer.go | 50 +++++++++++++++++------- dialer_test.go | 103 +++++++++++++++++++++++++++++++++++++++++++++++++ writer.go | 18 ++------- 3 files changed, 143 insertions(+), 28 deletions(-) diff --git a/dialer.go b/dialer.go index dad26261e..38637ffdf 100644 --- a/dialer.go +++ b/dialer.go @@ -320,20 +320,10 @@ func (d *Dialer) authenticateSASL(ctx context.Context, conn *Conn) error { return nil } -func (d *Dialer) dialContext(ctx context.Context, network string, address string) (net.Conn, error) { - if r := d.Resolver; r != nil { - host, port := splitHostPort(address) - addrs, err := r.LookupHost(ctx, host) - if err != nil { - return nil, err - } - if len(addrs) != 0 { - address = addrs[0] - } - if len(port) != 0 { - address, _ = splitHostPort(address) - address = net.JoinHostPort(address, port) - } +func (d *Dialer) dialContext(ctx context.Context, network string, addr string) (net.Conn, error) { + address, err := lookupHost(ctx, addr, d.Resolver) + if err != nil { + return nil, err } dial := d.DialFunc @@ -441,3 +431,35 @@ func splitHostPort(s string) (host string, port string) { } return } + +func lookupHost(ctx context.Context, address string, resolver Resolver) (string, error) { + host, port := splitHostPort(address) + + if resolver != nil { + resolved, err := resolver.LookupHost(ctx, host) + if err != nil { + return "", err + } + + // if the resolver doesn't return anything, we'll fall back on the provided + // address instead + if len(resolved) > 0 { + resolvedHost, resolvedPort := splitHostPort(resolved[0]) + + // we'll always prefer the resolved host + host = resolvedHost + + // in the case of port though, the provided address takes priority, and we + // only use the resolved address to set the port when not specified + if port == "" { + port = resolvedPort + } + } + } + + if port == "" { + port = "9092" + } + + return net.JoinHostPort(host, port), nil +} diff --git a/dialer_test.go b/dialer_test.go index 57f3c91bc..5aedb777b 100644 --- a/dialer_test.go +++ b/dialer_test.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "crypto/x509" + "fmt" "io" "net" "reflect" @@ -298,3 +299,105 @@ func TestDialerConnectTLSHonorsContext(t *testing.T) { t.FailNow() } } + +func TestDialerResolver(t *testing.T) { + ctx := context.TODO() + + tests := []struct { + scenario string + address string + resolver map[string][]string + }{ + { + scenario: "resolve domain to ip", + address: "example.com", + resolver: map[string][]string{ + "example.com": {"127.0.0.1"}, + }, + }, + { + scenario: "resolve domain to ip and port", + address: "example.com", + resolver: map[string][]string{ + "example.com": {"127.0.0.1:9092"}, + }, + }, + { + scenario: "resolve domain with port to ip", + address: "example.com:9092", + resolver: map[string][]string{ + "example.com": {"127.0.0.1:9092"}, + }, + }, + { + scenario: "resolve domain with port to ip with different port", + address: "example.com:9092", + resolver: map[string][]string{ + "example.com": {"127.0.0.1:80"}, + }, + }, + { + scenario: "resolve domain with port to ip", + address: "example.com:9092", + resolver: map[string][]string{ + "example.com": {"127.0.0.1"}, + }, + }, + } + + for _, test := range tests { + t.Run(test.scenario, func(t *testing.T) { + topic := makeTopic() + createTopic(t, topic, 1) + defer deleteTopic(t, topic) + + d := Dialer{ + Resolver: &mockResolver{addrs: test.resolver}, + } + + // Write a message to ensure the partition gets created. + w := NewWriter(WriterConfig{ + Brokers: []string{"localhost:9092"}, + Topic: topic, + Dialer: &d, + }) + w.WriteMessages(context.Background(), Message{}) + w.Close() + + partitions, err := d.LookupPartitions(ctx, "tcp", test.address, topic) + if err != nil { + t.Error(err) + return + } + + sort.Slice(partitions, func(i int, j int) bool { + return partitions[i].ID < partitions[j].ID + }) + + want := []Partition{ + { + Topic: topic, + Leader: Broker{Host: "localhost", Port: 9092, ID: 1}, + Replicas: []Broker{{Host: "localhost", Port: 9092, ID: 1}}, + Isr: []Broker{{Host: "localhost", Port: 9092, ID: 1}}, + ID: 0, + }, + } + if !reflect.DeepEqual(partitions, want) { + t.Errorf("bad partitions:\ngot: %+v\nwant: %+v", partitions, want) + } + }) + } +} + +type mockResolver struct { + addrs map[string][]string +} + +func (mr *mockResolver) LookupHost(ctx context.Context, host string) ([]string, error) { + if addrs, ok := mr.addrs[host]; !ok { + return nil, fmt.Errorf("unrecognized host %s", host) + } else { + return addrs, nil + } +} diff --git a/writer.go b/writer.go index 6ee9c149b..1c7d81cbb 100644 --- a/writer.go +++ b/writer.go @@ -425,25 +425,15 @@ func NewWriter(config WriterConfig) *Writer { stats := new(writerStats) // For backward compatibility with the pre-0.4 APIs, support custom // resolvers by wrapping the dial function. - dial := func(ctx context.Context, network, address string) (net.Conn, error) { + dial := func(ctx context.Context, network, addr string) (net.Conn, error) { start := time.Now() defer func() { stats.dials.observe(1) stats.dialTime.observe(int64(time.Since(start))) }() - if resolver != nil { - host, port := splitHostPort(address) - addrs, err := resolver.LookupHost(ctx, host) - if err != nil { - return nil, err - } - if len(addrs) != 0 { - address = addrs[0] - } - if len(port) != 0 { - address, _ = splitHostPort(address) - address = net.JoinHostPort(address, port) - } + address, err := lookupHost(ctx, addr, resolver) + if err != nil { + return nil, err } return dialer.DialContext(ctx, network, address) } From 7b19ac0561bdeff488ab268112bf29cc794ec265 Mon Sep 17 00:00:00 2001 From: Dominic Barnes Date: Fri, 23 Oct 2020 10:11:41 -0700 Subject: [PATCH 2/2] chore(dialer): add docs for resolver --- dialer.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/dialer.go b/dialer.go index 38637ffdf..43e8af194 100644 --- a/dialer.go +++ b/dialer.go @@ -65,7 +65,13 @@ type Dialer struct { // support keep-alives ignore this field. KeepAlive time.Duration - // Resolver optionally specifies an alternate resolver to use. + // Resolver optionally gives a hook to convert the broker address into an + // alternate host or IP address which is useful for custom service discovery. + // If a custom resolver returns any possible hosts, the first one will be + // used and the original discarded. If a port number is included with the + // resolved host, it will only be used if a port number was not previously + // specified. If no port is specified or resolved, the default of 9092 will be + // used. Resolver Resolver // TLS enables Dialer to open secure connections. If nil, standard net.Conn