16
16
import logging
17
17
import re
18
18
import time
19
- from typing import TYPE_CHECKING , Any , Awaitable , Callable , Optional , Tuple , cast
19
+ from http import HTTPStatus
20
+ from typing import TYPE_CHECKING , Any , Awaitable , Callable , Dict , Optional , Tuple , cast
20
21
21
22
from synapse .api .errors import Codes , FederationDeniedError , SynapseError
22
23
from synapse .api .urls import FEDERATION_V1_PREFIX
@@ -86,15 +87,24 @@ async def authenticate_request(
86
87
87
88
if not auth_headers :
88
89
raise NoAuthenticationError (
89
- 401 , "Missing Authorization headers" , Codes .UNAUTHORIZED
90
+ HTTPStatus .UNAUTHORIZED ,
91
+ "Missing Authorization headers" ,
92
+ Codes .UNAUTHORIZED ,
90
93
)
91
94
92
95
for auth in auth_headers :
93
96
if auth .startswith (b"X-Matrix" ):
94
- (origin , key , sig ) = _parse_auth_header (auth )
97
+ (origin , key , sig , destination ) = _parse_auth_header (auth )
95
98
json_request ["origin" ] = origin
96
99
json_request ["signatures" ].setdefault (origin , {})[key ] = sig
97
100
101
+ # if the origin_server sent a destination along it needs to match our own server_name
102
+ if destination is not None and destination != self .server_name :
103
+ raise AuthenticationError (
104
+ HTTPStatus .UNAUTHORIZED ,
105
+ "Destination mismatch in auth header" ,
106
+ Codes .UNAUTHORIZED ,
107
+ )
98
108
if (
99
109
self .federation_domain_whitelist is not None
100
110
and origin not in self .federation_domain_whitelist
@@ -103,7 +113,9 @@ async def authenticate_request(
103
113
104
114
if origin is None or not json_request ["signatures" ]:
105
115
raise NoAuthenticationError (
106
- 401 , "Missing Authorization headers" , Codes .UNAUTHORIZED
116
+ HTTPStatus .UNAUTHORIZED ,
117
+ "Missing Authorization headers" ,
118
+ Codes .UNAUTHORIZED ,
107
119
)
108
120
109
121
await self .keyring .verify_json_for_server (
@@ -142,13 +154,14 @@ async def reset_retry_timings(self, origin: str) -> None:
142
154
logger .exception ("Error resetting retry timings on %s" , origin )
143
155
144
156
145
- def _parse_auth_header (header_bytes : bytes ) -> Tuple [str , str , str ]:
157
+ def _parse_auth_header (header_bytes : bytes ) -> Tuple [str , str , str , Optional [ str ] ]:
146
158
"""Parse an X-Matrix auth header
147
159
148
160
Args:
149
161
header_bytes: header value
150
162
151
163
Returns:
164
+ origin, key id, signature, destination.
152
165
origin, key id, signature.
153
166
154
167
Raises:
@@ -157,7 +170,9 @@ def _parse_auth_header(header_bytes: bytes) -> Tuple[str, str, str]:
157
170
try :
158
171
header_str = header_bytes .decode ("utf-8" )
159
172
params = header_str .split (" " )[1 ].split ("," )
160
- param_dict = {k : v for k , v in (kv .split ("=" , maxsplit = 1 ) for kv in params )}
173
+ param_dict : Dict [str , str ] = {
174
+ k : v for k , v in [param .split ("=" , maxsplit = 1 ) for param in params ]
175
+ }
161
176
162
177
def strip_quotes (value : str ) -> str :
163
178
if value .startswith ('"' ):
@@ -172,15 +187,23 @@ def strip_quotes(value: str) -> str:
172
187
173
188
key = strip_quotes (param_dict ["key" ])
174
189
sig = strip_quotes (param_dict ["sig" ])
175
- return origin , key , sig
190
+
191
+ # get the destination server_name from the auth header if it exists
192
+ destination = param_dict .get ("destination" )
193
+ if destination is not None :
194
+ destination = strip_quotes (destination )
195
+ else :
196
+ destination = None
197
+
198
+ return origin , key , sig , destination
176
199
except Exception as e :
177
200
logger .warning (
178
201
"Error parsing auth header '%s': %s" ,
179
202
header_bytes .decode ("ascii" , "replace" ),
180
203
e ,
181
204
)
182
205
raise AuthenticationError (
183
- 400 , "Malformed Authorization header" , Codes .UNAUTHORIZED
206
+ HTTPStatus . BAD_REQUEST , "Malformed Authorization header" , Codes .UNAUTHORIZED
184
207
)
185
208
186
209
0 commit comments