Skip to content

Commit c4f1e01

Browse files
turt2liveS7evinK
andauthored
Merge commit from fork
* Plumb allow/deny CIDRs to HTTP client * Make the DNS cache aware of the ControlContext function --------- Co-authored-by: Till Faelligen <[email protected]>
1 parent bf86bc9 commit c4f1e01

File tree

3 files changed

+98
-19
lines changed

3 files changed

+98
-19
lines changed

fclient/client.go

+91-15
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import (
2828
"strings"
2929
"sync"
3030
"sync/atomic"
31+
"syscall"
3132
"time"
3233

3334
"github.com/matrix-org/gomatrix"
@@ -54,13 +55,15 @@ type UserInfo struct {
5455
}
5556

5657
type clientOptions struct {
57-
transport http.RoundTripper
58-
dnsCache *DNSCache
59-
timeout time.Duration
60-
skipVerify bool
61-
keepAlives bool
62-
wellKnownSRV bool
63-
userAgent string
58+
transport http.RoundTripper
59+
dnsCache *DNSCache
60+
timeout time.Duration
61+
skipVerify bool
62+
keepAlives bool
63+
wellKnownSRV bool
64+
userAgent string
65+
allowNetworks []string
66+
denyNetworks []string
6467
}
6568

6669
// ClientOption are supplied to NewClient or NewFederationClient.
@@ -82,6 +85,8 @@ func NewClient(options ...ClientOption) *Client {
8285
clientOpts.dnsCache,
8386
clientOpts.keepAlives,
8487
clientOpts.wellKnownSRV,
88+
clientOpts.allowNetworks,
89+
clientOpts.denyNetworks,
8590
)
8691
}
8792
client := &Client{
@@ -152,6 +157,15 @@ func WithUserAgent(userAgent string) ClientOption {
152157
}
153158
}
154159

160+
// WithAllowDenyNetworks sets the allowed and denied networks for the http client. By default,
161+
// all networks are allowed. The deny list is checked before the allow list.
162+
func WithAllowDenyNetworks(allowCIDRs []string, denyCIDRs []string) ClientOption {
163+
return func(options *clientOptions) {
164+
options.allowNetworks = allowCIDRs
165+
options.denyNetworks = denyCIDRs
166+
}
167+
}
168+
155169
const destinationTripperLifetime = time.Minute * 5 // how long to keep an entry
156170
const destinationTripperReapInterval = time.Minute // how often to check for dead entries
157171

@@ -165,15 +179,17 @@ type destinationTripper struct {
165179
dnsCache *DNSCache
166180
keepAlives bool
167181
wellKnownSRV bool
182+
dialer *net.Dialer
168183
}
169184

