Skip to content

Commit 1dbdc3f

Browse files
committed
Add support for authenticated media downloads
1 parent 55c53e0 commit 1dbdc3f

File tree

5 files changed

+31
-6
lines changed

5 files changed

+31
-6
lines changed

mautrix/api.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,7 @@ def get_download_url(
462462
mxc_uri: str,
463463
download_type: Literal["download", "thumbnail"] = "download",
464464
file_name: str | None = None,
465+
authenticated: bool = False,
465466
) -> URL:
466467
"""
467468
Get the full HTTP URL to download a ``mxc://`` URI.
@@ -470,6 +471,7 @@ def get_download_url(
470471
mxc_uri: The MXC URI whose full URL to get.
471472
download_type: The type of download ("download" or "thumbnail").
472473
file_name: Optionally, a file name to include in the download URL.
474+
authenticated: Whether to use the new authenticated download endpoint in Matrix v1.11.
473475
474476
Returns:
475477
The full HTTP URL.
@@ -485,7 +487,11 @@ def get_download_url(
485487
"https://matrix-client.matrix.org/_matrix/media/v3/download/matrix.org/pqjkOuKZ1ZKRULWXgz2IVZV6/hello.png"
486488
"""
487489
server_name, media_id = self.parse_mxc_uri(mxc_uri)
488-
url = self.base_url / str(APIPath.MEDIA) / "v3" / download_type / server_name / media_id
490+
if authenticated:
491+
url = self.base_url / str(APIPath.CLIENT) / "v1" / "media"
492+
else:
493+
url = self.base_url / str(APIPath.MEDIA) / "v3"
494+
url = url / download_type / server_name / media_id
489495
if file_name:
490496
url /= file_name
491497
return url

mautrix/appservice/api/intent.py

+2
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,8 @@ def __init__(
118118
) -> None:
119119
super().__init__(mxid=mxid, api=api, state_store=state_store)
120120
self.bot = bot
121+
if bot is not None:
122+
self.versions_cache = bot.versions_cache
121123
self.log = api.base_log.getChild("intent")
122124

123125
for method in ENSURE_REGISTERED_METHODS:

mautrix/appservice/appservice.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from aiohttp import web
1414
import aiohttp
1515

16-
from mautrix.types import JSON, RoomAlias, UserID
16+
from mautrix.types import JSON, RoomAlias, UserID, VersionsResponse
1717
from mautrix.util.logging import TraceLogger
1818

1919
from ..api import HTTPAPI

mautrix/client/api/modules/media_repository.py

+17-4
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import asyncio
1111
import time
1212

13+
from yarl import URL
14+
1315
from mautrix import __optional_imports__
1416
from mautrix.api import MediaPath, Method
1517
from mautrix.errors import MatrixResponseError, make_request_error
@@ -19,6 +21,7 @@
1921
MediaRepoConfig,
2022
MXOpenGraph,
2123
SerializerError,
24+
SpecVersions,
2225
)
2326
from mautrix.util import background_task
2427
from mautrix.util.async_body import async_iter_bytes
@@ -178,13 +181,17 @@ async def download_media(self, url: ContentURI, timeout_ms: int | None = None) -
178181
Returns:
179182
The raw downloaded data.
180183
"""
181-
url = self.api.get_download_url(url)
184+
authenticated = (await self.versions()).supports(SpecVersions.V111)
185+
url = self.api.get_download_url(url, authenticated=authenticated)
182186
query_params: dict[str, Any] = {"allow_redirect": "true"}
183187
if timeout_ms is not None:
184188
query_params["timeout_ms"] = timeout_ms
189+
headers: dict[str, str] = {}
190+
if authenticated:
191+
headers["Authorization"] = f"Bearer {self.api.token}"
185192
req_id = self.api.log_download_request(url, query_params)
186193
start = time.monotonic()
187-
async with self.api.session.get(url, params=query_params) as response:
194+
async with self.api.session.get(url, params=query_params, headers=headers) as response:
188195
try:
189196
response.raise_for_status()
190197
return await response.read()
@@ -223,7 +230,10 @@ async def download_thumbnail(
223230
Returns:
224231
The raw downloaded data.
225232
"""
226-
url = self.api.get_download_url(url, download_type="thumbnail")
233+
authenticated = (await self.versions()).supports(SpecVersions.V111)
234+
url = self.api.get_download_url(
235+
url, download_type="thumbnail", authenticated=authenticated
236+
)
227237
query_params: dict[str, Any] = {"allow_redirect": "true"}
228238
if width is not None:
229239
query_params["width"] = width
@@ -235,9 +245,12 @@ async def download_thumbnail(
235245
query_params["allow_remote"] = str(allow_remote).lower()
236246
if timeout_ms is not None:
237247
query_params["timeout_ms"] = timeout_ms
248+
headers: dict[str, str] = {}
249+
if authenticated:
250+
headers["Authorization"] = f"Bearer {self.api.token}"
238251
req_id = self.api.log_download_request(url, query_params)
239252
start = time.monotonic()
240-
async with self.api.session.get(url, params=query_params) as response:
253+
async with self.api.session.get(url, params=query_params, headers=headers) as response:
241254
try:
242255
response.raise_for_status()
243256
return await response.read()

mautrix/types/versions.py

+4
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,10 @@ class SpecVersions:
7474
V15 = Version.deserialize("v1.5")
7575
V16 = Version.deserialize("v1.6")
7676
V17 = Version.deserialize("v1.7")
77+
V18 = Version.deserialize("v1.8")
78+
V19 = Version.deserialize("v1.9")
79+
V110 = Version.deserialize("v1.10")
80+
V111 = Version.deserialize("v1.11")
7781

7882

7983
@dataclass

0 commit comments

Comments
 (0)