46
46
import signedjson .types
47
47
import srvlookup
48
48
import yaml
49
+ from requests import PreparedRequest , Response
49
50
from requests .adapters import HTTPAdapter
50
51
from urllib3 import HTTPConnectionPool
51
52
52
53
# uncomment the following to enable debug logging of http requests
53
- # from httplib import HTTPConnection
54
+ # from http.client import HTTPConnection
54
55
# HTTPConnection.debuglevel = 1
55
56
56
57
@@ -103,6 +104,7 @@ def request(
103
104
destination : str ,
104
105
path : str ,
105
106
content : Optional [str ],
107
+ verify_tls : bool ,
106
108
) -> requests .Response :
107
109
if method is None :
108
110
if content is None :
@@ -141,7 +143,6 @@ def request(
141
143
s .mount ("matrix://" , MatrixConnectionAdapter ())
142
144
143
145
headers : Dict [str , str ] = {
144
- "Host" : destination ,
145
146
"Authorization" : authorization_headers [0 ],
146
147
}
147
148
@@ -152,7 +153,7 @@ def request(
152
153
method = method ,
153
154
url = dest ,
154
155
headers = headers ,
155
- verify = False ,
156
+ verify = verify_tls ,
156
157
data = content ,
157
158
stream = True ,
158
159
)
@@ -202,6 +203,12 @@ def main() -> None:
202
203
203
204
parser .add_argument ("--body" , help = "Data to send as the body of the HTTP request" )
204
205
206
+ parser .add_argument (
207
+ "--insecure" ,
208
+ action = "store_true" ,
209
+ help = "Disable TLS certificate verification" ,
210
+ )
211
+
205
212
parser .add_argument (
206
213
"path" , help = "request path, including the '/_matrix/federation/...' prefix."
207
214
)
@@ -227,6 +234,7 @@ def main() -> None:
227
234
args .destination ,
228
235
args .path ,
229
236
content = args .body ,
237
+ verify_tls = not args .insecure ,
230
238
)
231
239
232
240
sys .stderr .write ("Status Code: %d\n " % (result .status_code ,))
@@ -254,36 +262,93 @@ def read_args_from_config(args: argparse.Namespace) -> None:
254
262
255
263
256
264
class MatrixConnectionAdapter (HTTPAdapter ):
265
+ def send (
266
+ self ,
267
+ request : PreparedRequest ,
268
+ * args : Any ,
269
+ ** kwargs : Any ,
270
+ ) -> Response :
271
+ # overrides the send() method in the base class.
272
+
273
+ # We need to look for .well-known redirects before passing the request up to
274
+ # HTTPAdapter.send().
275
+ assert isinstance (request .url , str )
276
+ parsed = urlparse .urlsplit (request .url )
277
+ server_name = parsed .netloc
278
+ well_known = self ._get_well_known (parsed .netloc )
279
+
280
+ if well_known :
281
+ server_name = well_known
282
+
283
+ # replace the scheme in the uri with https, so that cert verification is done
284
+ # also replace the hostname if we got a .well-known result
285
+ request .url = urlparse .urlunsplit (
286
+ ("https" , server_name , parsed .path , parsed .query , parsed .fragment )
287
+ )
288
+
289
+ # at this point we also add the host header (otherwise urllib will add one
290
+ # based on the `host` from the connection returned by `get_connection`,
291
+ # which will be wrong if there is an SRV record).
292
+ request .headers ["Host" ] = server_name
293
+
294
+ return super ().send (request , * args , ** kwargs )
295
+
296
+ def get_connection (
297
+ self , url : str , proxies : Optional [Dict [str , str ]] = None
298
+ ) -> HTTPConnectionPool :
299
+ # overrides the get_connection() method in the base class
300
+ parsed = urlparse .urlsplit (url )
301
+ (host , port , ssl_server_name ) = self ._lookup (parsed .netloc )
302
+ print (
303
+ f"Connecting to { host } :{ port } with SNI { ssl_server_name } " , file = sys .stderr
304
+ )
305
+ return self .poolmanager .connection_from_host (
306
+ host ,
307
+ port = port ,
308
+ scheme = "https" ,
309
+ pool_kwargs = {"server_hostname" : ssl_server_name },
310
+ )
311
+
257
312
@staticmethod
258
- def lookup (s : str , skip_well_known : bool = False ) -> Tuple [str , int ]:
259
- if s [- 1 ] == "]" :
313
+ def _lookup (server_name : str ) -> Tuple [str , int , str ]:
314
+ """
315
+ Do an SRV lookup on a server name and return the host:port to connect to
316
+ Given the server_name (after any .well-known lookup), return the host, port and
317
+ the ssl server name
318
+ """
319
+ if server_name [- 1 ] == "]" :
260
320
# ipv6 literal (with no port)
261
- return s , 8448
321
+ return server_name , 8448 , server_name
262
322
263
- if ":" in s :
264
- out = s .rsplit (":" , 1 )
323
+ if ":" in server_name :
324
+ # explicit port
325
+ out = server_name .rsplit (":" , 1 )
265
326
try :
266
327
port = int (out [1 ])
267
328
except ValueError :
268
- raise ValueError ("Invalid host:port '%s'" % s )
269
- return out [0 ], port
270
-
271
- # try a .well-known lookup
272
- if not skip_well_known :
273
- well_known = MatrixConnectionAdapter .get_well_known (s )
274
- if well_known :
275
- return MatrixConnectionAdapter .lookup (well_known , skip_well_known = True )
329
+ raise ValueError ("Invalid host:port '%s'" % (server_name ,))
330
+ return out [0 ], port , out [0 ]
276
331
277
332
try :
278
- srv = srvlookup .lookup ("matrix" , "tcp" , s )[0 ]
279
- return srv .host , srv .port
333
+ srv = srvlookup .lookup ("matrix" , "tcp" , server_name )[0 ]
334
+ print (
335
+ f"SRV lookup on _matrix._tcp.{ server_name } gave { srv } " ,
336
+ file = sys .stderr ,
337
+ )
338
+ return srv .host , srv .port , server_name
280
339
except Exception :
281
- return s , 8448
340
+ return server_name , 8448 , server_name
282
341
283
342
@staticmethod
284
- def get_well_known (server_name : str ) -> Optional [str ]:
285
- uri = "https://%s/.well-known/matrix/server" % (server_name ,)
286
- print ("fetching %s" % (uri ,), file = sys .stderr )
343
+ def _get_well_known (server_name : str ) -> Optional [str ]:
344
+ if ":" in server_name :
345
+ # explicit port, or ipv6 literal. Either way, no .well-known
346
+ return None
347
+
348
+ # TODO: check for ipv4 literals
349
+
350
+ uri = f"https://{ server_name } /.well-known/matrix/server"
351
+ print (f"fetching { uri } " , file = sys .stderr )
287
352
288
353
try :
289
354
resp = requests .get (uri )
@@ -304,19 +369,6 @@ def get_well_known(server_name: str) -> Optional[str]:
304
369
print ("Invalid response from %s: %s" % (uri , e ), file = sys .stderr )
305
370
return None
306
371
307
- def get_connection (
308
- self , url : str , proxies : Optional [Dict [str , str ]] = None
309
- ) -> HTTPConnectionPool :
310
- parsed = urlparse .urlparse (url )
311
-
312
- (host , port ) = self .lookup (parsed .netloc )
313
- netloc = "%s:%d" % (host , port )
314
- print ("Connecting to %s" % (netloc ,), file = sys .stderr )
315
- url = urlparse .urlunparse (
316
- ("https" , netloc , parsed .path , parsed .params , parsed .query , parsed .fragment )
317
- )
318
- return super ().get_connection (url , proxies )
319
-
320
372
321
373
if __name__ == "__main__" :
322
374
main ()
0 commit comments