@@ -33,17 +33,24 @@ class _PathNotFoundException(Exception):
33
33
class FileDownloader :
34
34
MEDIA_DOWNLOAD_PREFIX = "_matrix/media/%s/download"
35
35
MEDIA_THUMBNAIL_PREFIX = "_matrix/media/%s/thumbnail"
36
+ MEDIA_DOWNLOAD_AUTHENTICATED_PREFIX = "_matrix/client/%s/media/download"
37
+ MEDIA_THUMBNAIL_AUTHENTICATED_PREFIX = "_matrix/client/%s/media/thumbnail"
36
38
37
39
def __init__ (self , mcs : "MatrixContentScanner" ):
38
40
self ._base_url = mcs .config .download .base_homeserver_url
39
41
self ._well_known_cache : Dict [str , Optional [str ]] = {}
40
42
self ._proxy_url = mcs .config .download .proxy
41
- self ._headers = mcs .config .download .additional_headers
43
+ self ._headers = (
44
+ mcs .config .download .additional_headers
45
+ if mcs .config .download .additional_headers is not None
46
+ else {}
47
+ )
42
48
43
49
async def download_file (
44
50
self ,
45
51
media_path : str ,
46
52
thumbnail_params : Optional [MultiMapping [str ]] = None ,
53
+ auth_header : Optional [str ] = None ,
47
54
) -> MediaDescription :
48
55
"""Retrieve the file with the given `server_name/media_id` path, and stores it on
49
56
disk.
@@ -52,6 +59,8 @@ async def download_file(
52
59
media_path: The path identifying the media to retrieve.
53
60
thumbnail_params: If present, then we want to request and scan a thumbnail
54
61
generated with the provided parameters instead of the full media.
62
+ auth_header: If present, we forward the given Authorization header, this is
63
+ required for authenticated media endpoints.
55
64
56
65
Returns:
57
66
A description of the file (including its full content).
@@ -60,27 +69,45 @@ async def download_file(
60
69
ContentScannerRestError: The file was not found or could not be downloaded due
61
70
to an error on the remote homeserver's side.
62
71
"""
72
+
73
+ auth_media = True if auth_header is not None else False
74
+
75
+ prefix = (
76
+ self .MEDIA_DOWNLOAD_AUTHENTICATED_PREFIX
77
+ if auth_media
78
+ else self .MEDIA_DOWNLOAD_PREFIX
79
+ )
80
+ if thumbnail_params is not None :
81
+ prefix = (
82
+ self .MEDIA_THUMBNAIL_AUTHENTICATED_PREFIX
83
+ if auth_media
84
+ else self .MEDIA_THUMBNAIL_PREFIX
85
+ )
86
+
63
87
url = await self ._build_https_url (
64
- media_path , for_thumbnail = thumbnail_params is not None
88
+ media_path , prefix , "v1" if auth_media else "v3"
65
89
)
66
90
67
91
# Attempt to retrieve the file at the generated URL.
68
92
try :
69
- file = await self ._get_file_content (url , thumbnail_params )
93
+ file = await self ._get_file_content (url , thumbnail_params , auth_header )
70
94
except _PathNotFoundException :
95
+ if auth_media :
96
+ raise ContentScannerRestError (
97
+ http_status = HTTPStatus .NOT_FOUND ,
98
+ reason = ErrCode .NOT_FOUND ,
99
+ info = "File not found" ,
100
+ )
101
+
71
102
# If the file could not be found, it might be because the homeserver hasn't
72
103
# been upgraded to a version that supports Matrix v1.1 endpoints yet, so try
73
104
# again with an r0 endpoint.
74
105
logger .info ("File not found, trying legacy r0 path" )
75
106
76
- url = await self ._build_https_url (
77
- media_path ,
78
- endpoint_version = "r0" ,
79
- for_thumbnail = thumbnail_params is not None ,
80
- )
107
+ url = await self ._build_https_url (media_path , prefix , "r0" )
81
108
82
109
try :
83
- file = await self ._get_file_content (url , thumbnail_params )
110
+ file = await self ._get_file_content (url , thumbnail_params , auth_header )
84
111
except _PathNotFoundException :
85
112
# If that still failed, raise an error.
86
113
raise ContentScannerRestError (
@@ -94,9 +121,8 @@ async def download_file(
94
121
async def _build_https_url (
95
122
self ,
96
123
media_path : str ,
97
- endpoint_version : str = "v3" ,
98
- * ,
99
- for_thumbnail : bool ,
124
+ prefix : str ,
125
+ endpoint_version : str ,
100
126
) -> str :
101
127
"""Turn a `server_name/media_id` path into an https:// one we can use to fetch
102
128
the media.
@@ -107,10 +133,8 @@ async def _build_https_url(
107
133
Args:
108
134
media_path: The media path to translate.
109
135
endpoint_version: The version of the download endpoint to use. As of Matrix
110
- v1.1, this is either "v3" or "r0".
111
- for_thumbnail: True if a server-side thumbnail is desired instead of the full
112
- media. In that case, the URL for the `/thumbnail` endpoint is returned
113
- instead of the `/download` endpoint.
136
+ v1.11, this is "v1" for authenticated media. For unauthenticated media
137
+ this is either "v3" or "r0".
114
138
115
139
Returns:
116
140
An https URL to use. If `base_homeserver_url` is set in the config, this
@@ -140,10 +164,6 @@ async def _build_https_url(
140
164
# didn't find a .well-known file.
141
165
base_url = "https://" + server_name
142
166
143
- prefix = (
144
- self .MEDIA_THUMBNAIL_PREFIX if for_thumbnail else self .MEDIA_DOWNLOAD_PREFIX
145
- )
146
-
147
167
# Build the full URL.
148
168
path_prefix = prefix % endpoint_version
149
169
url = "%s/%s/%s/%s" % (
@@ -159,12 +179,15 @@ async def _get_file_content(
159
179
self ,
160
180
url : str ,
161
181
thumbnail_params : Optional [MultiMapping [str ]],
182
+ auth_header : Optional [str ] = None ,
162
183
) -> MediaDescription :
163
184
"""Retrieve the content of the file at a given URL.
164
185
165
186
Args:
166
187
url: The URL to query.
167
188
thumbnail_params: Query parameters used if the request is for a thumbnail.
189
+ auth_header: If present, we forward the given Authorization header, this is
190
+ required for authenticated media endpoints.
168
191
169
192
Returns:
170
193
A description of the file (including its full content).
@@ -178,7 +201,9 @@ async def _get_file_content(
178
201
ContentScannerRestError: the server returned a non-200 status which cannot
179
202
meant that the path wasn't understood.
180
203
"""
181
- code , body , headers = await self ._get (url , query = thumbnail_params )
204
+ code , body , headers = await self ._get (
205
+ url , query = thumbnail_params , auth_header = auth_header
206
+ )
182
207
183
208
logger .info ("Remote server responded with %d" , code )
184
209
@@ -307,12 +332,15 @@ async def _get(
307
332
self ,
308
333
url : str ,
309
334
query : Optional [MultiMapping [str ]] = None ,
335
+ auth_header : Optional [str ] = None ,
310
336
) -> Tuple [int , bytes , CIMultiDictProxy [str ]]:
311
337
"""Sends a GET request to the provided URL.
312
338
313
339
Args:
314
340
url: The URL to send requests to.
315
341
query: Optional parameters to use in the request's query string.
342
+ auth_header: If present, we forward the given Authorization header, this is
343
+ required for authenticated media endpoints.
316
344
317
345
Returns:
318
346
The HTTP status code, body and headers the remote server responded with.
@@ -324,10 +352,15 @@ async def _get(
324
352
try :
325
353
logger .info ("Sending GET request to %s" , url )
326
354
async with aiohttp .ClientSession () as session :
355
+ if auth_header is not None :
356
+ request_headers = {"Authorization" : auth_header , ** self ._headers }
357
+ else :
358
+ request_headers = self ._headers
359
+
327
360
async with session .get (
328
361
url ,
329
362
proxy = self ._proxy_url ,
330
- headers = self . _headers ,
363
+ headers = request_headers ,
331
364
params = query ,
332
365
) as resp :
333
366
return resp .status , await resp .read (), resp .headers
0 commit comments