@@ -28,6 +28,7 @@ import (
28
28
"strings"
29
29
"sync"
30
30
"sync/atomic"
31
+ "syscall"
31
32
"time"
32
33
33
34
"github.com/matrix-org/gomatrix"
@@ -54,13 +55,15 @@ type UserInfo struct {
54
55
}
55
56
56
57
type clientOptions struct {
57
- transport http.RoundTripper
58
- dnsCache * DNSCache
59
- timeout time.Duration
60
- skipVerify bool
61
- keepAlives bool
62
- wellKnownSRV bool
63
- userAgent string
58
+ transport http.RoundTripper
59
+ dnsCache * DNSCache
60
+ timeout time.Duration
61
+ skipVerify bool
62
+ keepAlives bool
63
+ wellKnownSRV bool
64
+ userAgent string
65
+ allowNetworks []string
66
+ denyNetworks []string
64
67
}
65
68
66
69
// ClientOption are supplied to NewClient or NewFederationClient.
@@ -82,6 +85,8 @@ func NewClient(options ...ClientOption) *Client {
82
85
clientOpts .dnsCache ,
83
86
clientOpts .keepAlives ,
84
87
clientOpts .wellKnownSRV ,
88
+ clientOpts .allowNetworks ,
89
+ clientOpts .denyNetworks ,
85
90
)
86
91
}
87
92
client := & Client {
@@ -152,6 +157,15 @@ func WithUserAgent(userAgent string) ClientOption {
152
157
}
153
158
}
154
159
160
+ // WithAllowDenyNetworks sets the allowed and denied networks for the http client. By default,
161
+ // all networks are allowed. The deny list is checked before the allow list.
162
+ func WithAllowDenyNetworks (allowCIDRs []string , denyCIDRs []string ) ClientOption {
163
+ return func (options * clientOptions ) {
164
+ options .allowNetworks = allowCIDRs
165
+ options .denyNetworks = denyCIDRs
166
+ }
167
+ }
168
+
155
169
const destinationTripperLifetime = time .Minute * 5 // how long to keep an entry
156
170
const destinationTripperReapInterval = time .Minute // how often to check for dead entries
157
171
@@ -165,15 +179,17 @@ type destinationTripper struct {
165
179
dnsCache * DNSCache
166
180
keepAlives bool
167
181
wellKnownSRV bool
182
+ dialer * net.Dialer
168
183
}
169
184
170
- func newDestinationTripper (skipVerify bool , dnsCache * DNSCache , keepAlives , wellKnownSRV bool ) * destinationTripper {
185
+ func newDestinationTripper (skipVerify bool , dnsCache * DNSCache , keepAlives , wellKnownSRV bool , allowCIDRs [] string , denyCIDRs [] string ) * destinationTripper {
171
186
tripper := & destinationTripper {
172
187
transports : make (map [string ]* destinationTripperTransport ),
173
188
skipVerify : skipVerify ,
174
189
dnsCache : dnsCache ,
175
190
keepAlives : keepAlives ,
176
191
wellKnownSRV : wellKnownSRV ,
192
+ dialer : newDestinationTripperDialer (allowCIDRs , denyCIDRs ),
177
193
}
178
194
time .AfterFunc (destinationTripperReapInterval , tripper .reaper )
179
195
return tripper
@@ -195,11 +211,71 @@ func (f *destinationTripper) reaper() {
195
211
time .AfterFunc (destinationTripperReapInterval , f .reaper )
196
212
}
197
213
198
- // destinationTripperDialer enforces dial timeouts on the federation requests. If
214
+ // newDestinationTripperDialer creates a dialer which enforces dial timeouts on the federation requests. If
199
215
// the TCP connection doesn't complete within 5 seconds, it's probably just not
200
216
// going to.
201
- var destinationTripperDialer = & net.Dialer {
202
- Timeout : time .Second * 5 ,
217
+ // The dialer can also be limited to CIDR ranges, if allow or deny networks is non-empty.
218
+ func newDestinationTripperDialer (allowNetworks []string , denyNetworks []string ) * net.Dialer {
219
+ if len (allowNetworks ) == 0 && len (denyNetworks ) == 0 {
220
+ return & net.Dialer {
221
+ Timeout : time .Second * 5 ,
222
+ }
223
+ }
224
+
225
+ return & net.Dialer {
226
+ Timeout : time .Second * 5 ,
227
+ ControlContext : allowDenyNetworksControl (allowNetworks , denyNetworks ),
228
+ }
229
+ }
230
+
231
+ // allowDenyNetworksControl is used to allow/deny access to certain networks
232
+ func allowDenyNetworksControl (allowNetworks , denyNetworks []string ) func (_ context.Context , network string , address string , conn syscall.RawConn ) error {
233
+ return func (_ context.Context , network string , address string , conn syscall.RawConn ) error {
234
+ if network != "tcp4" && network != "tcp6" {
235
+ return fmt .Errorf ("%s is not a safe network type" , network )
236
+ }
237
+
238
+ host , _ , err := net .SplitHostPort (address )
239
+ if err != nil {
240
+ return fmt .Errorf ("%s is not a valid host/port pair: %s" , address , err )
241
+ }
242
+
243
+ ipaddress := net .ParseIP (host )
244
+ if ipaddress == nil {
245
+ return fmt .Errorf ("%s is not a valid IP address" , host )
246
+ }
247
+
248
+ if ! isAllowed (ipaddress , allowNetworks , denyNetworks ) {
249
+ return fmt .Errorf ("%s is denied" , address )
250
+ }
251
+
252
+ return nil // allow connection
253
+ }
254
+ }
255
+
256
+ func isAllowed (ip net.IP , allowCIDRs []string , denyCIDRs []string ) bool {
257
+ if inRange (ip , denyCIDRs ) {
258
+ return false
259
+ }
260
+ if inRange (ip , allowCIDRs ) {
261
+ return true
262
+ }
263
+ return false // "should never happen"
264
+ }
265
+
266
+ func inRange (ip net.IP , CIDRs []string ) bool {
267
+ for i := 0 ; i < len (CIDRs ); i ++ {
268
+ cidr := CIDRs [i ]
269
+ _ , network , err := net .ParseCIDR (cidr )
270
+ if err != nil {
271
+ return false
272
+ }
273
+ if network .Contains (ip ) {
274
+ return true
275
+ }
276
+ }
277
+
278
+ return false
203
279
}
204
280
205
281
type destinationTripperTransport struct {
@@ -213,7 +289,7 @@ type destinationTripperTransport struct {
213
289
// We need to use one transport per TLS server name (instead of giving our round
214
290
// tripper a single transport) because there is no way to specify the TLS
215
291
// ServerName on a per-connection basis.
216
- func (f * destinationTripper ) getTransport (tlsServerName string ) http.RoundTripper {
292
+ func (f * destinationTripper ) getTransport (tlsServerName string , dialer * net. Dialer ) http.RoundTripper {
217
293
f .transportsMutex .Lock ()
218
294
defer f .transportsMutex .Unlock ()
219
295
@@ -230,8 +306,8 @@ func (f *destinationTripper) getTransport(tlsServerName string) http.RoundTrippe
230
306
InsecureSkipVerify : f .skipVerify ,
231
307
ClientSessionCache : tls .NewLRUClientSessionCache (0 ), // 0 = use default
232
308
},
233
- Dial : destinationTripperDialer .Dial , // nolint: staticcheck
234
- DialContext : destinationTripperDialer .DialContext ,
309
+ Dial : dialer .Dial , // nolint: staticcheck
310
+ DialContext : dialer .DialContext ,
235
311
Proxy : http .ProxyFromEnvironment ,
236
312
ForceAttemptHTTP2 : true , // if we can multiplex requests over HTTP/2, we should
237
313
},
@@ -296,7 +372,7 @@ retryResolution:
296
372
u := makeHTTPSURL (r .URL , result .Destination )
297
373
r .URL = & u
298
374
r .Host = string (result .Host )
299
- resp , err = f .getTransport (result .TLSServerName ).RoundTrip (r )
375
+ resp , err = f .getTransport (result .TLSServerName , f . dialer ).RoundTrip (r )
300
376
if err == nil {
301
377
return resp , nil
302
378
}
0 commit comments