diff --git a/starlette/_utils.py b/starlette/_utils.py index 26854f3d4..15ccd92a4 100644 --- a/starlette/_utils.py +++ b/starlette/_utils.py @@ -1,9 +1,12 @@ import asyncio import functools +import re import sys import typing from contextlib import contextmanager +from starlette.types import Scope + if sys.version_info >= (3, 10): # pragma: no cover from typing import TypeGuard else: # pragma: no cover @@ -86,3 +89,9 @@ def collapse_excgroups() -> typing.Generator[None, None, None]: exc = exc.exceptions[0] # pragma: no cover raise exc + + +def get_route_path(scope: Scope) -> str: + root_path = scope.get("root_path", "") + route_path = re.sub(r"^" + root_path, "", scope["path"]) + return route_path diff --git a/starlette/datastructures.py b/starlette/datastructures.py index a0c3ba140..e12957f50 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -30,7 +30,7 @@ def __init__( assert not components, 'Cannot set both "scope" and "**components".' scheme = scope.get("scheme", "http") server = scope.get("server", None) - path = scope.get("root_path", "") + scope["path"] + path = scope["path"] query_string = scope.get("query_string", b"") host_header = None diff --git a/starlette/middleware/wsgi.py b/starlette/middleware/wsgi.py index 95578c9d2..2ce83b074 100644 --- a/starlette/middleware/wsgi.py +++ b/starlette/middleware/wsgi.py @@ -20,10 +20,16 @@ def build_environ(scope: Scope, body: bytes) -> typing.Dict[str, typing.Any]: """ Builds a scope and request body into a WSGI environ object. """ + + script_name = scope.get("root_path", "").encode("utf8").decode("latin1") + path_info = scope["path"].encode("utf8").decode("latin1") + if path_info.startswith(script_name): + path_info = path_info[len(script_name) :] + environ = { "REQUEST_METHOD": scope["method"], - "SCRIPT_NAME": scope.get("root_path", "").encode("utf8").decode("latin1"), - "PATH_INFO": scope["path"].encode("utf8").decode("latin1"), + "SCRIPT_NAME": script_name, + "PATH_INFO": path_info, "QUERY_STRING": scope["query_string"].decode("ascii"), "SERVER_PROTOCOL": f"HTTP/{scope['http_version']}", "wsgi.version": (1, 0), diff --git a/starlette/requests.py b/starlette/requests.py index 83a52aca1..e51223bab 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -99,9 +99,18 @@ def url(self) -> URL: def base_url(self) -> URL: if not hasattr(self, "_base_url"): base_url_scope = dict(self.scope) - base_url_scope["path"] = "/" + # This is used by request.url_for, it might be used inside a Mount which + # would have its own child scope with its own root_path, but the base URL + # for url_for should still be the top level app root path. + app_root_path = base_url_scope.get( + "app_root_path", base_url_scope.get("root_path", "") + ) + path = app_root_path + if not path.endswith("/"): + path += "/" + base_url_scope["path"] = path base_url_scope["query_string"] = b"" - base_url_scope["root_path"] = base_url_scope.get("root_path", "") + base_url_scope["root_path"] = app_root_path self._base_url = URL(scope=base_url_scope) return self._base_url diff --git a/starlette/routing.py b/starlette/routing.py index 0ced9e667..d718bb921 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -10,7 +10,7 @@ from enum import Enum from starlette._exception_handler import wrap_app_handling_exceptions -from starlette._utils import is_async_callable +from starlette._utils import get_route_path, is_async_callable from starlette.concurrency import run_in_threadpool from starlette.convertors import CONVERTOR_TYPES, Convertor from starlette.datastructures import URL, Headers, URLPath @@ -255,9 +255,8 @@ def __init__( def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: path_params: "typing.Dict[str, typing.Any]" if scope["type"] == "http": - root_path = scope.get("route_root_path", scope.get("root_path", "")) - path = scope.get("route_path", re.sub(r"^" + root_path, "", scope["path"])) - match = self.path_regex.match(path) + route_path = get_route_path(scope) + match = self.path_regex.match(route_path) if match: matched_params = match.groupdict() for key, value in matched_params.items(): @@ -345,9 +344,8 @@ def __init__( def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: path_params: "typing.Dict[str, typing.Any]" if scope["type"] == "websocket": - root_path = scope.get("route_root_path", scope.get("root_path", "")) - path = scope.get("route_path", re.sub(r"^" + root_path, "", scope["path"])) - match = self.path_regex.match(path) + route_path = get_route_path(scope) + match = self.path_regex.match(route_path) if match: matched_params = match.groupdict() for key, value in matched_params.items(): @@ -420,9 +418,8 @@ def routes(self) -> typing.List[BaseRoute]: def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: path_params: "typing.Dict[str, typing.Any]" if scope["type"] in ("http", "websocket"): - path = scope["path"] - root_path = scope.get("route_root_path", scope.get("root_path", "")) - route_path = scope.get("route_path", re.sub(r"^" + root_path, "", path)) + root_path = scope.get("root_path", "") + route_path = get_route_path(scope) match = self.path_regex.match(route_path) if match: matched_params = match.groupdict() @@ -432,11 +429,20 @@ def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: matched_path = route_path[: -len(remaining_path)] path_params = dict(scope.get("path_params", {})) path_params.update(matched_params) - root_path = scope.get("root_path", "") child_scope = { "path_params": path_params, - "route_root_path": root_path + matched_path, - "route_path": remaining_path, + # app_root_path will only be set at the top level scope, + # initialized with the (optional) value of a root_path + # set above/before Starlette. And even though any + # mount will have its own child scope with its own respective + # root_path, the app_root_path will always be available in all + # the child scopes with the same top level value because it's + # set only once here with a default, any other child scope will + # just inherit that app_root_path default value stored in the + # scope. All this is needed to support Request.url_for(), as it + # uses the app_root_path to build the URL path. + "app_root_path": scope.get("app_root_path", root_path), + "root_path": root_path + matched_path, "endpoint": self.app, } return Match.FULL, child_scope @@ -787,15 +793,12 @@ async def app(self, scope: Scope, receive: Receive, send: Send) -> None: await partial.handle(scope, receive, send) return - root_path = scope.get("route_root_path", scope.get("root_path", "")) - path = scope.get("route_path", re.sub(r"^" + root_path, "", scope["path"])) - if scope["type"] == "http" and self.redirect_slashes and path != "/": + route_path = get_route_path(scope) + if scope["type"] == "http" and self.redirect_slashes and route_path != "/": redirect_scope = dict(scope) - if path.endswith("/"): - redirect_scope["route_path"] = path.rstrip("/") + if route_path.endswith("/"): redirect_scope["path"] = redirect_scope["path"].rstrip("/") else: - redirect_scope["route_path"] = path + "/" redirect_scope["path"] = redirect_scope["path"] + "/" for route in self.routes: diff --git a/starlette/staticfiles.py b/starlette/staticfiles.py index 895105a7d..0101b11bc 100644 --- a/starlette/staticfiles.py +++ b/starlette/staticfiles.py @@ -1,6 +1,5 @@ import importlib.util import os -import re import stat import typing from email.utils import parsedate @@ -8,6 +7,7 @@ import anyio import anyio.to_thread +from starlette._utils import get_route_path from starlette.datastructures import URL, Headers from starlette.exceptions import HTTPException from starlette.responses import FileResponse, RedirectResponse, Response @@ -110,9 +110,8 @@ def get_path(self, scope: Scope) -> str: Given the ASGI scope, return the `path` string to serve up, with OS specific path separators, and any '..', '.' components removed. """ - root_path = scope.get("route_root_path", scope.get("root_path", "")) - path = scope.get("route_path", re.sub(r"^" + root_path, "", scope["path"])) - return os.path.normpath(os.path.join(*path.split("/"))) # type: ignore[no-any-return] # noqa: E501 + route_path = get_route_path(scope) + return os.path.normpath(os.path.join(*route_path.split("/"))) # noqa: E501 async def get_response(self, path: str, scope: Scope) -> Response: """ diff --git a/tests/middleware/test_wsgi.py b/tests/middleware/test_wsgi.py index fe527e373..316cb191f 100644 --- a/tests/middleware/test_wsgi.py +++ b/tests/middleware/test_wsgi.py @@ -92,7 +92,8 @@ def test_build_environ(): "http_version": "1.1", "method": "GET", "scheme": "https", - "path": "/", + "path": "/sub/", + "root_path": "/sub", "query_string": b"a=123&b=456", "headers": [ (b"host", b"www.example.org"), @@ -117,7 +118,7 @@ def test_build_environ(): "QUERY_STRING": "a=123&b=456", "REMOTE_ADDR": "134.56.78.4", "REQUEST_METHOD": "GET", - "SCRIPT_NAME": "", + "SCRIPT_NAME": "/sub", "SERVER_NAME": "www.example.org", "SERVER_PORT": 443, "SERVER_PROTOCOL": "HTTP/1.1", diff --git a/tests/test_routing.py b/tests/test_routing.py index cd6f37880..128f06674 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -1,5 +1,6 @@ import contextlib import functools +import json import typing import uuid @@ -563,12 +564,12 @@ def test_url_for_with_root_path(test_client_factory): client = test_client_factory( app, base_url="https://www.example.org/", root_path="/sub_path" ) - response = client.get("/") + response = client.get("/sub_path/") assert response.json() == { "index": "https://www.example.org/sub_path/", "submount": "https://www.example.org/sub_path/submount/", } - response = client.get("/submount/") + response = client.get("/sub_path/submount/") assert response.json() == { "index": "https://www.example.org/sub_path/", "submount": "https://www.example.org/sub_path/submount/", @@ -1242,6 +1243,19 @@ async def echo_paths(request: Request, name: str): ) +async def pure_asgi_echo_paths(scope: Scope, receive: Receive, send: Send, name: str): + data = {"name": name, "path": scope["path"], "root_path": scope["root_path"]} + content = json.dumps(data).encode("utf-8") + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [(b"content-type", b"application/json")], + } + ) + await send({"type": "http.response.body", "body": content}) + + echo_paths_routes = [ Route( "/path", @@ -1249,8 +1263,9 @@ async def echo_paths(request: Request, name: str): name="path", methods=["GET"], ), + Mount("/asgipath", app=functools.partial(pure_asgi_echo_paths, name="asgipath")), Mount( - "/root", + "/sub", name="mount", routes=[ Route( @@ -1258,7 +1273,7 @@ async def echo_paths(request: Request, name: str): functools.partial(echo_paths, name="subpath"), name="subpath", methods=["GET"], - ) + ), ], ), ] @@ -1276,11 +1291,22 @@ def test_paths_with_root_path(test_client_factory: typing.Callable[..., TestClie "path": "/root/path", "root_path": "/root", } + response = client.get("/root/asgipath/") + assert response.status_code == 200 + assert response.json() == { + "name": "asgipath", + "path": "/root/asgipath/", + # Things that mount other ASGI apps, like WSGIMiddleware, would not be aware + # of the prefixed path, and would have their own notion of their own paths, + # so they need to be able to rely on the root_path to know the location they + # are mounted on + "root_path": "/root/asgipath", + } - response = client.get("/root/root/path") + response = client.get("/root/sub/path") assert response.status_code == 200 assert response.json() == { "name": "subpath", - "path": "/root/root/path", - "root_path": "/root", + "path": "/root/sub/path", + "root_path": "/root/sub", }