Skip to content

Commit

Permalink
Ignore the cancellation error on memory stream receive
Browse files Browse the repository at this point in the history
As per the discussion on #147, it's better to ignore the cancellation exception now and have it triggered at the next checkpoint than to push the item to the buffer, potentially going over the buffer's limit.
  • Loading branch information
agronholm committed Aug 17, 2020
1 parent b5a2f08 commit a3af1da
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 64 deletions.
49 changes: 16 additions & 33 deletions src/anyio/streams/memory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections import deque, OrderedDict
from dataclasses import dataclass, field
from typing import TypeVar, Generic, List, Deque, Tuple
from typing import TypeVar, Generic, List, Deque

from .. import get_cancelled_exc_class
from .._core._lowlevel import checkpoint
Expand All @@ -18,7 +18,8 @@ class MemoryObjectStreamState(Generic[T_Item]):
buffer: Deque[T_Item] = field(init=False, default_factory=deque)
open_send_channels: int = field(init=False, default=0)
open_receive_channels: int = field(init=False, default=0)
waiting_receivers: Deque[Tuple[Event, List[T_Item]]] = field(init=False, default_factory=deque)
waiting_receivers: 'OrderedDict[Event, List[T_Item]]' = field(init=False,
default_factory=OrderedDict)
waiting_senders: 'OrderedDict[Event, T_Item]' = field(init=False, default_factory=OrderedDict)


Expand Down Expand Up @@ -66,34 +67,17 @@ async def receive(self) -> T_Item:
# Add ourselves in the queue
receive_event = create_event()
container: List[T_Item] = []
ticket = receive_event, container
self._state.waiting_receivers.append(ticket)
self._state.waiting_receivers[receive_event] = container

try:
await receive_event.wait()
except get_cancelled_exc_class():
# If we already received an item in the container, pass it to the next receiver in
# line
index = self._state.waiting_receivers.index(ticket) + 1
if container:
item = container[0]
while index < len(self._state.waiting_receivers):
receive_event, container = self._state.waiting_receivers[index]
if container:
item, container[0] = container[0], item
else:
# Found an untriggered receiver
container.append(item)
await receive_event.set()
break
else:
# Could not find an untriggered receiver, so in order to not lose any
# items, put it in the buffer, even if it exceeds the maximum buffer size
self._state.buffer.append(item)

raise
# Ignore the immediate cancellation if we already received an item, so as not to
# lose it
if not container:
raise
finally:
self._state.waiting_receivers.remove(ticket)
self._state.waiting_receivers.pop(receive_event, None)

if container:
return container[0]
Expand Down Expand Up @@ -151,13 +135,11 @@ async def send_nowait(self, item: T_Item) -> None:
if not self._state.open_receive_channels:
raise BrokenResourceError

for receive_event, container in self._state.waiting_receivers:
if not container:
container.append(item)
await receive_event.set()
return

if len(self._state.buffer) < self._state.max_buffer_size:
if self._state.waiting_receivers:
receive_event, container = self._state.waiting_receivers.popitem(last=False)
container.append(item)
await receive_event.set()
elif len(self._state.buffer) < self._state.max_buffer_size:
self._state.buffer.append(item)
else:
raise WouldBlock
Expand Down Expand Up @@ -199,6 +181,7 @@ async def aclose(self) -> None:
self._closed = True
self._state.open_send_channels -= 1
if self._state.open_send_channels == 0:
receive_events = [event for event, container in self._state.waiting_receivers]
receive_events = list(self._state.waiting_receivers.keys())
self._state.waiting_receivers.clear()
for event in receive_events:
await event.set()
34 changes: 3 additions & 31 deletions tests/streams/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,45 +231,17 @@ async def test_cancel_during_receive():
async def scoped_receiver():
nonlocal receiver_scope
async with open_cancel_scope() as receiver_scope:
await receive.receive()
received.append(await receive.receive())

async def receiver():
received.append(await receive.receive())
assert receiver_scope.cancel_called

receiver_scope = None
received = []
send, receive = create_memory_object_stream()
async with create_task_group() as tg:
await tg.spawn(scoped_receiver)
await wait_all_tasks_blocked()
await tg.spawn(receiver)
await send.send_nowait('hello')
await receiver_scope.cancel()
await send.send('hello')

assert received == ['hello']


async def test_cancel_during_receive_last_receiver():
"""
Test that cancelling a pending receive() operation does not cause an item in the stream to be
lost, even if there are no other receivers waiting.
"""
async def scoped_receiver():
nonlocal receiver_scope
async with open_cancel_scope() as receiver_scope:
await receive.receive()
pytest.fail('This point should never be reached')

receiver_scope = None
send, receive = create_memory_object_stream()
async with create_task_group() as tg:
await tg.spawn(scoped_receiver)
await wait_all_tasks_blocked()
await receiver_scope.cancel()
await send.send_nowait('hello')

with pytest.raises(WouldBlock):
await send.send_nowait('world')

assert await receive.receive_nowait() == 'hello'

0 comments on commit a3af1da

Please sign in to comment.