Skip to content

Commit

Permalink
Fix Event.wait() raising cancelled on asyncio when set() before scope…
Browse files Browse the repository at this point in the history
… cancelled

Fixes agronholm#536.
  • Loading branch information
gschaffner committed Mar 7, 2023
1 parent bfdc46a commit bcdedc4
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 15 deletions.
5 changes: 5 additions & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
the event loop to be closed
- Fixed ``current_effective_deadline()`` not returning ``-inf`` on asyncio when the
currently active cancel scope has been cancelled (PR by Ganden Schaffner)
- Fixed ``Event.set()`` failing to notify a waiter on asyncio if an ``Event.wait()``'s
scope was cancelled after ``Event.set()`` but before the the scheduler resumed the
waiting task. This also fixed a race condition where ``MemoryObjectSendStream.send()``
could raise a ``CancelledError`` on asyncio after successfully delivering an item to a
receiver (PR by Ganden Schaffner)

**3.6.1**

Expand Down
68 changes: 65 additions & 3 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,42 @@
from ..lowlevel import RunVar
from ..streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream

if sys.version_info < (3, 11):
if sys.version_info >= (3, 11):

def cancelling(task: asyncio.Task) -> bool:
"""
Return ``True`` if the task is cancelling.
NOTE: If the task finished cancelling and is now done, this function can return
anything. This is because on Python >= 3.11, the task can be uncancelled after
it finishes. (One might think we could avoid this by instead returning
``bool(task.cancelling()) or task.cancelled()``, but on Python < 3.8
``task.cancelled()`` can be ``False`` when it should be ``True``. On Python <
3.8 it appears to be impossible to determine whether a done task was cancelled
or not (see https://github.com/python/cpython/pull/16330).)
"""
return bool(task.cancelling())

else:

def cancelling(task: asyncio.Task) -> bool:
if task.cancelled():
return True

if task._must_cancel: # type: ignore[attr-defined]
return True

waiter = task._fut_waiter # type: ignore[attr-defined]
if waiter is None:
return False
if waiter.cancelled():
return True
elif isinstance(waiter, asyncio.Task):
return cancelling(waiter)
else:
return False

from exceptiongroup import BaseExceptionGroup, ExceptionGroup

if sys.version_info >= (3, 8):
Expand Down Expand Up @@ -1474,16 +1509,43 @@ def __new__(cls) -> Event:

def __init__(self) -> None:
self._event = asyncio.Event()
self._waiter_cancelling_when_set: dict[asyncio.Task, bool | None] = {}

def set(self) -> None:
self._event.set()
if not self._event.is_set():
self._event.set()
for waiter in tuple(self._waiter_cancelling_when_set):
self._waiter_cancelling_when_set[waiter] = cancelling(waiter)

def is_set(self) -> bool:
return self._event.is_set()

async def wait(self) -> None:
if await self._event.wait():
if self._event.is_set():
await AsyncIOBackend.checkpoint()
else:
waiter = cast(asyncio.Task, current_task())
self._waiter_cancelling_when_set[waiter] = None
try:
if await self._event.wait():
await AsyncIOBackend.checkpoint()
except CancelledError:
if not self._event.is_set():
raise
else:
# If we are here, then the event was not set before `wait()`. Then,
# in either order:
#
# * the event got set.
# * the current cancel scope was cancelled.
#
# To match trio, `Event.wait()` must raise a cancellation exception
# now if and only if the current scope was cancelled *before* the
# event was set.
if self._waiter_cancelling_when_set[waiter]:
raise
finally:
del self._waiter_cancelling_when_set[waiter]

def statistics(self) -> EventStatistics:
return EventStatistics(len(self._event._waiters)) # type: ignore[attr-defined]
Expand Down
13 changes: 1 addition & 12 deletions src/anyio/streams/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,7 @@
from types import TracebackType
from typing import Generic, NamedTuple, TypeVar

from .. import (
BrokenResourceError,
ClosedResourceError,
EndOfStream,
WouldBlock,
get_cancelled_exc_class,
)
from .. import BrokenResourceError, ClosedResourceError, EndOfStream, WouldBlock
from ..abc import Event, ObjectReceiveStream, ObjectSendStream
from ..lowlevel import checkpoint

Expand Down Expand Up @@ -104,11 +98,6 @@ async def receive(self) -> T_co:

try:
await receive_event.wait()
except get_cancelled_exc_class():
# Ignore the immediate cancellation if we already received an item, so
# as not to lose it
if not container:
raise
finally:
self._state.waiting_receivers.pop(receive_event, None)

Expand Down
Empty file added tests/_backends/__init__.py
Empty file.
49 changes: 49 additions & 0 deletions tests/_backends/test_asyncio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from __future__ import annotations

import asyncio
from typing import Any, cast

import pytest
from _pytest.fixtures import getfixturemarker

from anyio import create_task_group
from anyio._backends._asyncio import cancelling
from anyio.abc import TaskStatus
from anyio.lowlevel import cancel_shielded_checkpoint, checkpoint
from anyio.pytest_plugin import anyio_backend_name

try:
from .conftest import anyio_backend as parent_anyio_backend
except ImportError:
from ..conftest import anyio_backend as parent_anyio_backend

pytestmark = pytest.mark.anyio

# Use the inherited anyio_backend, but filter out non-asyncio
anyio_backend = pytest.fixture(
params=[
param
for param in cast(Any, getfixturemarker(parent_anyio_backend)).params
if any(
"asyncio"
in anyio_backend_name.__wrapped__(backend) # type: ignore[attr-defined]
for backend in param.values
)
]
)(parent_anyio_backend.__wrapped__)


async def test_cancelling() -> None:
async def func(*, task_status: TaskStatus[asyncio.Task]) -> None:
task = cast(asyncio.Task, asyncio.current_task())
task_status.started(task)
try:
await checkpoint()
finally:
await cancel_shielded_checkpoint()

async with create_task_group() as tg:
task = cast(asyncio.Task, await tg.start(func))
assert not cancelling(task)
tg.cancel_scope.cancel()
assert cancelling(task)

0 comments on commit bcdedc4

Please sign in to comment.