Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not overwrite "path" and "root_path" scope keys #2352

Merged
merged 3 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions starlette/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,7 @@ def base_url(self) -> URL:
base_url_scope = dict(self.scope)
base_url_scope["path"] = "/"
base_url_scope["query_string"] = b""
base_url_scope["root_path"] = base_url_scope.get(
"app_root_path", base_url_scope.get("root_path", "")
)
base_url_scope["root_path"] = base_url_scope.get("root_path", "")
self._base_url = URL(scope=base_url_scope)
return self._base_url

Expand Down
30 changes: 21 additions & 9 deletions starlette/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,11 @@ def __init__(
self.path_regex, self.path_format, self.param_convertors = compile_path(path)

def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
path_params: "typing.Dict[str, typing.Any]"
if scope["type"] == "http":
match = self.path_regex.match(scope["path"])
root_path = scope.get("route_root_path", scope.get("root_path", ""))
path = scope.get("route_path", re.sub(r"^" + root_path, "", scope["path"]))
match = self.path_regex.match(path)
if match:
matched_params = match.groupdict()
for key, value in matched_params.items():
Expand Down Expand Up @@ -338,8 +341,11 @@ def __init__(
self.path_regex, self.path_format, self.param_convertors = compile_path(path)

def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
path_params: "typing.Dict[str, typing.Any]"
if scope["type"] == "websocket":
match = self.path_regex.match(scope["path"])
root_path = scope.get("route_root_path", scope.get("root_path", ""))
path = scope.get("route_path", re.sub(r"^" + root_path, "", scope["path"]))
match = self.path_regex.match(path)
if match:
matched_params = match.groupdict()
for key, value in matched_params.items():
Expand Down Expand Up @@ -410,23 +416,25 @@ def routes(self) -> typing.List[BaseRoute]:
return getattr(self._base_app, "routes", [])

def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
path_params: "typing.Dict[str, typing.Any]"
if scope["type"] in ("http", "websocket"):
path = scope["path"]
match = self.path_regex.match(path)
root_path = scope.get("route_root_path", scope.get("root_path", ""))
route_path = scope.get("route_path", re.sub(r"^" + root_path, "", path))
match = self.path_regex.match(route_path)
if match:
matched_params = match.groupdict()
for key, value in matched_params.items():
matched_params[key] = self.param_convertors[key].convert(value)
remaining_path = "/" + matched_params.pop("path")
matched_path = path[: -len(remaining_path)]
matched_path = route_path[: -len(remaining_path)]
path_params = dict(scope.get("path_params", {}))
path_params.update(matched_params)
root_path = scope.get("root_path", "")
child_scope = {
"path_params": path_params,
"app_root_path": scope.get("app_root_path", root_path),
"root_path": root_path + matched_path,
"path": remaining_path,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we need to keep the original root_path and path in the child scope?

Copy link
Sponsor Member Author

@Kludex Kludex Dec 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't they always there anyway? Those lines were just modifying them, no?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you're right. My point was if we were using this child_scope as the new scope for inner routers, but it seems actually that it's always merged with the root scope.

"route_root_path": root_path + matched_path,
"route_path": remaining_path,
"endpoint": self.app,
}
return Match.FULL, child_scope
Expand Down Expand Up @@ -767,11 +775,15 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await partial.handle(scope, receive, send)
return

if scope["type"] == "http" and self.redirect_slashes and scope["path"] != "/":
root_path = scope.get("route_root_path", scope.get("root_path", ""))
path = scope.get("route_path", re.sub(r"^" + root_path, "", scope["path"]))
if scope["type"] == "http" and self.redirect_slashes and path != "/":
redirect_scope = dict(scope)
if scope["path"].endswith("/"):
if path.endswith("/"):
redirect_scope["route_path"] = path.rstrip("/")
redirect_scope["path"] = redirect_scope["path"].rstrip("/")
else:
redirect_scope["route_path"] = path + "/"
redirect_scope["path"] = redirect_scope["path"] + "/"

for route in self.routes:
Expand Down
5 changes: 4 additions & 1 deletion starlette/staticfiles.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import importlib.util
import os
import re
import stat
import typing
from email.utils import parsedate
Expand Down Expand Up @@ -108,7 +109,9 @@ def get_path(self, scope: Scope) -> str:
Given the ASGI scope, return the `path` string to serve up,
with OS specific path separators, and any '..', '.' components removed.
"""
return os.path.normpath(os.path.join(*scope["path"].split("/"))) # type: ignore[no-any-return] # noqa: E501
root_path = scope.get("route_root_path", scope.get("root_path", ""))
path = scope.get("route_path", re.sub(r"^" + root_path, "", scope["path"]))
return os.path.normpath(os.path.join(*path.split("/"))) # type: ignore[no-any-return] # noqa: E501

async def get_response(self, path: str, scope: Scope) -> Response:
"""
Expand Down
58 changes: 56 additions & 2 deletions tests/test_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ async def websocket_params(session: WebSocket):


@pytest.fixture
def client(test_client_factory):
def client(test_client_factory: typing.Callable[..., TestClient]):
with test_client_factory(app) as client:
yield client

Expand All @@ -170,7 +170,7 @@ def client(test_client_factory):
r":UserWarning"
r":charset_normalizer.api"
)
def test_router(client):
def test_router(client: TestClient):
response = client.get("/")
assert response.status_code == 200
assert response.text == "Hello, world"
Expand Down Expand Up @@ -1210,3 +1210,57 @@ async def startup() -> None:
... # pragma: nocover

router.on_event("startup")(startup)


async def echo_paths(request: Request, name: str):
return JSONResponse(
{
"name": name,
"path": request.scope["path"],
"root_path": request.scope["root_path"],
}
)


echo_paths_routes = [
Route(
"/path",
functools.partial(echo_paths, name="path"),
name="path",
methods=["GET"],
),
Mount(
"/root",
name="mount",
routes=[
Route(
"/path",
functools.partial(echo_paths, name="subpath"),
name="subpath",
methods=["GET"],
)
],
),
]


def test_paths_with_root_path(test_client_factory: typing.Callable[..., TestClient]):
app = Starlette(routes=echo_paths_routes)
client = test_client_factory(
app, base_url="https://www.example.org/", root_path="/root"
)
response = client.get("/root/path")
assert response.status_code == 200
assert response.json() == {
"name": "path",
"path": "/root/path",
"root_path": "/root",
}

response = client.get("/root/root/path")
assert response.status_code == 200
assert response.json() == {
"name": "subpath",
"path": "/root/root/path",
"root_path": "/root",
}