Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shutdown logic: Only wait on handlers #8495

Merged
merged 8 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 0 additions & 26 deletions aiohttp/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import warnings
from argparse import ArgumentParser
from collections.abc import Iterable
from contextlib import suppress
from functools import partial
from importlib import import_module
from typing import (
Any,
Expand All @@ -21,7 +19,6 @@
Union,
cast,
)
from weakref import WeakSet

from .abc import AbstractAccessLogger
from .helpers import AppKey
Expand Down Expand Up @@ -300,23 +297,6 @@ async def _run_app(
reuse_port: Optional[bool] = None,
handler_cancellation: bool = False,
) -> None:
async def wait(
starting_tasks: "WeakSet[asyncio.Task[object]]", shutdown_timeout: float
) -> None:
# Wait for pending tasks for a given time limit.
t = asyncio.current_task()
assert t is not None
starting_tasks.add(t)
with suppress(asyncio.TimeoutError):
await asyncio.wait_for(_wait(starting_tasks), timeout=shutdown_timeout)

async def _wait(exclude: "WeakSet[asyncio.Task[object]]") -> None:
t = asyncio.current_task()
assert t is not None
exclude.add(t)
while tasks := asyncio.all_tasks().difference(exclude):
await asyncio.wait(tasks)

# An internal function to actually do all dirty job for application running
if asyncio.iscoroutine(app):
app = await app
Expand All @@ -335,12 +315,6 @@ async def _wait(exclude: "WeakSet[asyncio.Task[object]]") -> None:
)

await runner.setup()
# On shutdown we want to avoid waiting on tasks which run forever.
# It's very likely that all tasks which run forever will have been created by
# the time we have completed the application startup (in runner.setup()),
# so we just record all running tasks here and exclude them later.
starting_tasks: "WeakSet[asyncio.Task[object]]" = WeakSet(asyncio.all_tasks())
runner.shutdown_callback = partial(wait, starting_tasks, shutdown_timeout)

sites: List[BaseSite] = []

Expand Down
8 changes: 6 additions & 2 deletions aiohttp/web_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,12 @@ async def shutdown(self, timeout: Optional[float] = 15.0) -> None:
if self._waiter:
self._waiter.cancel()

# wait for handlers
# Wait for graceful disconnection
if self._current_request is not None:
with suppress(asyncio.CancelledError, asyncio.TimeoutError):
async with ceil_timeout(timeout):
await self._current_request.wait_for_disconnection()
# Then cancel handler and wait
with suppress(asyncio.CancelledError, asyncio.TimeoutError):
async with ceil_timeout(timeout):
if self._current_request is not None:
Expand Down Expand Up @@ -461,7 +466,6 @@ async def _handle_request(
start_time: float,
request_handler: Callable[[BaseRequest], Awaitable[StreamResponse]],
) -> Tuple[StreamResponse, bool]:
assert self._request_handler is not None
try:
try:
self._current_request = request
Expand Down
16 changes: 2 additions & 14 deletions aiohttp/web_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import signal
import socket
from abc import ABC, abstractmethod
from typing import Any, Awaitable, Callable, List, Optional, Set, Type
from typing import Any, List, Optional, Set, Type

from yarl import URL

Expand Down Expand Up @@ -230,14 +230,7 @@ async def start(self) -> None:


class BaseRunner(ABC):
__slots__ = (
"shutdown_callback",
"_handle_signals",
"_kwargs",
"_server",
"_sites",
"_shutdown_timeout",
)
__slots__ = ("_handle_signals", "_kwargs", "_server", "_sites", "_shutdown_timeout")

