Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

♻️ Refactor logic to handle root_path to keep compatibility with ASGI and compatibility with other non-Starlette-specific libraries like a2wsgi #2400

Merged
merged 14 commits into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions starlette/_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import asyncio
import functools
import re
import sys
import typing
from contextlib import contextmanager

from starlette.types import Scope

if sys.version_info >= (3, 10): # pragma: no cover
from typing import TypeGuard
else: # pragma: no cover
Expand Down Expand Up @@ -86,3 +89,9 @@ def collapse_excgroups() -> typing.Generator[None, None, None]:
exc = exc.exceptions[0] # pragma: no cover

raise exc


def get_route_path(scope: Scope) -> str:
root_path = scope.get("root_path", "")
route_path = re.sub(r"^" + root_path, "", scope["path"])
return route_path
2 changes: 1 addition & 1 deletion starlette/datastructures.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(
assert not components, 'Cannot set both "scope" and "**components".'
scheme = scope.get("scheme", "http")
server = scope.get("server", None)
path = scope.get("root_path", "") + scope["path"]
path = scope["path"]
query_string = scope.get("query_string", b"")

host_header = None
Expand Down
10 changes: 8 additions & 2 deletions starlette/middleware/wsgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,16 @@ def build_environ(scope: Scope, body: bytes) -> typing.Dict[str, typing.Any]:
"""
Builds a scope and request body into a WSGI environ object.
"""

script_name = scope.get("root_path", "").encode("utf8").decode("latin1")
path_info = scope["path"].encode("utf8").decode("latin1")
if path_info.startswith(script_name):
path_info = path_info[len(script_name) :]

environ = {
"REQUEST_METHOD": scope["method"],
"SCRIPT_NAME": scope.get("root_path", "").encode("utf8").decode("latin1"),
"PATH_INFO": scope["path"].encode("utf8").decode("latin1"),
"SCRIPT_NAME": script_name,
"PATH_INFO": path_info,
"QUERY_STRING": scope["query_string"].decode("ascii"),
"SERVER_PROTOCOL": f"HTTP/{scope['http_version']}",
"wsgi.version": (1, 0),
Expand Down
13 changes: 11 additions & 2 deletions starlette/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,18 @@ def url(self) -> URL:
def base_url(self) -> URL:
if not hasattr(self, "_base_url"):
base_url_scope = dict(self.scope)
base_url_scope["path"] = "/"
# This is used by request.url_for, it might be used inside a Mount which
# would have its own child scope with its own root_path, but the base URL
# for url_for should still be the top level app root path.
app_root_path = base_url_scope.get(
"app_root_path", base_url_scope.get("root_path", "")
)
path = app_root_path
if not path.endswith("/"):
path += "/"
base_url_scope["path"] = path
base_url_scope["query_string"] = b""
base_url_scope["root_path"] = base_url_scope.get("root_path", "")
base_url_scope["root_path"] = app_root_path
self._base_url = URL(scope=base_url_scope)
return self._base_url

Expand Down
41 changes: 22 additions & 19 deletions starlette/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from enum import Enum

from starlette._exception_handler import wrap_app_handling_exceptions
from starlette._utils import is_async_callable
from starlette._utils import get_route_path, is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.convertors import CONVERTOR_TYPES, Convertor
from starlette.datastructures import URL, Headers, URLPath
Expand Down Expand Up @@ -255,9 +255,8 @@ def __init__(
def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
path_params: "typing.Dict[str, typing.Any]"
if scope["type"] == "http":
root_path = scope.get("route_root_path", scope.get("root_path", ""))
path = scope.get("route_path", re.sub(r"^" + root_path, "", scope["path"]))
match = self.path_regex.match(path)
route_path = get_route_path(scope)
match = self.path_regex.match(route_path)
if match:
matched_params = match.groupdict()
for key, value in matched_params.items():
Expand Down Expand Up @@ -345,9 +344,8 @@ def __init__(
def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
path_params: "typing.Dict[str, typing.Any]"
if scope["type"] == "websocket":
root_path = scope.get("route_root_path", scope.get("root_path", ""))
path = scope.get("route_path", re.sub(r"^" + root_path, "", scope["path"]))
match = self.path_regex.match(path)
route_path = get_route_path(scope)
match = self.path_regex.match(route_path)
if match:
matched_params = match.groupdict()
for key, value in matched_params.items():
Expand Down Expand Up @@ -420,9 +418,8 @@ def routes(self) -> typing.List[BaseRoute]:
def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
path_params: "typing.Dict[str, typing.Any]"
if scope["type"] in ("http", "websocket"):
path = scope["path"]
root_path = scope.get("route_root_path", scope.get("root_path", ""))
route_path = scope.get("route_path", re.sub(r"^" + root_path, "", path))
root_path = scope.get("root_path", "")
route_path = get_route_path(scope)
match = self.path_regex.match(route_path)
if match:
matched_params = match.groupdict()
Expand All @@ -432,11 +429,20 @@ def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
matched_path = route_path[: -len(remaining_path)]
path_params = dict(scope.get("path_params", {}))
path_params.update(matched_params)
root_path = scope.get("root_path", "")
child_scope = {
"path_params": path_params,
"route_root_path": root_path + matched_path,
"route_path": remaining_path,
# app_root_path will only be set at the top level scope,
# initialized with the (optional) value of a root_path
# set above/before Starlette. And even though any
# mount will have its own child scope with its own respective
# root_path, the app_root_path will always be available in all
# the child scopes with the same top level value because it's
# set only once here with a default, any other child scope will
# just inherit that app_root_path default value stored in the
# scope. All this is needed to support Request.url_for(), as it
# uses the app_root_path to build the URL path.
"app_root_path": scope.get("app_root_path", root_path),
"root_path": root_path + matched_path,
"endpoint": self.app,
}
return Match.FULL, child_scope
Expand Down Expand Up @@ -787,15 +793,12 @@ async def app(self, scope: Scope, receive: Receive, send: Send) -> None:
await partial.handle(scope, receive, send)
return

root_path = scope.get("route_root_path", scope.get("root_path", ""))
path = scope.get("route_path", re.sub(r"^" + root_path, "", scope["path"]))
if scope["type"] == "http" and self.redirect_slashes and path != "/":
route_path = get_route_path(scope)
if scope["type"] == "http" and self.redirect_slashes and route_path != "/":
redirect_scope = dict(scope)
if path.endswith("/"):
redirect_scope["route_path"] = path.rstrip("/")
if route_path.endswith("/"):
redirect_scope["path"] = redirect_scope["path"].rstrip("/")
else:
redirect_scope["route_path"] = path + "/"
redirect_scope["path"] = redirect_scope["path"] + "/"

for route in self.routes:
Expand Down
7 changes: 3 additions & 4 deletions starlette/staticfiles.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import importlib.util
import os
import re
import stat
import typing
from email.utils import parsedate

import anyio
import anyio.to_thread

from starlette._utils import get_route_path
from starlette.datastructures import URL, Headers
from starlette.exceptions import HTTPException
from starlette.responses import FileResponse, RedirectResponse, Response
Expand Down Expand Up @@ -110,9 +110,8 @@ def get_path(self, scope: Scope) -> str:
Given the ASGI scope, return the `path` string to serve up,
with OS specific path separators, and any '..', '.' components removed.
"""
root_path = scope.get("route_root_path", scope.get("root_path", ""))
path = scope.get("route_path", re.sub(r"^" + root_path, "", scope["path"]))
return os.path.normpath(os.path.join(*path.split("/"))) # type: ignore[no-any-return] # noqa: E501
route_path = get_route_path(scope)
return os.path.normpath(os.path.join(*route_path.split("/"))) # noqa: E501

async def get_response(self, path: str, scope: Scope) -> Response:
"""
Expand Down
5 changes: 3 additions & 2 deletions tests/middleware/test_wsgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ def test_build_environ():
"http_version": "1.1",
"method": "GET",
"scheme": "https",
"path": "/",
"path": "/sub/",
"root_path": "/sub",
"query_string": b"a=123&b=456",
"headers": [
(b"host", b"www.example.org"),
Expand All @@ -117,7 +118,7 @@ def test_build_environ():
"QUERY_STRING": "a=123&b=456",
"REMOTE_ADDR": "134.56.78.4",
"REQUEST_METHOD": "GET",
"SCRIPT_NAME": "",
"SCRIPT_NAME": "/sub",
"SERVER_NAME": "www.example.org",
"SERVER_PORT": 443,
"SERVER_PROTOCOL": "HTTP/1.1",
Expand Down
40 changes: 33 additions & 7 deletions tests/test_routing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import contextlib
import functools
import json
import typing
import uuid

Expand Down Expand Up @@ -563,12 +564,12 @@ def test_url_for_with_root_path(test_client_factory):
client = test_client_factory(
app, base_url="https://www.example.org/", root_path="/sub_path"
)
response = client.get("/")
response = client.get("/sub_path/")
assert response.json() == {
"index": "https://www.example.org/sub_path/",
"submount": "https://www.example.org/sub_path/submount/",
}
response = client.get("/submount/")
response = client.get("/sub_path/submount/")
assert response.json() == {
"index": "https://www.example.org/sub_path/",
"submount": "https://www.example.org/sub_path/submount/",
Expand Down Expand Up @@ -1242,23 +1243,37 @@ async def echo_paths(request: Request, name: str):
)


async def pure_asgi_echo_paths(scope: Scope, receive: Receive, send: Send, name: str):
data = {"name": name, "path": scope["path"], "root_path": scope["root_path"]}
content = json.dumps(data).encode("utf-8")
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [(b"content-type", b"application/json")],
}
)
await send({"type": "http.response.body", "body": content})


echo_paths_routes = [
Route(
"/path",
functools.partial(echo_paths, name="path"),
name="path",
methods=["GET"],
),
Mount("/asgipath", app=functools.partial(pure_asgi_echo_paths, name="asgipath")),
Mount(
"/root",
"/sub",
name="mount",
routes=[
Route(
"/path",
functools.partial(echo_paths, name="subpath"),
name="subpath",
methods=["GET"],
)
),
],
),
]
Expand All @@ -1276,11 +1291,22 @@ def test_paths_with_root_path(test_client_factory: typing.Callable[..., TestClie
"path": "/root/path",
"root_path": "/root",
}
response = client.get("/root/asgipath/")
assert response.status_code == 200
assert response.json() == {
"name": "asgipath",
"path": "/root/asgipath/",
# Things that mount other ASGI apps, like WSGIMiddleware, would not be aware
# of the prefixed path, and would have their own notion of their own paths,
# so they need to be able to rely on the root_path to know the location they
# are mounted on
"root_path": "/root/asgipath",
}

response = client.get("/root/root/path")
response = client.get("/root/sub/path")
assert response.status_code == 200
assert response.json() == {
"name": "subpath",
"path": "/root/root/path",
"root_path": "/root",
"path": "/root/sub/path",
"root_path": "/root/sub",
}