Skip to content

Commit 1131b3c

Browse files
Ensure accurate root_path removal in get_route_path function (#2600)
* fix: regex inside function get_route_path to remove root_path * fix: apply format ruff * fix: mypy --------- Co-authored-by: Marcelo Trylesinski <[email protected]>
1 parent 1eb4036 commit 1131b3c

File tree

3 files changed

+31
-2
lines changed

3 files changed

+31
-2
lines changed

starlette/_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,5 +85,5 @@ def collapse_excgroups() -> typing.Generator[None, None, None]:
8585

8686
def get_route_path(scope: Scope) -> str:
8787
root_path = scope.get("root_path", "")
88-
route_path = re.sub(r"^" + root_path, "", scope["path"])
88+
route_path = re.sub(r"^" + root_path + r"(?=/|$)", "", scope["path"])
8989
return route_path

tests/test__utils.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import functools
22
from typing import Any
33

4-
from starlette._utils import is_async_callable
4+
import pytest
5+
6+
from starlette._utils import get_route_path, is_async_callable
7+
from starlette.types import Scope
58

69

710
def test_async_func() -> None:
@@ -78,3 +81,15 @@ async def async_func(
7881
partial = functools.partial(async_func, b=2)
7982
nested_partial = functools.partial(partial, a=1)
8083
assert is_async_callable(nested_partial)
84+
85+
86+
@pytest.mark.parametrize(
87+
"scope, expected_result",
88+
[
89+
({"path": "/foo-123/bar", "root_path": "/foo"}, "/foo-123/bar"),
90+
({"path": "/foo/bar", "root_path": "/foo"}, "/bar"),
91+
({"path": "/foo", "root_path": "/foo"}, ""),
92+
],
93+
)
94+
def test_get_route_path(scope: Scope, expected_result: str) -> None:
95+
assert get_route_path(scope) == expected_result

tests/test_routing.py

+14
Original file line numberDiff line numberDiff line change
@@ -1221,6 +1221,12 @@ async def pure_asgi_echo_paths(scope: Scope, receive: Receive, send: Send, name:
12211221
name="path",
12221222
methods=["GET"],
12231223
),
1224+
Route(
1225+
"/root-queue/path",
1226+
functools.partial(echo_paths, name="queue_path"),
1227+
name="queue_path",
1228+
methods=["POST"],
1229+
),
12241230
Mount("/asgipath", app=functools.partial(pure_asgi_echo_paths, name="asgipath")),
12251231
Mount(
12261232
"/sub",
@@ -1266,3 +1272,11 @@ def test_paths_with_root_path(test_client_factory: TestClientFactory) -> None:
12661272
"path": "/root/sub/path",
12671273
"root_path": "/root/sub",
12681274
}
1275+
1276+
response = client.post("/root/root-queue/path")
1277+
assert response.status_code == 200
1278+
assert response.json() == {
1279+
"name": "queue_path",
1280+
"path": "/root/root-queue/path",
1281+
"root_path": "/root",
1282+
}

0 commit comments

Comments
 (0)