def __init__(
self,
Expand All @@ -246,7 +239,6 @@ def __init__(
shutdown_timeout: float = 60.0,
**kwargs: Any,
) -> None:
self.shutdown_callback: Optional[Callable[[], Awaitable[None]]] = None
self._handle_signals = handle_signals
self._kwargs = kwargs
self._server: Optional[Server] = None
Expand Down Expand Up @@ -304,10 +296,6 @@ async def cleanup(self) -> None:
await asyncio.sleep(0)
self._server.pre_shutdown()
await self.shutdown()

if self.shutdown_callback:
await self.shutdown_callback()

await self._server.shutdown(self._shutdown_timeout)
await self._cleanup_server()

Expand Down
7 changes: 6 additions & 1 deletion aiohttp/web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,12 @@ def connection_lost(
self, handler: RequestHandler, exc: Optional[BaseException] = None
) -> None:
if handler in self._connections:
del self._connections[handler]
if handler._task_handler:
handler._task_handler.add_done_callback(
lambda f: self._connections.pop(handler, None)
)
else:
del self._connections[handler]

def _make_request(
self,
Expand Down
43 changes: 10 additions & 33 deletions tests/test_run_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import pytest

from aiohttp import ClientConnectorError, ClientSession, WSCloseCode, web
from aiohttp import ClientConnectorError, ClientSession, ClientTimeout, WSCloseCode, web
from aiohttp.test_utils import make_mocked_coro
from aiohttp.web_runner import BaseRunner

Expand Down Expand Up @@ -935,8 +935,12 @@ async def test() -> None:
async with ClientSession() as sess:
for _ in range(5): # pragma: no cover
try:
async with sess.get(f"http://localhost:{port}/"):
pass
with pytest.raises(asyncio.TimeoutError):
async with sess.get(
f"http://localhost:{port}/",
timeout=ClientTimeout(total=0.1),
):
pass
except ClientConnectorError:
await asyncio.sleep(0.5)
else:
Expand All @@ -956,6 +960,7 @@ async def run_test(app: web.Application) -> None:
async def handler(request: web.Request) -> web.Response:
nonlocal t
t = asyncio.create_task(task())
await t
Dismissed Show dismissed Hide dismissed
return web.Response(text="FOO")

t = test_task = None
Expand All @@ -968,7 +973,7 @@ async def handler(request: web.Request) -> web.Response:
assert test_task.exception() is None
return t

def test_shutdown_wait_for_task(
def test_shutdown_wait_for_handler(
self, aiohttp_unused_port: Callable[[], int]
) -> None:
port = aiohttp_unused_port()
Expand All @@ -985,7 +990,7 @@ async def task():
assert t.done()
assert not t.cancelled()

def test_shutdown_timeout_task(
def test_shutdown_timeout_handler(
self, aiohttp_unused_port: Callable[[], int]
) -> None:
port = aiohttp_unused_port()
Expand All @@ -1002,34 +1007,6 @@ async def task():
assert t.done()
assert t.cancelled()

def test_shutdown_wait_for_spawned_task(
self, aiohttp_unused_port: Callable[[], int]
) -> None:
port = aiohttp_unused_port()
finished = False
finished_sub = False
sub_t = None

async def sub_task():
nonlocal finished_sub
await asyncio.sleep(1.5)
finished_sub = True

async def task():
nonlocal finished, sub_t
await asyncio.sleep(0.5)
sub_t = asyncio.create_task(sub_task())
finished = True

t = self.run_app(port, 3, task)

assert finished is True
assert t.done()
assert not t.cancelled()
assert finished_sub is True
assert sub_t.done()
assert not sub_t.cancelled()

def test_shutdown_timeout_not_reached(
self, aiohttp_unused_port: Callable[[], int]
) -> None:
Expand Down
6 changes: 4 additions & 2 deletions tests/test_web_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ async def test_connections() -> None:
manager = web.Server(serve)
assert manager.connections == []

handler = object()
handler = mock.Mock(spec_set=web.RequestHandler)
handler._task_handler = None
transport = object()
manager.connection_made(handler, transport) # type: ignore[arg-type]
assert manager.connections == [handler]
Expand All @@ -34,7 +35,8 @@ async def test_connections() -> None:
async def test_shutdown_no_timeout() -> None:
manager = web.Server(serve)

handler = mock.Mock()
handler = mock.Mock(spec_set=web.RequestHandler)
handler._task_handler = None
handler.shutdown = make_mocked_coro(mock.Mock())
transport = mock.Mock()
manager.connection_made(handler, transport)
Expand Down
Loading