Skip to content

Commit 87949af

Browse files
improve custom resolver support by allowing port to be overridden (#545)
1 parent e6b8599 commit 87949af

File tree

3 files changed

+150
-29
lines changed

3 files changed

+150
-29
lines changed

dialer.go

+43-15
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,13 @@ type Dialer struct {
6565
// support keep-alives ignore this field.
6666
KeepAlive time.Duration
6767

68-
// Resolver optionally specifies an alternate resolver to use.
68+
// Resolver optionally gives a hook to convert the broker address into an
69+
// alternate host or IP address which is useful for custom service discovery.
70+
// If a custom resolver returns any possible hosts, the first one will be
71+
// used and the original discarded. If a port number is included with the
72+
// resolved host, it will only be used if a port number was not previously
73+
// specified. If no port is specified or resolved, the default of 9092 will be
74+
// used.
6975
Resolver Resolver
7076

7177
// TLS enables Dialer to open secure connections. If nil, standard net.Conn
@@ -320,20 +326,10 @@ func (d *Dialer) authenticateSASL(ctx context.Context, conn *Conn) error {
320326
return nil
321327
}
322328

323-
func (d *Dialer) dialContext(ctx context.Context, network string, address string) (net.Conn, error) {
324-
if r := d.Resolver; r != nil {
325-
host, port := splitHostPort(address)
326-
addrs, err := r.LookupHost(ctx, host)
327-
if err != nil {
328-
return nil, err
329-
}
330-
if len(addrs) != 0 {
331-
address = addrs[0]
332-
}
333-
if len(port) != 0 {
334-
address, _ = splitHostPort(address)
335-
address = net.JoinHostPort(address, port)
336-
}
329+
func (d *Dialer) dialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
330+
address, err := lookupHost(ctx, addr, d.Resolver)
331+
if err != nil {
332+
return nil, err
337333
}
338334

339335
dial := d.DialFunc
@@ -441,3 +437,35 @@ func splitHostPort(s string) (host string, port string) {
441437
}
442438
return
443439
}
440+
441+
func lookupHost(ctx context.Context, address string, resolver Resolver) (string, error) {
442+
host, port := splitHostPort(address)
443+
444+
if resolver != nil {
445+
resolved, err := resolver.LookupHost(ctx, host)
446+
if err != nil {
447+
return "", err
448+
}
449+
450+
// if the resolver doesn't return anything, we'll fall back on the provided
451+
// address instead
452+
if len(resolved) > 0 {
453+
resolvedHost, resolvedPort := splitHostPort(resolved[0])
454+
455+
// we'll always prefer the resolved host
456+
host = resolvedHost
457+
458+
// in the case of port though, the provided address takes priority, and we
459+
// only use the resolved address to set the port when not specified
460+
if port == "" {
461+
port = resolvedPort
462+
}
463+
}
464+
}
465+
466+
if port == "" {
467+
port = "9092"
468+
}
469+
470+
return net.JoinHostPort(host, port), nil
471+
}

dialer_test.go

+103
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"crypto/tls"
66
"crypto/x509"
7+
"fmt"
78
"io"
89
"net"
910
"reflect"
@@ -298,3 +299,105 @@ func TestDialerConnectTLSHonorsContext(t *testing.T) {
298299
t.FailNow()
299300
}
300301
}
302+
303+
func TestDialerResolver(t *testing.T) {
304+
ctx := context.TODO()
305+
306+
tests := []struct {
307+
scenario string
308+
address string
309+
resolver map[string][]string
310+
}{
311+
{
312+
scenario: "resolve domain to ip",
313+
address: "example.com",
314+
resolver: map[string][]string{
315+
"example.com": {"127.0.0.1"},
316+
},
317+
},
318+
{
319+
scenario: "resolve domain to ip and port",
320+
address: "example.com",
321+
resolver: map[string][]string{
322+
"example.com": {"127.0.0.1:9092"},
323+
},
324+
},
325+
{
326+
scenario: "resolve domain with port to ip",
327+
address: "example.com:9092",
328+
resolver: map[string][]string{
329+
"example.com": {"127.0.0.1:9092"},
330+
},
331+
},
332+
{
333+
scenario: "resolve domain with port to ip with different port",
334+
address: "example.com:9092",
335+
resolver: map[string][]string{
336+
"example.com": {"127.0.0.1:80"},
337+
},
338+
},
339+
{
340+
scenario: "resolve domain with port to ip",
341+
address: "example.com:9092",
342+
resolver: map[string][]string{
343+
"example.com": {"127.0.0.1"},
344+
},
345+
},
346+
}
347+
348+
for _, test := range tests {
349+
t.Run(test.scenario, func(t *testing.T) {
350+
topic := makeTopic()
351+
createTopic(t, topic, 1)
352+
defer deleteTopic(t, topic)
353+
354+
d := Dialer{
355+
Resolver: &mockResolver{addrs: test.resolver},
356+
}
357+
358+
// Write a message to ensure the partition gets created.
359+
w := NewWriter(WriterConfig{
360+
Brokers: []string{"localhost:9092"},
361+
Topic: topic,
362+
Dialer: &d,
363+
})
364+
w.WriteMessages(context.Background(), Message{})
365+
w.Close()
366+
367+
partitions, err := d.LookupPartitions(ctx, "tcp", test.address, topic)
368+
if err != nil {
369+
t.Error(err)
370+
return
371+
}
372+
373+
sort.Slice(partitions, func(i int, j int) bool {
374+
return partitions[i].ID < partitions[j].ID
375+
})
376+
377+
want := []Partition{
378+
{
379+
Topic: topic,
380+
Leader: Broker{Host: "localhost", Port: 9092, ID: 1},
381+
Replicas: []Broker{{Host: "localhost", Port: 9092, ID: 1}},
382+
Isr: []Broker{{Host: "localhost", Port: 9092, ID: 1}},
383+
ID: 0,
384+
},
385+
}
386+
if !reflect.DeepEqual(partitions, want) {
387+
t.Errorf("bad partitions:\ngot: %+v\nwant: %+v", partitions, want)
388+
}
389+
})
390+
}
391+
}
392+
393+
type mockResolver struct {
394+
addrs map[string][]string
395+
}
396+
397+
func (mr *mockResolver) LookupHost(ctx context.Context, host string) ([]string, error) {
398+
if addrs, ok := mr.addrs[host]; !ok {
399+
return nil, fmt.Errorf("unrecognized host %s", host)
400+
} else {
401+
return addrs, nil
402+
}
403+
}

writer.go

+4-14
Original file line numberDiff line numberDiff line change
@@ -425,25 +425,15 @@ func NewWriter(config WriterConfig) *Writer {
425425
stats := new(writerStats)
426426
// For backward compatibility with the pre-0.4 APIs, support custom
427427
// resolvers by wrapping the dial function.
428-
dial := func(ctx context.Context, network, address string) (net.Conn, error) {
428+
dial := func(ctx context.Context, network, addr string) (net.Conn, error) {
429429
start := time.Now()
430430
defer func() {
431431
stats.dials.observe(1)
432432
stats.dialTime.observe(int64(time.Since(start)))
433433
}()
434-
if resolver != nil {
435-
host, port := splitHostPort(address)
436-
addrs, err := resolver.LookupHost(ctx, host)
437-
if err != nil {
438-
return nil, err
439-
}
440-
if len(addrs) != 0 {
441-
address = addrs[0]
442-
}
443-
if len(port) != 0 {
444-
address, _ = splitHostPort(address)
445-
address = net.JoinHostPort(address, port)
446-
}
434+
address, err := lookupHost(ctx, addr, resolver)
435+
if err != nil {
436+
return nil, err
447437
}
448438
return dialer.DialContext(ctx, network, address)
449439
}

0 commit comments

Comments
 (0)