Skip to content

Commit a5fdce2

Browse files
Merge pull request #12 from thibmeu/feature/maximum-ttl-cache
Add a max TTL for cached entries
2 parents fadcce3 + 909421f commit a5fdce2

File tree

2 files changed

+101
-15
lines changed

2 files changed

+101
-15
lines changed

resolver.go

+51-13
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package doh
22

33
import (
44
"context"
5+
"math"
56
"net"
67
"strings"
78
"sync"
@@ -17,8 +18,9 @@ type Resolver struct {
1718
url string
1819

1920
// RR cache
20-
ipCache map[string]ipAddrEntry
21-
txtCache map[string]txtEntry
21+
ipCache map[string]ipAddrEntry
22+
txtCache map[string]txtEntry
23+
maxCacheTTL time.Duration
2224
}
2325

2426
type ipAddrEntry struct {
@@ -31,16 +33,43 @@ type txtEntry struct {
3133
expire time.Time
3234
}
3335

34-
func NewResolver(url string) *Resolver {
36+
type Option func(*Resolver) error
37+
38+
// Specifies the maximum time entries are valid in the cache
39+
// A maxCacheTTL of zero is equivalent to `WithCacheDisabled`
40+
func WithMaxCacheTTL(maxCacheTTL time.Duration) Option {
41+
return func(tr *Resolver) error {
42+
tr.maxCacheTTL = maxCacheTTL
43+
return nil
44+
}
45+
}
46+
47+
func WithCacheDisabled() Option {
48+
return func(tr *Resolver) error {
49+
tr.maxCacheTTL = 0
50+
return nil
51+
}
52+
}
53+
54+
func NewResolver(url string, opts ...Option) (*Resolver, error) {
3555
if !strings.HasPrefix(url, "https:") {
3656
url = "https://" + url
3757
}
3858

39-
return &Resolver{
40-
url: url,
41-
ipCache: make(map[string]ipAddrEntry),
42-
txtCache: make(map[string]txtEntry),
59+
r := &Resolver{
60+
url: url,
61+
ipCache: make(map[string]ipAddrEntry),
62+
txtCache: make(map[string]txtEntry),
63+
maxCacheTTL: time.Duration(math.MaxUint32) * time.Second,
4364
}
65+
66+
for _, o := range opts {
67+
if err := o(r); err != nil {
68+
return nil, err
69+
}
70+
}
71+
72+
return r, nil
4473
}
4574

4675
var _ madns.BasicResolver = (*Resolver)(nil)
@@ -81,7 +110,8 @@ func (r *Resolver) LookupIPAddr(ctx context.Context, domain string) (result []ne
81110
}
82111
}
83112

84-
r.cacheIPAddr(domain, result, ttl)
113+
cacheTTL := minTTL(time.Duration(ttl)*time.Second, r.maxCacheTTL)
114+
r.cacheIPAddr(domain, result, cacheTTL)
85115
return result, nil
86116
}
87117

@@ -96,7 +126,8 @@ func (r *Resolver) LookupTXT(ctx context.Context, domain string) ([]string, erro
96126
return nil, err
97127
}
98128

99-
r.cacheTXT(domain, result, ttl)
129+
cacheTTL := minTTL(time.Duration(ttl)*time.Second, r.maxCacheTTL)
130+
r.cacheTXT(domain, result, cacheTTL)
100131
return result, nil
101132
}
102133

@@ -118,7 +149,7 @@ func (r *Resolver) getCachedIPAddr(domain string) ([]net.IPAddr, bool) {
118149
return entry.ips, true
119150
}
120151

121-
func (r *Resolver) cacheIPAddr(domain string, ips []net.IPAddr, ttl uint32) {
152+
func (r *Resolver) cacheIPAddr(domain string, ips []net.IPAddr, ttl time.Duration) {
122153
if ttl == 0 {
123154
return
124155
}
@@ -127,7 +158,7 @@ func (r *Resolver) cacheIPAddr(domain string, ips []net.IPAddr, ttl uint32) {
127158
defer r.mx.Unlock()
128159

129160
fqdn := dns.Fqdn(domain)
130-
r.ipCache[fqdn] = ipAddrEntry{ips, time.Now().Add(time.Duration(ttl) * time.Second)}
161+
r.ipCache[fqdn] = ipAddrEntry{ips, time.Now().Add(ttl)}
131162
}
132163

133164
func (r *Resolver) getCachedTXT(domain string) ([]string, bool) {
@@ -148,7 +179,7 @@ func (r *Resolver) getCachedTXT(domain string) ([]string, bool) {
148179
return entry.txt, true
149180
}
150181

151-
func (r *Resolver) cacheTXT(domain string, txt []string, ttl uint32) {
182+
func (r *Resolver) cacheTXT(domain string, txt []string, ttl time.Duration) {
152183
if ttl == 0 {
153184
return
154185
}
@@ -157,5 +188,12 @@ func (r *Resolver) cacheTXT(domain string, txt []string, ttl uint32) {
157188
defer r.mx.Unlock()
158189

159190
fqdn := dns.Fqdn(domain)
160-
r.txtCache[fqdn] = txtEntry{txt, time.Now().Add(time.Duration(ttl) * time.Second)}
191+
r.txtCache[fqdn] = txtEntry{txt, time.Now().Add(ttl)}
192+
}
193+
194+
func minTTL(a, b time.Duration) time.Duration {
195+
if a < b {
196+
return a
197+
}
198+
return b
161199
}

resolver_test.go

+50-2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"net/http"
88
"net/http/httptest"
99
"testing"
10+
"time"
1011

1112
"github.com/miekg/dns"
1213
)
@@ -76,7 +77,10 @@ func TestLookupIPAddr(t *testing.T) {
7677
})
7778
defer resolver.Close()
7879

79-
r := NewResolver("")
80+
r, err := NewResolver("https://cloudflare-dns.com/dns-query")
81+
if err != nil {
82+
t.Fatal("resolver cannot be initialised")
83+
}
8084
r.url = resolver.URL
8185

8286
ips, err := r.LookupIPAddr(context.Background(), domain)
@@ -120,7 +124,42 @@ func TestLookupTXT(t *testing.T) {
120124
})
121125
defer resolver.Close()
122126

123-
r := NewResolver("")
127+
r, err := NewResolver("")
128+
if err != nil {
129+
t.Fatal("resolver cannot be initialised")
130+
}
131+
r.url = resolver.URL
132+
133+
txt, err := r.LookupTXT(context.Background(), domain)
134+
if err != nil {
135+
t.Fatal(err)
136+
}
137+
if len(txt) == 0 {
138+
t.Fatal("got no TXT entries")
139+
}
140+
141+
// check the cache
142+
txt2, ok := r.getCachedTXT(domain)
143+
if !ok {
144+
t.Fatal("expected cache to be populated")
145+
}
146+
if !sameTXT(txt, txt2) {
147+
t.Fatal("expected cache to contain the same txt entries")
148+
}
149+
}
150+
151+
func TestLookupCache(t *testing.T) {
152+
domain := "example.com"
153+
resolver := mockDoHResolver(t, map[uint16]*dns.Msg{
154+
dns.TypeTXT: mockDNSAnswerTXT(dns.Fqdn(domain), []string{"dnslink=/ipns/example.com"}),
155+
})
156+
defer resolver.Close()
157+
158+
const cacheTTL = time.Second
159+
r, err := NewResolver("", WithMaxCacheTTL(cacheTTL))
160+
if err != nil {
161+
t.Fatal("resolver cannot be initialised")
162+
}
124163
r.url = resolver.URL
125164

126165
txt, err := r.LookupTXT(context.Background(), domain)
@@ -140,6 +179,15 @@ func TestLookupTXT(t *testing.T) {
140179
t.Fatal("expected cache to contain the same txt entries")
141180
}
142181

182+
// check cache is empty after its maxTTL
183+
time.Sleep(cacheTTL)
184+
txt2, ok = r.getCachedTXT(domain)
185+
if ok {
186+
t.Fatal("expected cache to be empty")
187+
}
188+
if txt2 != nil {
189+
t.Fatal("expected cache to not contain a txt entry")
190+
}
143191
}
144192

145193
func sameIPs(a, b []net.IPAddr) bool {

0 commit comments

Comments
 (0)