170-
func newDestinationTripper(skipVerify bool, dnsCache *DNSCache, keepAlives, wellKnownSRV bool) *destinationTripper {
185+
func newDestinationTripper(skipVerify bool, dnsCache *DNSCache, keepAlives, wellKnownSRV bool, allowCIDRs []string, denyCIDRs []string) *destinationTripper {
171186
tripper := &destinationTripper{
172187
transports: make(map[string]*destinationTripperTransport),
173188
skipVerify: skipVerify,
174189
dnsCache: dnsCache,
175190
keepAlives: keepAlives,
176191
wellKnownSRV: wellKnownSRV,
192+
dialer: newDestinationTripperDialer(allowCIDRs, denyCIDRs),
177193
}
178194
time.AfterFunc(destinationTripperReapInterval, tripper.reaper)
179195
return tripper
@@ -195,11 +211,71 @@ func (f *destinationTripper) reaper() {
195211
time.AfterFunc(destinationTripperReapInterval, f.reaper)
196212
}
197213

198-
// destinationTripperDialer enforces dial timeouts on the federation requests. If
214+
// newDestinationTripperDialer creates a dialer which enforces dial timeouts on the federation requests. If
199215
// the TCP connection doesn't complete within 5 seconds, it's probably just not
200216
// going to.
201-
var destinationTripperDialer = &net.Dialer{
202-
Timeout: time.Second * 5,
217+
// The dialer can also be limited to CIDR ranges, if allow or deny networks is non-empty.
218+
func newDestinationTripperDialer(allowNetworks []string, denyNetworks []string) *net.Dialer {
219+
if len(allowNetworks) == 0 && len(denyNetworks) == 0 {
220+
return &net.Dialer{
221+
Timeout: time.Second * 5,
222+
}
223+
}
224+
225+
return &net.Dialer{
226+
Timeout: time.Second * 5,
227+
ControlContext: allowDenyNetworksControl(allowNetworks, denyNetworks),
228+
}
229+
}
230+
231+
// allowDenyNetworksControl is used to allow/deny access to certain networks
232+
func allowDenyNetworksControl(allowNetworks, denyNetworks []string) func(_ context.Context, network string, address string, conn syscall.RawConn) error {
233+
return func(_ context.Context, network string, address string, conn syscall.RawConn) error {
234+
if network != "tcp4" && network != "tcp6" {
235+
return fmt.Errorf("%s is not a safe network type", network)
236+
}
237+
238+
host, _, err := net.SplitHostPort(address)
239+
if err != nil {
240+
return fmt.Errorf("%s is not a valid host/port pair: %s", address, err)
241+
}
242+
243+
ipaddress := net.ParseIP(host)
244+
if ipaddress == nil {
245+
return fmt.Errorf("%s is not a valid IP address", host)
246+
}
247+
248+
if !isAllowed(ipaddress, allowNetworks, denyNetworks) {
249+
return fmt.Errorf("%s is denied", address)
250+
}
251+
252+
return nil // allow connection
253+
}
254+
}
255+
256+
func isAllowed(ip net.IP, allowCIDRs []string, denyCIDRs []string) bool {
257+
if inRange(ip, denyCIDRs) {
258+
return false
259+
}
260+
if inRange(ip, allowCIDRs) {
261+
return true
262+
}
263+
return false // "should never happen"
264+
}
265+
266+
func inRange(ip net.IP, CIDRs []string) bool {
267+
for i := 0; i < len(CIDRs); i++ {
268+
cidr := CIDRs[i]
269+
_, network, err := net.ParseCIDR(cidr)
270+
if err != nil {
271+
return false
272+
}
273+
if network.Contains(ip) {
274+
return true
275+
}
276+
}
277+
278+
return false
203279
}
204280

205281
type destinationTripperTransport struct {
@@ -213,7 +289,7 @@ type destinationTripperTransport struct {
213289
// We need to use one transport per TLS server name (instead of giving our round
214290
// tripper a single transport) because there is no way to specify the TLS
215291
// ServerName on a per-connection basis.
216-
func (f *destinationTripper) getTransport(tlsServerName string) http.RoundTripper {
292+
func (f *destinationTripper) getTransport(tlsServerName string, dialer *net.Dialer) http.RoundTripper {
217293
f.transportsMutex.Lock()
218294
defer f.transportsMutex.Unlock()
219295

@@ -230,8 +306,8 @@ func (f *destinationTripper) getTransport(tlsServerName string) http.RoundTrippe
230306
InsecureSkipVerify: f.skipVerify,
231307
ClientSessionCache: tls.NewLRUClientSessionCache(0), // 0 = use default
232308
},
233-
Dial: destinationTripperDialer.Dial, // nolint: staticcheck
234-
DialContext: destinationTripperDialer.DialContext,
309+
Dial: dialer.Dial, // nolint: staticcheck
310+
DialContext: dialer.DialContext,
235311
Proxy: http.ProxyFromEnvironment,
236312
ForceAttemptHTTP2: true, // if we can multiplex requests over HTTP/2, we should
237313
},
@@ -296,7 +372,7 @@ retryResolution:
296372
u := makeHTTPSURL(r.URL, result.Destination)
297373
r.URL = &u
298374
r.Host = string(result.Host)
299-
resp, err = f.getTransport(result.TLSServerName).RoundTrip(r)
375+
resp, err = f.getTransport(result.TLSServerName, f.dialer).RoundTrip(r)
300376
if err == nil {
301377
return resp, nil
302378
}

fclient/dnscache.go

+6-3
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,18 @@ type DNSCache struct {
1414
size int
1515
duration time.Duration
1616
entries map[string]*dnsCacheEntry
17+
dialer net.Dialer
1718
}
1819

19-
func NewDNSCache(size int, duration time.Duration) *DNSCache {
20+
func NewDNSCache(size int, duration time.Duration, allowNetworks, denyNetworks []string) *DNSCache {
2021
return &DNSCache{
2122
resolver: net.DefaultResolver,
2223
size: size,
2324
duration: duration,
2425
entries: make(map[string]*dnsCacheEntry),
26+
dialer: net.Dialer{
27+
ControlContext: allowDenyNetworksControl(allowNetworks, denyNetworks),
28+
},
2529
}
2630
}
2731

@@ -100,7 +104,6 @@ func (c *DNSCache) DialContext(ctx context.Context, network, address string) (ne
100104
// retried set to true. This stops us from recursing more than
101105
// once.
102106
retried := false
103-
dialer := net.Dialer{}
104107

105108
retryLookup:
106109
// Consult the cache for the hostname. This will cause the OS to
@@ -113,7 +116,7 @@ retryLookup:
113116
// Try each address in the cached entry. If we successfully connect
114117
// to one of those addresses then return the conn and stop there.
115118
for _, addr := range entry.addrs {
116-
conn, err := dialer.DialContext(ctx, "tcp", addr.String()+":"+port)
119+
conn, err := c.dialer.DialContext(ctx, "tcp", addr.String()+":"+port)
117120
if err != nil {
118121
continue
119122
}

fclient/dnscache_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ func (r *dummyNetResolver) LookupIPAddr(_ context.Context, hostname string) ([]n
2525
}
2626

2727
func mustCreateCache(size int, lifetime time.Duration) *DNSCache {
28-
cache := NewDNSCache(size, lifetime)
28+
cache := NewDNSCache(size, lifetime, []string{}, []string{})
2929
cache.resolver = &dummyNetResolver{}
3030
return cache
3131
}

0 commit comments

Comments
 (0)