From 98028f5eaf18072889a51cea2e40c6dcd5413745 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Sun, 16 Aug 2020 13:44:03 +0300 Subject: [PATCH 1/2] Renamed check_cancelled() to checkpoint() and exposed it internally --- src/anyio/_backends/_asyncio.py | 36 ++++++++++++++++----------------- src/anyio/_backends/_curio.py | 36 ++++++++++++++++----------------- src/anyio/_backends/_trio.py | 1 + src/anyio/_core/_lowlevel.py | 6 ++++++ 4 files changed, 43 insertions(+), 36 deletions(-) create mode 100644 src/anyio/_core/_lowlevel.py diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index 58272769..e0790c1c 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -165,7 +165,7 @@ async def wrapper(): # async def sleep(delay: float) -> None: - await check_cancelled() + await checkpoint() await asyncio.sleep(delay) @@ -310,7 +310,7 @@ def shield(self) -> bool: return self._shield -async def check_cancelled(): +async def checkpoint(): try: cancel_scope = _task_states[current_task()].cancel_scope except KeyError: @@ -509,7 +509,7 @@ def thread_worker(): if not cancelled: loop.call_soon_threadsafe(queue.put_nowait, (result, None)) - await check_cancelled() + await checkpoint() loop = get_running_loop() task = current_task() queue: asyncio.Queue[_Retval_Queue_Type] = asyncio.Queue(1) @@ -630,7 +630,7 @@ def stderr(self) -> Optional[abc.ByteReceiveStream]: async def open_process(command, *, shell: bool, stdin: int, stdout: int, stderr: int): - await check_cancelled() + await checkpoint() if shell: process = await asyncio.create_subprocess_shell(command, stdin=stdin, stdout=stdout, stderr=stderr) @@ -725,7 +725,7 @@ def __init__(self, transport: asyncio.Transport, protocol: StreamProtocol): async def receive(self, max_bytes: int = 65536) -> bytes: with self._receive_guard: - await check_cancelled() + await checkpoint() if not self._protocol.read_queue and not self._transport.is_closing(): self._protocol.read_event.clear() self._transport.resume_reading() @@ -751,7 +751,7 @@ async def receive(self, max_bytes: int = 65536) -> bytes: async def send(self, item: bytes) -> None: with self._send_guard: - await check_cancelled() + await checkpoint() try: self._transport.write(item) except RuntimeError as exc: @@ -826,7 +826,7 @@ def local_address(self) -> SockAddrType: async def accept(self) -> abc.SocketStream: with self._accept_guard: - await check_cancelled() + await checkpoint() try: client_sock, _addr = await self._loop.sock_accept(self._raw_socket) except asyncio.CancelledError: @@ -880,7 +880,7 @@ def setsockopt(self, level, optname, value, *args) -> None: async def receive(self) -> Tuple[bytes, IPSockAddrType]: with self._receive_guard: - await check_cancelled() + await checkpoint() # If the buffer is empty, ask for more data if not self._protocol.read_queue and not self._transport.is_closing(): @@ -897,7 +897,7 @@ async def receive(self) -> Tuple[bytes, IPSockAddrType]: async def send(self, item: UDPPacketType) -> None: with self._send_guard: - await check_cancelled() + await checkpoint() await self._protocol.write_event.wait() if self._closed: raise ClosedResourceError @@ -943,7 +943,7 @@ def setsockopt(self, level, optname, value, *args) -> None: async def receive(self) -> bytes: with self._receive_guard: - await check_cancelled() + await checkpoint() # If the buffer is empty, ask for more data if not self._protocol.read_queue and not self._transport.is_closing(): @@ -962,7 +962,7 @@ async def receive(self) -> bytes: async def send(self, item: bytes) -> None: with self._send_guard: - await check_cancelled() + await checkpoint() await self._protocol.write_event.wait() if self._closed: raise ClosedResourceError @@ -1029,7 +1029,7 @@ async def getnameinfo(sockaddr: IPSockAddrType, flags: int = 0) -> Tuple[str, st async def wait_socket_readable(sock: socket.SocketType) -> None: - await check_cancelled() + await checkpoint() if _read_events.get(sock): raise BusyResourceError('reading from') from None @@ -1050,7 +1050,7 @@ async def wait_socket_readable(sock: socket.SocketType) -> None: async def wait_socket_writable(sock: socket.SocketType) -> None: - await check_cancelled() + await checkpoint() if _write_events.get(sock): raise BusyResourceError('writing to') from None @@ -1082,7 +1082,7 @@ def locked(self) -> bool: return self._lock.locked() async def acquire(self) -> None: - await check_cancelled() + await checkpoint() await self._lock.acquire() async def release(self) -> None: @@ -1095,7 +1095,7 @@ def __init__(self, lock: Optional[Lock]): self._condition = asyncio.Condition(asyncio_lock) async def acquire(self) -> None: - await check_cancelled() + await checkpoint() await self._condition.acquire() async def release(self) -> None: @@ -1111,7 +1111,7 @@ async def notify_all(self): self._condition.notify_all() async def wait(self): - await check_cancelled() + await checkpoint() return await self._condition.wait() @@ -1126,7 +1126,7 @@ def is_set(self) -> bool: return self._event.is_set() async def wait(self): - await check_cancelled() + await checkpoint() await self._event.wait() @@ -1135,7 +1135,7 @@ def __init__(self, value: int): self._semaphore = asyncio.Semaphore(value) async def acquire(self) -> None: - await check_cancelled() + await checkpoint() await self._semaphore.acquire() async def release(self) -> None: diff --git a/src/anyio/_backends/_curio.py b/src/anyio/_backends/_curio.py index 68daf545..9537e740 100644 --- a/src/anyio/_backends/_curio.py +++ b/src/anyio/_backends/_curio.py @@ -72,7 +72,7 @@ async def wrapper(): # async def sleep(delay: float): - await check_cancelled() + await checkpoint() await curio.sleep(delay) @@ -215,7 +215,7 @@ def shield(self) -> bool: return self._shield -async def check_cancelled(): +async def checkpoint(): try: cancel_scope = _task_states[await curio.current_task()].cancel_scope except KeyError: @@ -421,7 +421,7 @@ def thread_worker(): if not helper_task.cancelled: queue.put(None) - await check_cancelled() + await checkpoint() task = await curio.current_task() queue = curio.UniversalQueue(maxsize=1) finish_event = curio.Event() @@ -554,7 +554,7 @@ def stderr(self) -> Optional[abc.ByteReceiveStream]: async def open_process(command, *, shell: bool, stdin: int, stdout: int, stderr: int): - await check_cancelled() + await checkpoint() process = curio.subprocess.Popen(command, stdin=stdin, stdout=stdout, stderr=stderr, shell=shell) stdin_stream = FileStreamWrapper(process.stdin) if process.stdin else None @@ -634,7 +634,7 @@ def remote_address(self) -> Union[IPSockAddrType, str]: async def receive(self, max_bytes: int = 65536) -> bytes: with self._receive_guard: - await check_cancelled() + await checkpoint() try: data = await self._curio_socket.recv(max_bytes) except (OSError, AttributeError) as exc: @@ -647,7 +647,7 @@ async def receive(self, max_bytes: int = 65536) -> bytes: async def send(self, item: bytes) -> None: with self._send_guard: - await check_cancelled() + await checkpoint() try: await self._curio_socket.sendall(item) except (OSError, AttributeError) as exc: @@ -664,7 +664,7 @@ def __init__(self, raw_socket: socket.SocketType): async def accept(self) -> SocketStream: with self._accept_guard: - await check_cancelled() + await checkpoint() try: curio_socket, _addr = await self._curio_socket.accept() except (OSError, AttributeError) as exc: @@ -684,7 +684,7 @@ def __init__(self, curio_socket: curio.io.Socket): async def receive(self) -> Tuple[bytes, IPSockAddrType]: with self._receive_guard: - await check_cancelled() + await checkpoint() try: return await self._curio_socket.recvfrom(65536) except (OSError, AttributeError) as exc: @@ -692,7 +692,7 @@ async def receive(self) -> Tuple[bytes, IPSockAddrType]: async def send(self, item: UDPPacketType) -> None: with self._send_guard: - await check_cancelled() + await checkpoint() try: await self._curio_socket.sendto(*item) except (OSError, AttributeError) as exc: @@ -714,7 +714,7 @@ def remote_address(self) -> IPSockAddrType: async def receive(self) -> bytes: with self._receive_guard: - await check_cancelled() + await checkpoint() try: return await self._curio_socket.recv(65536) except (OSError, AttributeError) as exc: @@ -722,7 +722,7 @@ async def receive(self) -> bytes: async def send(self, item: bytes) -> None: with self._send_guard: - await check_cancelled() + await checkpoint() try: await self._curio_socket.send(item) except (OSError, AttributeError) as exc: @@ -772,7 +772,7 @@ def getaddrinfo(host: Union[bytearray, bytes, str], port: Union[str, int, None], async def wait_socket_readable(sock): - await check_cancelled() + await checkpoint() if _reader_tasks.get(sock): raise BusyResourceError('reading from') from None @@ -789,7 +789,7 @@ async def wait_socket_readable(sock): async def wait_socket_writable(sock): - await check_cancelled() + await checkpoint() if _writer_tasks.get(sock): raise BusyResourceError('writing to') from None @@ -817,7 +817,7 @@ def locked(self) -> bool: return self._lock.locked() async def acquire(self) -> None: - await check_cancelled() + await checkpoint() await self._lock.acquire() async def release(self) -> None: @@ -830,7 +830,7 @@ def __init__(self, lock: Optional[Lock]): self._condition = curio.Condition(curio_lock) async def acquire(self) -> None: - await check_cancelled() + await checkpoint() await self._condition.acquire() async def release(self) -> None: @@ -846,7 +846,7 @@ async def notify_all(self): await self._condition.notify_all() async def wait(self): - await check_cancelled() + await checkpoint() return await self._condition.wait() @@ -861,7 +861,7 @@ def is_set(self) -> bool: return self._event.is_set() async def wait(self): - await check_cancelled() + await checkpoint() return await self._event.wait() @@ -870,7 +870,7 @@ def __init__(self, value: int): self._semaphore = curio.Semaphore(value) async def acquire(self) -> None: - await check_cancelled() + await checkpoint() await self._semaphore.acquire() async def release(self) -> None: diff --git a/src/anyio/_backends/_trio.py b/src/anyio/_backends/_trio.py index 8f3d3a54..d9325794 100644 --- a/src/anyio/_backends/_trio.py +++ b/src/anyio/_backends/_trio.py @@ -54,6 +54,7 @@ # CancelledError = trio.Cancelled +checkpoint = trio.lowlevel.checkpoint class CancelScope(abc.CancelScope): diff --git a/src/anyio/_core/_lowlevel.py b/src/anyio/_core/_lowlevel.py new file mode 100644 index 00000000..1fa87378 --- /dev/null +++ b/src/anyio/_core/_lowlevel.py @@ -0,0 +1,6 @@ +from anyio._core._eventloop import get_asynclib + + +async def checkpoint(): + """Checks for cancellation and allows the scheduler to switch to another task.""" + await get_asynclib().checkpoint() From 42b72825a4701514c0541373ed170db9b1ac12dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Sun, 16 Aug 2020 13:56:06 +0300 Subject: [PATCH 2/2] Pass along the received item to the next receiver if the task was cancelled --- src/anyio/streams/memory.py | 59 ++++++++++++++------- tests/streams/test_memory.py | 99 +++++++++++++++++++++++++++++++++++- 2 files changed, 139 insertions(+), 19 deletions(-) diff --git a/src/anyio/streams/memory.py b/src/anyio/streams/memory.py index 9b8e8ee2..916dfa0a 100644 --- a/src/anyio/streams/memory.py +++ b/src/anyio/streams/memory.py @@ -1,8 +1,10 @@ from collections import deque, OrderedDict from dataclasses import dataclass, field -from typing import TypeVar, Generic, List, Deque +from typing import TypeVar, Generic, List, Deque, Tuple -import anyio +from .. import get_cancelled_exc_class +from .._core._lowlevel import checkpoint +from .._core._synchronization import create_event from ..abc.synchronization import Event from ..abc.streams import ObjectSendStream, ObjectReceiveStream from ..exceptions import ClosedResourceError, BrokenResourceError, WouldBlock, EndOfStream @@ -16,8 +18,7 @@ class MemoryObjectStreamState(Generic[T_Item]): buffer: Deque[T_Item] = field(init=False, default_factory=deque) open_send_channels: int = field(init=False, default=0) open_receive_channels: int = field(init=False, default=0) - waiting_receivers: 'OrderedDict[Event, List[T_Item]]' = field(init=False, - default_factory=OrderedDict) + waiting_receivers: Deque[Tuple[Event, List[T_Item]]] = field(init=False, default_factory=deque) waiting_senders: 'OrderedDict[Event, T_Item]' = field(init=False, default_factory=OrderedDict) @@ -58,20 +59,41 @@ async def receive_nowait(self) -> T_Item: raise WouldBlock async def receive(self) -> T_Item: - # anyio.check_cancelled() + await checkpoint() try: return await self.receive_nowait() except WouldBlock: # Add ourselves in the queue - receive_event = anyio.create_event() + receive_event = create_event() container: List[T_Item] = [] - self._state.waiting_receivers[receive_event] = container + ticket = receive_event, container + self._state.waiting_receivers.append(ticket) try: await receive_event.wait() - except BaseException: - self._state.waiting_receivers.pop(receive_event, None) + except get_cancelled_exc_class(): + # If we already received an item in the container, pass it to the next receiver in + # line + index = self._state.waiting_receivers.index(ticket) + 1 + if container: + item = container[0] + while index < len(self._state.waiting_receivers): + receive_event, container = self._state.waiting_receivers[index] + if container: + item, container[0] = container[0], item + else: + # Found an untriggered receiver + container.append(item) + await receive_event.set() + break + else: + # Could not find an untriggered receiver, so in order to not lose any + # items, put it in the buffer, even if it exceeds the maximum buffer size + self._state.buffer.append(item) + raise + finally: + self._state.waiting_receivers.remove(ticket) if container: return container[0] @@ -129,22 +151,24 @@ async def send_nowait(self, item: T_Item) -> None: if not self._state.open_receive_channels: raise BrokenResourceError - if self._state.waiting_receivers: - receive_event, container = self._state.waiting_receivers.popitem(last=False) - container.append(item) - await receive_event.set() - elif len(self._state.buffer) < self._state.max_buffer_size: + for receive_event, container in self._state.waiting_receivers: + if not container: + container.append(item) + await receive_event.set() + return + + if len(self._state.buffer) < self._state.max_buffer_size: self._state.buffer.append(item) else: raise WouldBlock async def send(self, item: T_Item) -> None: - # await check_cancelled() + await checkpoint() try: await self.send_nowait(item) except WouldBlock: # Wait until there's someone on the receiving end - send_event = anyio.create_event() + send_event = create_event() self._state.waiting_senders[send_event] = item try: await send_event.wait() @@ -175,7 +199,6 @@ async def aclose(self) -> None: self._closed = True self._state.open_send_channels -= 1 if self._state.open_send_channels == 0: - receive_events = list(self._state.waiting_receivers.keys()) - self._state.waiting_receivers.clear() + receive_events = [event for event, container in self._state.waiting_receivers] for event in receive_events: await event.set() diff --git a/tests/streams/test_memory.py b/tests/streams/test_memory.py index 61778341..4b95c351 100644 --- a/tests/streams/test_memory.py +++ b/tests/streams/test_memory.py @@ -1,7 +1,8 @@ import pytest from anyio import ( - create_task_group, wait_all_tasks_blocked, create_memory_object_stream, fail_after) + create_task_group, wait_all_tasks_blocked, create_memory_object_stream, fail_after, + open_cancel_scope) from anyio.exceptions import EndOfStream, ClosedResourceError, BrokenResourceError, WouldBlock pytestmark = pytest.mark.anyio @@ -177,3 +178,99 @@ async def test_receive_after_send_closed(): await send.send('hello') await send.aclose() assert await receive.receive() == 'hello' + + +async def test_receive_when_cancelled(): + """ + Test that calling receive() in a cancelled scope prevents it from going through with the + operation. + + """ + send, receive = create_memory_object_stream() + async with create_task_group() as tg: + await tg.spawn(send.send, 'hello') + await wait_all_tasks_blocked() + await tg.spawn(send.send, 'world') + await wait_all_tasks_blocked() + + async with open_cancel_scope() as scope: + await scope.cancel() + await receive.receive() + + assert await receive.receive() == 'hello' + assert await receive.receive() == 'world' + + +async def test_send_when_cancelled(): + """ + Test that calling send() in a cancelled scope prevents it from going through with the + operation. + + """ + async def receiver(): + received.append(await receive.receive()) + + received = [] + send, receive = create_memory_object_stream() + async with create_task_group() as tg: + await tg.spawn(receiver) + async with open_cancel_scope() as scope: + await scope.cancel() + await send.send('hello') + + await send.send('world') + + assert received == ['world'] + + +async def test_cancel_during_receive(): + """ + Test that cancelling a pending receive() operation does not cause an item in the stream to be + lost. + + """ + async def scoped_receiver(): + nonlocal receiver_scope + async with open_cancel_scope() as receiver_scope: + await receive.receive() + + async def receiver(): + received.append(await receive.receive()) + + receiver_scope = None + received = [] + send, receive = create_memory_object_stream() + async with create_task_group() as tg: + await tg.spawn(scoped_receiver) + await wait_all_tasks_blocked() + await tg.spawn(receiver) + await receiver_scope.cancel() + await send.send('hello') + + assert received == ['hello'] + + +async def test_cancel_during_receive_last_receiver(): + """ + Test that cancelling a pending receive() operation does not cause an item in the stream to be + lost, even if there are no other receivers waiting. + + """ + async def scoped_receiver(): + nonlocal receiver_scope + async with open_cancel_scope() as receiver_scope: + await receive.receive() + pytest.fail('This point should never be reached') + + receiver_scope = None + send, receive = create_memory_object_stream() + async with create_task_group() as tg: + await tg.spawn(scoped_receiver) + await wait_all_tasks_blocked() + await receiver_scope.cancel() + await send.send_nowait('hello') + + with pytest.raises(WouldBlock): + await send.send_nowait('world') + + assert await receive.receive_nowait() == 'hello'