From 85480eb699d6a332e604f071117d30a01409a04c Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Sat, 13 Jul 2024 16:00:14 +0100 Subject: [PATCH 1/5] Only wait on handlers --- aiohttp/web.py | 23 ----------------------- aiohttp/web_protocol.py | 8 ++++++-- aiohttp/web_runner.py | 14 +------------- aiohttp/web_server.py | 6 +++++- tests/test_run_app.py | 40 +++++++--------------------------------- 5 files changed, 19 insertions(+), 72 deletions(-) diff --git a/aiohttp/web.py b/aiohttp/web.py index 1a30dd87775..93594f6a902 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -300,23 +300,6 @@ async def _run_app( reuse_port: Optional[bool] = None, handler_cancellation: bool = False, ) -> None: - async def wait( - starting_tasks: "WeakSet[asyncio.Task[object]]", shutdown_timeout: float - ) -> None: - # Wait for pending tasks for a given time limit. - t = asyncio.current_task() - assert t is not None - starting_tasks.add(t) - with suppress(asyncio.TimeoutError): - await asyncio.wait_for(_wait(starting_tasks), timeout=shutdown_timeout) - - async def _wait(exclude: "WeakSet[asyncio.Task[object]]") -> None: - t = asyncio.current_task() - assert t is not None - exclude.add(t) - while tasks := asyncio.all_tasks().difference(exclude): - await asyncio.wait(tasks) - # An internal function to actually do all dirty job for application running if asyncio.iscoroutine(app): app = await app @@ -335,12 +318,6 @@ async def _wait(exclude: "WeakSet[asyncio.Task[object]]") -> None: ) await runner.setup() - # On shutdown we want to avoid waiting on tasks which run forever. - # It's very likely that all tasks which run forever will have been created by - # the time we have completed the application startup (in runner.setup()), - # so we just record all running tasks here and exclude them later. - starting_tasks: "WeakSet[asyncio.Task[object]]" = WeakSet(asyncio.all_tasks()) - runner.shutdown_callback = partial(wait, starting_tasks, shutdown_timeout) sites: List[BaseSite] = [] diff --git a/aiohttp/web_protocol.py b/aiohttp/web_protocol.py index db15958e88d..1b4e7e66cd8 100644 --- a/aiohttp/web_protocol.py +++ b/aiohttp/web_protocol.py @@ -278,7 +278,12 @@ async def shutdown(self, timeout: Optional[float] = 15.0) -> None: if self._waiter: self._waiter.cancel() - # wait for handlers + # Wait for graceful disconnection + if self._current_request is not None: + with suppress(asyncio.CancelledError, asyncio.TimeoutError): + async with ceil_timeout(timeout): + await self._current_request.wait_for_disconnection() + # Then cancel handler and wait with suppress(asyncio.CancelledError, asyncio.TimeoutError): async with ceil_timeout(timeout): if self._current_request is not None: @@ -461,7 +466,6 @@ async def _handle_request( start_time: float, request_handler: Callable[[BaseRequest], Awaitable[StreamResponse]], ) -> Tuple[StreamResponse, bool]: - assert self._request_handler is not None try: try: self._current_request = request diff --git a/aiohttp/web_runner.py b/aiohttp/web_runner.py index 2618875f6bd..b9d13cb5dce 100644 --- a/aiohttp/web_runner.py +++ b/aiohttp/web_runner.py @@ -230,14 +230,7 @@ async def start(self) -> None: class BaseRunner(ABC): - __slots__ = ( - "shutdown_callback", - "_handle_signals", - "_kwargs", - "_server", - "_sites", - "_shutdown_timeout", - ) + __slots__ = ("_handle_signals", "_kwargs", "_server", "_sites", "_shutdown_timeout") def __init__( self, @@ -246,7 +239,6 @@ def __init__( shutdown_timeout: float = 60.0, **kwargs: Any, ) -> None: - self.shutdown_callback: Optional[Callable[[], Awaitable[None]]] = None self._handle_signals = handle_signals self._kwargs = kwargs self._server: Optional[Server] = None @@ -304,10 +296,6 @@ async def cleanup(self) -> None: await asyncio.sleep(0) self._server.pre_shutdown() await self.shutdown() - - if self.shutdown_callback: - await self.shutdown_callback() - await self._server.shutdown(self._shutdown_timeout) await self._cleanup_server() diff --git a/aiohttp/web_server.py b/aiohttp/web_server.py index 9d317bb12e1..a55d399c2bc 100644 --- a/aiohttp/web_server.py +++ b/aiohttp/web_server.py @@ -50,7 +50,10 @@ def connection_lost( self, handler: RequestHandler, exc: Optional[BaseException] = None ) -> None: if handler in self._connections: - del self._connections[handler] + if handler._task_handler: + handler._task_handler.add_done_callback(lambda f: self._connections.pop(handler, None)) + else: + del self._connections[handler] def _make_request( self, @@ -69,6 +72,7 @@ def pre_shutdown(self) -> None: async def shutdown(self, timeout: Optional[float] = None) -> None: coros = (conn.shutdown(timeout) for conn in self._connections) await asyncio.gather(*coros) + print("LENGTH", len(self._connections)) self._connections.clear() def __call__(self) -> RequestHandler: diff --git a/tests/test_run_app.py b/tests/test_run_app.py index b53637ad436..b5b0da1728d 100644 --- a/tests/test_run_app.py +++ b/tests/test_run_app.py @@ -16,7 +16,7 @@ import pytest -from aiohttp import ClientConnectorError, ClientSession, WSCloseCode, web +from aiohttp import ClientConnectorError, ClientSession, ClientTimeout, WSCloseCode, web from aiohttp.test_utils import make_mocked_coro from aiohttp.web_runner import BaseRunner @@ -935,8 +935,9 @@ async def test() -> None: async with ClientSession() as sess: for _ in range(5): # pragma: no cover try: - async with sess.get(f"http://localhost:{port}/"): - pass + with pytest.raises(asyncio.TimeoutError): + async with sess.get(f"http://localhost:{port}/", timeout=ClientTimeout(total=0.1)): + pass except ClientConnectorError: await asyncio.sleep(0.5) else: @@ -956,6 +957,7 @@ async def run_test(app: web.Application) -> None: async def handler(request: web.Request) -> web.Response: nonlocal t t = asyncio.create_task(task()) + await t return web.Response(text="FOO") t = test_task = None @@ -968,7 +970,7 @@ async def handler(request: web.Request) -> web.Response: assert test_task.exception() is None return t - def test_shutdown_wait_for_task( + def test_shutdown_wait_for_handler( self, aiohttp_unused_port: Callable[[], int] ) -> None: port = aiohttp_unused_port() @@ -985,7 +987,7 @@ async def task(): assert t.done() assert not t.cancelled() - def test_shutdown_timeout_task( + def test_shutdown_timeout_handler( self, aiohttp_unused_port: Callable[[], int] ) -> None: port = aiohttp_unused_port() @@ -1002,34 +1004,6 @@ async def task(): assert t.done() assert t.cancelled() - def test_shutdown_wait_for_spawned_task( - self, aiohttp_unused_port: Callable[[], int] - ) -> None: - port = aiohttp_unused_port() - finished = False - finished_sub = False - sub_t = None - - async def sub_task(): - nonlocal finished_sub - await asyncio.sleep(1.5) - finished_sub = True - - async def task(): - nonlocal finished, sub_t - await asyncio.sleep(0.5) - sub_t = asyncio.create_task(sub_task()) - finished = True - - t = self.run_app(port, 3, task) - - assert finished is True - assert t.done() - assert not t.cancelled() - assert finished_sub is True - assert sub_t.done() - assert not sub_t.cancelled() - def test_shutdown_timeout_not_reached( self, aiohttp_unused_port: Callable[[], int] ) -> None: From 1457ba5c147f86dc0eaeeb1f4929479d802ff2d4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 13 Jul 2024 15:07:45 +0000 Subject: [PATCH 2/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- aiohttp/web_server.py | 4 +++- tests/test_run_app.py | 5 ++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/aiohttp/web_server.py b/aiohttp/web_server.py index a55d399c2bc..3b90f0bd239 100644 --- a/aiohttp/web_server.py +++ b/aiohttp/web_server.py @@ -51,7 +51,9 @@ def connection_lost( ) -> None: if handler in self._connections: if handler._task_handler: - handler._task_handler.add_done_callback(lambda f: self._connections.pop(handler, None)) + handler._task_handler.add_done_callback( + lambda f: self._connections.pop(handler, None) + ) else: del self._connections[handler] diff --git a/tests/test_run_app.py b/tests/test_run_app.py index b5b0da1728d..1c3ba0a6dd5 100644 --- a/tests/test_run_app.py +++ b/tests/test_run_app.py @@ -936,7 +936,10 @@ async def test() -> None: for _ in range(5): # pragma: no cover try: with pytest.raises(asyncio.TimeoutError): - async with sess.get(f"http://localhost:{port}/", timeout=ClientTimeout(total=0.1)): + async with sess.get( + f"http://localhost:{port}/", + timeout=ClientTimeout(total=0.1), + ): pass except ClientConnectorError: await asyncio.sleep(0.5) From ad4ec817d76508d6965578d6b04eb276f778fd1d Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 21 Jul 2024 10:24:48 -0500 Subject: [PATCH 3/5] lint --- aiohttp/web.py | 3 --- aiohttp/web_runner.py | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/aiohttp/web.py b/aiohttp/web.py index 93594f6a902..68b29c79d0b 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -6,8 +6,6 @@ import warnings from argparse import ArgumentParser from collections.abc import Iterable -from contextlib import suppress -from functools import partial from importlib import import_module from typing import ( Any, @@ -21,7 +19,6 @@ Union, cast, ) -from weakref import WeakSet from .abc import AbstractAccessLogger from .helpers import AppKey diff --git a/aiohttp/web_runner.py b/aiohttp/web_runner.py index b9d13cb5dce..f507be60341 100644 --- a/aiohttp/web_runner.py +++ b/aiohttp/web_runner.py @@ -2,7 +2,7 @@ import signal import socket from abc import ABC, abstractmethod -from typing import Any, Awaitable, Callable, List, Optional, Set, Type +from typing import Any, List, Optional, Set, Type from yarl import URL From 5f6b84300782b9bffbf6476653df50b110e013cb Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 22 Jul 2024 09:20:49 -0500 Subject: [PATCH 4/5] small cleanups to make mergable for beta --- aiohttp/web_server.py | 1 - tests/test_web_request_handler.py | 6 ++++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/aiohttp/web_server.py b/aiohttp/web_server.py index 3b90f0bd239..f7dc971c6e1 100644 --- a/aiohttp/web_server.py +++ b/aiohttp/web_server.py @@ -74,7 +74,6 @@ def pre_shutdown(self) -> None: async def shutdown(self, timeout: Optional[float] = None) -> None: coros = (conn.shutdown(timeout) for conn in self._connections) await asyncio.gather(*coros) - print("LENGTH", len(self._connections)) self._connections.clear() def __call__(self) -> RequestHandler: diff --git a/tests/test_web_request_handler.py b/tests/test_web_request_handler.py index 06f99be76c0..e09ea0c5a96 100644 --- a/tests/test_web_request_handler.py +++ b/tests/test_web_request_handler.py @@ -22,7 +22,8 @@ async def test_connections() -> None: manager = web.Server(serve) assert manager.connections == [] - handler = object() + handler = mock.Mock(spec_set=web.RequestHandler) + handler._task_handler = None transport = object() manager.connection_made(handler, transport) # type: ignore[arg-type] assert manager.connections == [handler] @@ -34,7 +35,8 @@ async def test_connections() -> None: async def test_shutdown_no_timeout() -> None: manager = web.Server(serve) - handler = mock.Mock() + handler = mock.Mock(spec_set=web.RequestHandler) + handler._task_handler = None handler.shutdown = make_mocked_coro(mock.Mock()) transport = mock.Mock() manager.connection_made(handler, transport) From 7f303b4cfe087249403389391ade3a574c428e1a Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 22 Jul 2024 09:23:38 -0500 Subject: [PATCH 5/5] small cleanups to make mergable for beta --- tests/test_web_request_handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_web_request_handler.py b/tests/test_web_request_handler.py index e09ea0c5a96..4837cab030e 100644 --- a/tests/test_web_request_handler.py +++ b/tests/test_web_request_handler.py @@ -28,7 +28,7 @@ async def test_connections() -> None: manager.connection_made(handler, transport) # type: ignore[arg-type] assert manager.connections == [handler] - manager.connection_lost(handler, None) # type: ignore[arg-type] + manager.connection_lost(handler, None) assert manager.connections == []