diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c1313cf3..f0b604bb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,7 +16,7 @@ repos: - id: trailing-whitespace - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.0.263 + rev: v0.0.265 hooks: - id: ruff args: [--fix, --show-fixes] diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index 2b67bd9f..0013f48f 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -21,6 +21,11 @@ This library adheres to `Semantic Versioning 2.0 `_. Ganden Schaffner) - Several functions and methods that previously only accepted coroutines as the return type of the callable have been amended to accept any awaitables: + - Several functions and methods that were previously annotated as accepting + ``Coroutine[Any, Any, Any]`` as the return type of the callable have been amended to + accept ``Awaitable[Any]`` instead, to allow a slightly broader set of coroutine-like + inputs, like ``async_generator_asend`` objects returned from the ``asend()`` method + of async generators, and to match the ``trio`` annotations: - ``anyio.run()`` - ``anyio.from_thread.run()`` @@ -30,8 +35,11 @@ This library adheres to `Semantic Versioning 2.0 `_. - ``BlockingPortal.start_task_soon()`` - ``BlockingPortal.start_task()`` - - The ``TaskStatus`` class is now generic, and should be parametrized to indicate the - type of the value passed to ``task_status.started()`` + Note that this change involved only changing the type annotations; run-time + functionality was not altered. + + - The ``TaskStatus`` class is now a generic protocol, and should be parametrized to + indicate the type of the value passed to ``task_status.started()`` - The ``Listener`` class is now covariant in its stream type - Object receive streams are now covariant and object send streams are correspondingly contravariant @@ -54,6 +62,15 @@ This library adheres to `Semantic Versioning 2.0 `_. ``TLSStream.wrap()`` being inadvertently set on Python 3.11.3 and 3.10.11 - Fixed ``CancelScope`` to properly handle asyncio task uncancellation on Python 3.11 (PR by Nikolay Bryskin) +- Fixed ``from_thread.run`` and ``from_thread.run_sync`` not setting sniffio on asyncio. + As a result: + + - Fixed ``from_thread.run_sync`` failing when used to call sniffio-dependent functions + on asyncio + - Fixed ``from_thread.run`` failing when used to call sniffio-dependent functions on + asyncio from a thread running trio or curio + - Fixed deadlock when using ``from_thread.start_blocking_portal(backend="asyncio")`` + in a thread running trio or curio (PR by Ganden Schaffner) **3.6.1** diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index cb5fe5b4..e2cd1cbf 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -95,6 +95,7 @@ def get_coro(task: asyncio.Task) -> Generator | Awaitable[Any]: T_Retval = TypeVar("T_Retval") +T_contra = TypeVar("T_contra", contravariant=True) # Check whether there is native support for task names in asyncio (3.8+) _native_task_names = hasattr(asyncio.Task, "get_name") @@ -442,7 +443,7 @@ def __init__(self, future: asyncio.Future, parent_id: int): self._future = future self._parent_id = parent_id - def started(self, value: object = None) -> None: + def started(self, value: T_contra | None = None) -> None: try: self._future.set_result(value) except asyncio.InvalidStateError: @@ -2051,8 +2052,10 @@ def run_async_from_thread( token: object, ) -> T_Retval: loop = cast(AbstractEventLoop, token) - f: concurrent.futures.Future[T_Retval] = asyncio.run_coroutine_threadsafe( - func(*args), loop + context = copy_context() + context.run(sniffio.current_async_library_cvar.set, "asyncio") + f: concurrent.futures.Future[T_Retval] = context.run( + asyncio.run_coroutine_threadsafe, func(*args), loop ) return f.result() @@ -2063,6 +2066,7 @@ def run_sync_from_thread( @wraps(func) def wrapper() -> None: try: + sniffio.current_async_library_cvar.set("asyncio") f.set_result(func(*args)) except BaseException as exc: f.set_exception(exc) diff --git a/src/anyio/abc/_tasks.py b/src/anyio/abc/_tasks.py index dad4b0e9..3c8a06bc 100644 --- a/src/anyio/abc/_tasks.py +++ b/src/anyio/abc/_tasks.py @@ -1,19 +1,33 @@ from __future__ import annotations +import sys from abc import ABCMeta, abstractmethod from collections.abc import Awaitable, Callable, Coroutine from types import TracebackType -from typing import TYPE_CHECKING, Any, Generic, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar, overload + +if sys.version_info >= (3, 8): + from typing import Protocol +else: + from typing_extensions import Protocol if TYPE_CHECKING: from .._core._tasks import CancelScope T_Retval = TypeVar("T_Retval") +T_contra = TypeVar("T_contra", contravariant=True) -class TaskStatus(Generic[T_Retval]): - @abstractmethod - def started(self, value: T_Retval | None = None) -> None: +class TaskStatus(Protocol[T_contra]): + @overload + def started(self: TaskStatus[None]) -> None: + ... + + @overload + def started(self, value: T_contra) -> None: + ... + + def started(self, value: T_contra | None = None) -> None: """ Signal that the task has started. @@ -54,7 +68,7 @@ async def start( func: Callable[..., Awaitable[Any]], *args: object, name: object = None, - ) -> object: + ) -> Any: """ Start a new task and wait until it signals for readiness. diff --git a/tests/test_from_thread.py b/tests/test_from_thread.py index 93886f53..0d0b003d 100644 --- a/tests/test_from_thread.py +++ b/tests/test_from_thread.py @@ -10,12 +10,14 @@ from typing import Any, AsyncGenerator, NoReturn, TypeVar import pytest +import sniffio from _pytest.logging import LogCaptureFixture from anyio import ( Event, create_task_group, from_thread, + get_all_backends, get_cancelled_exc_class, get_current_task, run, @@ -145,6 +147,16 @@ def worker() -> int: assert await to_thread.run_sync(worker) == 6 + async def test_sniffio(self, anyio_backend_name: str) -> None: + async def async_func() -> str: + return sniffio.current_async_library() + + def worker() -> str: + sniffio.current_async_library_cvar.set("something invalid for async_func") + return from_thread.run(async_func) + + assert await to_thread.run_sync(worker) == anyio_backend_name + class TestRunSyncFromThread: def test_run_sync_from_unclaimed_thread(self) -> None: @@ -163,6 +175,13 @@ def worker() -> int: assert await to_thread.run_sync(worker) == 6 + async def test_sniffio(self, anyio_backend_name: str) -> None: + def worker() -> str: + sniffio.current_async_library_cvar.set("something invalid for async_func") + return from_thread.run_sync(sniffio.current_async_library) + + assert await to_thread.run_sync(worker) == anyio_backend_name + class TestBlockingPortal: class AsyncCM: @@ -524,3 +543,21 @@ async def raise_baseexception() -> None: portal.call(raise_baseexception) assert exc.value.__context__ is None + + @pytest.mark.parametrize("portal_backend_name", get_all_backends()) + async def test_from_async( + self, anyio_backend_name: str, portal_backend_name: str + ) -> None: + """ + Test that portals don't deadlock when started/used from async code. + + Note: This test will deadlock if there is a regression. A deadlock should be + treated as a failure. See also + https://github.com/agronholm/anyio/pull/524#discussion_r1183080886. + + """ + if anyio_backend_name == "trio" and portal_backend_name == "trio": + pytest.xfail("known bug (#525)") + + with start_blocking_portal(portal_backend_name) as portal: + portal.call(checkpoint) diff --git a/tests/test_taskgroups.py b/tests/test_taskgroups.py index e9e61e5c..a5fc3c49 100644 --- a/tests/test_taskgroups.py +++ b/tests/test_taskgroups.py @@ -1126,3 +1126,40 @@ async def test_uncancel_after_scope_and_native_cancel(self) -> None: assert task.cancelling() == 1 task.uncancel() + + +class TestTaskStatusTyping: + """ + These tests do not do anything at run time, but since the test suite is also checked + with a static type checker, it ensures that the `TaskStatus` typing works as + intended. + """ + + async def typetest_None(*, task_status: TaskStatus[None]) -> None: + task_status.started() + task_status.started(None) + + async def typetest_None_Union(*, task_status: TaskStatus[int | None]) -> None: + task_status.started() + task_status.started(None) + + async def typetest_non_None(*, task_status: TaskStatus[int]) -> None: + # We use `type: ignore` and `--warn-unused-ignores` to get type checking errors + # if these ever stop failing. + task_status.started() # type: ignore[call-arg] + task_status.started(None) # type: ignore[arg-type] + + async def typetest_variance_good(*, task_status: TaskStatus[float]) -> None: + task_status2: TaskStatus[int] = task_status + task_status2.started(int()) + + async def typetest_variance_bad(*, task_status: TaskStatus[int]) -> None: + # We use `type: ignore` and `--warn-unused-ignores` to get type checking errors + # if these ever stop failing. + task_status2: TaskStatus[float] = task_status # type: ignore[assignment] + task_status2.started(float()) + + async def typetest_optional_status( + *, task_status: TaskStatus[int] = TASK_STATUS_IGNORED + ) -> None: + task_status.started(1)