Skip to content

Commit 694914f

Browse files
authored
Add support for authenticated media (#69)
This PR adds support for authenticated media by passing the Authorization header information through from the client to the homeserver. Once MAS supports scoped access tokens this code should be changed over to use that. Huge shoutout to @S7evinK for doing the bulk of the implementation on this.
1 parent 2056293 commit 694914f

File tree

7 files changed

+165
-34
lines changed

7 files changed

+165
-34
lines changed

src/matrix_content_scanner/scanner/file_downloader.py

+55-22
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,24 @@ class _PathNotFoundException(Exception):
3333
class FileDownloader:
3434
MEDIA_DOWNLOAD_PREFIX = "_matrix/media/%s/download"
3535
MEDIA_THUMBNAIL_PREFIX = "_matrix/media/%s/thumbnail"
36+
MEDIA_DOWNLOAD_AUTHENTICATED_PREFIX = "_matrix/client/%s/media/download"
37+
MEDIA_THUMBNAIL_AUTHENTICATED_PREFIX = "_matrix/client/%s/media/thumbnail"
3638

3739
def __init__(self, mcs: "MatrixContentScanner"):
3840
self._base_url = mcs.config.download.base_homeserver_url
3941
self._well_known_cache: Dict[str, Optional[str]] = {}
4042
self._proxy_url = mcs.config.download.proxy
41-
self._headers = mcs.config.download.additional_headers
43+
self._headers = (
44+
mcs.config.download.additional_headers
45+
if mcs.config.download.additional_headers is not None
46+
else {}
47+
)
4248

4349
async def download_file(
4450
self,
4551
media_path: str,
4652
thumbnail_params: Optional[MultiMapping[str]] = None,
53+
auth_header: Optional[str] = None,
4754
) -> MediaDescription:
4855
"""Retrieve the file with the given `server_name/media_id` path, and stores it on
4956
disk.
@@ -52,6 +59,8 @@ async def download_file(
5259
media_path: The path identifying the media to retrieve.
5360
thumbnail_params: If present, then we want to request and scan a thumbnail
5461
generated with the provided parameters instead of the full media.
62+
auth_header: If present, we forward the given Authorization header, this is
63+
required for authenticated media endpoints.
5564
5665
Returns:
5766
A description of the file (including its full content).
@@ -60,27 +69,45 @@ async def download_file(
6069
ContentScannerRestError: The file was not found or could not be downloaded due
6170
to an error on the remote homeserver's side.
6271
"""
72+
73+
auth_media = True if auth_header is not None else False
74+
75+
prefix = (
76+
self.MEDIA_DOWNLOAD_AUTHENTICATED_PREFIX
77+
if auth_media
78+
else self.MEDIA_DOWNLOAD_PREFIX
79+
)
80+
if thumbnail_params is not None:
81+
prefix = (
82+
self.MEDIA_THUMBNAIL_AUTHENTICATED_PREFIX
83+
if auth_media
84+
else self.MEDIA_THUMBNAIL_PREFIX
85+
)
86+
6387
url = await self._build_https_url(
64-
media_path, for_thumbnail=thumbnail_params is not None
88+
media_path, prefix, "v1" if auth_media else "v3"
6589
)
6690

6791
# Attempt to retrieve the file at the generated URL.
6892
try:
69-
file = await self._get_file_content(url, thumbnail_params)
93+
file = await self._get_file_content(url, thumbnail_params, auth_header)
7094
except _PathNotFoundException:
95+
if auth_media:
96+
raise ContentScannerRestError(
97+
http_status=HTTPStatus.NOT_FOUND,
98+
reason=ErrCode.NOT_FOUND,
99+
info="File not found",
100+
)
101+
71102
# If the file could not be found, it might be because the homeserver hasn't
72103
# been upgraded to a version that supports Matrix v1.1 endpoints yet, so try
73104
# again with an r0 endpoint.
74105
logger.info("File not found, trying legacy r0 path")
75106

76-
url = await self._build_https_url(
77-
media_path,
78-
endpoint_version="r0",
79-
for_thumbnail=thumbnail_params is not None,
80-
)
107+
url = await self._build_https_url(media_path, prefix, "r0")
81108

82109
try:
83-
file = await self._get_file_content(url, thumbnail_params)
110+
file = await self._get_file_content(url, thumbnail_params, auth_header)
84111
except _PathNotFoundException:
85112
# If that still failed, raise an error.
86113
raise ContentScannerRestError(
@@ -94,9 +121,8 @@ async def download_file(
94121
async def _build_https_url(
95122
self,
96123
media_path: str,
97-
endpoint_version: str = "v3",
98-
*,
99-
for_thumbnail: bool,
124+
prefix: str,
125+
endpoint_version: str,
100126
) -> str:
101127
"""Turn a `server_name/media_id` path into an https:// one we can use to fetch
102128
the media.
@@ -107,10 +133,8 @@ async def _build_https_url(
107133
Args:
108134
media_path: The media path to translate.
109135
endpoint_version: The version of the download endpoint to use. As of Matrix
110-
v1.1, this is either "v3" or "r0".
111-
for_thumbnail: True if a server-side thumbnail is desired instead of the full
112-
media. In that case, the URL for the `/thumbnail` endpoint is returned
113-
instead of the `/download` endpoint.
136+
v1.11, this is "v1" for authenticated media. For unauthenticated media
137+
this is either "v3" or "r0".
114138
115139
Returns:
116140
An https URL to use. If `base_homeserver_url` is set in the config, this
@@ -140,10 +164,6 @@ async def _build_https_url(
140164
# didn't find a .well-known file.
141165
base_url = "https://" + server_name
142166

143-
prefix = (
144-
self.MEDIA_THUMBNAIL_PREFIX if for_thumbnail else self.MEDIA_DOWNLOAD_PREFIX
145-
)
146-
147167
# Build the full URL.
148168
path_prefix = prefix % endpoint_version
149169
url = "%s/%s/%s/%s" % (
@@ -159,12 +179,15 @@ async def _get_file_content(
159179
self,
160180
url: str,
161181
thumbnail_params: Optional[MultiMapping[str]],
182+
auth_header: Optional[str] = None,
162183
) -> MediaDescription:
163184
"""Retrieve the content of the file at a given URL.
164185
165186
Args:
166187
url: The URL to query.
167188
thumbnail_params: Query parameters used if the request is for a thumbnail.
189+
auth_header: If present, we forward the given Authorization header, this is
190+
required for authenticated media endpoints.
168191
169192
Returns:
170193
A description of the file (including its full content).
@@ -178,7 +201,9 @@ async def _get_file_content(
178201
ContentScannerRestError: the server returned a non-200 status which cannot
179202
meant that the path wasn't understood.
180203
"""
181-
code, body, headers = await self._get(url, query=thumbnail_params)
204+
code, body, headers = await self._get(
205+
url, query=thumbnail_params, auth_header=auth_header
206+
)
182207

183208
logger.info("Remote server responded with %d", code)
184209

@@ -307,12 +332,15 @@ async def _get(
307332
self,
308333
url: str,
309334
query: Optional[MultiMapping[str]] = None,
335+
auth_header: Optional[str] = None,
310336
) -> Tuple[int, bytes, CIMultiDictProxy[str]]:
311337
"""Sends a GET request to the provided URL.
312338
313339
Args:
314340
url: The URL to send requests to.
315341
query: Optional parameters to use in the request's query string.
342+
auth_header: If present, we forward the given Authorization header, this is
343+
required for authenticated media endpoints.
316344
317345
Returns:
318346
The HTTP status code, body and headers the remote server responded with.
@@ -324,10 +352,15 @@ async def _get(
324352
try:
325353
logger.info("Sending GET request to %s", url)
326354
async with aiohttp.ClientSession() as session:
355+
if auth_header is not None:
356+
request_headers = {"Authorization": auth_header, **self._headers}
357+
else:
358+
request_headers = self._headers
359+
327360
async with session.get(
328361
url,
329362
proxy=self._proxy_url,
330-
headers=self._headers,
363+
headers=request_headers,
331364
params=query,
332365
) as resp:
333366
return resp.status, await resp.read(), resp.headers

src/matrix_content_scanner/scanner/scanner.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ async def scan_file(
100100
media_path: str,
101101
metadata: Optional[JsonDict] = None,
102102
thumbnail_params: Optional["MultiMapping[str]"] = None,
103+
auth_header: Optional[str] = None,
103104
) -> MediaDescription:
104105
"""Download and scan the given media.
105106
@@ -119,6 +120,8 @@ async def scan_file(
119120
the file isn't encrypted.
120121
thumbnail_params: If present, then we want to request and scan a thumbnail
121122
generated with the provided parameters instead of the full media.
123+
auth_header: If present, we forward the given Authorization header, this is
124+
required for authenticated media endpoints.
122125
123126
Returns:
124127
A description of the media.
@@ -141,7 +144,7 @@ async def scan_file(
141144
# Try to download and scan the file.
142145
try:
143146
res = await self._scan_file(
144-
cache_key, media_path, metadata, thumbnail_params
147+
cache_key, media_path, metadata, thumbnail_params, auth_header
145148
)
146149
# Set the future's result, and mark it as done.
147150
f.set_result(res)
@@ -168,6 +171,7 @@ async def _scan_file(
168171
media_path: str,
169172
metadata: Optional[JsonDict] = None,
170173
thumbnail_params: Optional[MultiMapping[str]] = None,
174+
auth_header: Optional[str] = None,
171175
) -> MediaDescription:
172176
"""Download and scan the given media.
173177
@@ -185,6 +189,8 @@ async def _scan_file(
185189
the file isn't encrypted.
186190
thumbnail_params: If present, then we want to request and scan a thumbnail
187191
generated with the provided parameters instead of the full media.
192+
auth_header: If present, we forward the given Authorization header, this is
193+
required for authenticated media endpoints.
188194
189195
Returns:
190196
A description of the media.
@@ -218,6 +224,7 @@ async def _scan_file(
218224
media = await self._file_downloader.download_file(
219225
media_path=media_path,
220226
thumbnail_params=thumbnail_params,
227+
auth_header=auth_header,
221228
)
222229

223230
# Compare the media's hash to ensure the server hasn't changed the file since
@@ -251,6 +258,7 @@ async def _scan_file(
251258
media = await self._file_downloader.download_file(
252259
media_path=media_path,
253260
thumbnail_params=thumbnail_params,
261+
auth_header=auth_header,
254262
)
255263

256264
# Download and scan the file.

src/matrix_content_scanner/servlets/download.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,11 @@ async def _scan(
2626
self,
2727
media_path: str,
2828
metadata: Optional[JsonDict] = None,
29+
auth_header: Optional[str] = None,
2930
) -> Tuple[int, _BytesResponse]:
30-
media = await self._scanner.scan_file(media_path, metadata)
31+
media = await self._scanner.scan_file(
32+
media_path, metadata, auth_header=auth_header
33+
)
3134

3235
return 200, _BytesResponse(
3336
headers=media.response_headers,
@@ -38,7 +41,9 @@ async def _scan(
3841
async def handle_plain(self, request: web.Request) -> Tuple[int, _BytesResponse]:
3942
"""Handles GET requests to ../download/serverName/mediaId"""
4043
media_path = request.match_info["media_path"]
41-
return await self._scan(media_path)
44+
return await self._scan(
45+
media_path, auth_header=request.headers.get("Authorization")
46+
)
4247

4348
@web_handler
4449
async def handle_encrypted(
@@ -49,4 +54,6 @@ async def handle_encrypted(
4954
request, self._crypto_handler
5055
)
5156

52-
return await self._scan(media_path, metadata)
57+
return await self._scan(
58+
media_path, metadata, auth_header=request.headers.get("Authorization")
59+
)

src/matrix_content_scanner/servlets/scan.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,10 @@ async def _scan_and_format(
2323
self,
2424
media_path: str,
2525
metadata: Optional[JsonDict] = None,
26+
auth_header: Optional[str] = None,
2627
) -> Tuple[int, JsonDict]:
2728
try:
28-
await self._scanner.scan_file(media_path, metadata)
29+
await self._scanner.scan_file(media_path, metadata, auth_header=auth_header)
2930
except FileDirtyError as e:
3031
res = {"clean": False, "info": e.info}
3132
else:
@@ -37,12 +38,16 @@ async def _scan_and_format(
3738
async def handle_plain(self, request: web.Request) -> Tuple[int, JsonDict]:
3839
"""Handles GET requests to ../scan/serverName/mediaId"""
3940
media_path = request.match_info["media_path"]
40-
return await self._scan_and_format(media_path)
41+
return await self._scan_and_format(
42+
media_path, auth_header=request.headers.get("Authorization")
43+
)
4144

4245
@web_handler
4346
async def handle_encrypted(self, request: web.Request) -> Tuple[int, JsonDict]:
4447
"""Handles GET requests to ../scan_encrypted"""
4548
media_path, metadata = await get_media_metadata_from_request(
4649
request, self._crypto_handler
4750
)
48-
return await self._scan_and_format(media_path, metadata)
51+
return await self._scan_and_format(
52+
media_path, metadata, auth_header=request.headers.get("Authorization")
53+
)

src/matrix_content_scanner/servlets/thumbnail.py

+1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ async def handle_thumbnail(
2626
media = await self._scanner.scan_file(
2727
media_path=media_path,
2828
thumbnail_params=request.query,
29+
auth_header=request.headers.get("Authorization"),
2930
)
3031

3132
return 200, _BytesResponse(

0 commit comments

Comments
 (0)