Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass along the received item to the next receiver if the task was cancelled #147

Merged
merged 2 commits into from
Aug 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ async def wrapper():
#

async def sleep(delay: float) -> None:
await check_cancelled()
await checkpoint()
await asyncio.sleep(delay)


Expand Down Expand Up @@ -310,7 +310,7 @@ def shield(self) -> bool:
return self._shield


async def check_cancelled():
async def checkpoint():
try:
cancel_scope = _task_states[current_task()].cancel_scope
except KeyError:
Expand Down Expand Up @@ -509,7 +509,7 @@ def thread_worker():
if not cancelled:
loop.call_soon_threadsafe(queue.put_nowait, (result, None))

await check_cancelled()
await checkpoint()
loop = get_running_loop()
task = current_task()
queue: asyncio.Queue[_Retval_Queue_Type] = asyncio.Queue(1)
Expand Down Expand Up @@ -630,7 +630,7 @@ def stderr(self) -> Optional[abc.ByteReceiveStream]:


async def open_process(command, *, shell: bool, stdin: int, stdout: int, stderr: int):
await check_cancelled()
await checkpoint()
if shell:
process = await asyncio.create_subprocess_shell(command, stdin=stdin, stdout=stdout,
stderr=stderr)
Expand Down Expand Up @@ -725,7 +725,7 @@ def __init__(self, transport: asyncio.Transport, protocol: StreamProtocol):

async def receive(self, max_bytes: int = 65536) -> bytes:
with self._receive_guard:
await check_cancelled()
await checkpoint()
if not self._protocol.read_queue and not self._transport.is_closing():
self._protocol.read_event.clear()
self._transport.resume_reading()
Expand All @@ -751,7 +751,7 @@ async def receive(self, max_bytes: int = 65536) -> bytes:

async def send(self, item: bytes) -> None:
with self._send_guard:
await check_cancelled()
await checkpoint()
try:
self._transport.write(item)
except RuntimeError as exc:
Expand Down Expand Up @@ -826,7 +826,7 @@ def local_address(self) -> SockAddrType:

async def accept(self) -> abc.SocketStream:
with self._accept_guard:
await check_cancelled()
await checkpoint()
try:
client_sock, _addr = await self._loop.sock_accept(self._raw_socket)
except asyncio.CancelledError:
Expand Down Expand Up @@ -880,7 +880,7 @@ def setsockopt(self, level, optname, value, *args) -> None:

async def receive(self) -> Tuple[bytes, IPSockAddrType]:
with self._receive_guard:
await check_cancelled()
await checkpoint()

# If the buffer is empty, ask for more data
if not self._protocol.read_queue and not self._transport.is_closing():
Expand All @@ -897,7 +897,7 @@ async def receive(self) -> Tuple[bytes, IPSockAddrType]:

async def send(self, item: UDPPacketType) -> None:
with self._send_guard:
await check_cancelled()
await checkpoint()
await self._protocol.write_event.wait()
if self._closed:
raise ClosedResourceError
Expand Down Expand Up @@ -943,7 +943,7 @@ def setsockopt(self, level, optname, value, *args) -> None:

async def receive(self) -> bytes:
with self._receive_guard:
await check_cancelled()
await checkpoint()

# If the buffer is empty, ask for more data
if not self._protocol.read_queue and not self._transport.is_closing():
Expand All @@ -962,7 +962,7 @@ async def receive(self) -> bytes:

async def send(self, item: bytes) -> None:
with self._send_guard:
await check_cancelled()
await checkpoint()
await self._protocol.write_event.wait()
if self._closed:
raise ClosedResourceError
Expand Down Expand Up @@ -1029,7 +1029,7 @@ async def getnameinfo(sockaddr: IPSockAddrType, flags: int = 0) -> Tuple[str, st


async def wait_socket_readable(sock: socket.SocketType) -> None:
await check_cancelled()
await checkpoint()
if _read_events.get(sock):
raise BusyResourceError('reading from') from None

Expand All @@ -1050,7 +1050,7 @@ async def wait_socket_readable(sock: socket.SocketType) -> None:


async def wait_socket_writable(sock: socket.SocketType) -> None:
await check_cancelled()
await checkpoint()
if _write_events.get(sock):
raise BusyResourceError('writing to') from None

Expand Down Expand Up @@ -1082,7 +1082,7 @@ def locked(self) -> bool:
return self._lock.locked()

async def acquire(self) -> None:
await check_cancelled()
await checkpoint()
await self._lock.acquire()

async def release(self) -> None:
Expand All @@ -1095,7 +1095,7 @@ def __init__(self, lock: Optional[Lock]):
self._condition = asyncio.Condition(asyncio_lock)

async def acquire(self) -> None:
await check_cancelled()
await checkpoint()
await self._condition.acquire()

async def release(self) -> None:
Expand All @@ -1111,7 +1111,7 @@ async def notify_all(self):
self._condition.notify_all()

async def wait(self):
await check_cancelled()
await checkpoint()
return await self._condition.wait()


Expand All @@ -1126,7 +1126,7 @@ def is_set(self) -> bool:
return self._event.is_set()

async def wait(self):
await check_cancelled()
await checkpoint()
await self._event.wait()


Expand All @@ -1135,7 +1135,7 @@ def __init__(self, value: int):
self._semaphore = asyncio.Semaphore(value)

async def acquire(self) -> None:
await check_cancelled()
await checkpoint()
await self._semaphore.acquire()

async def release(self) -> None:
Expand Down
36 changes: 18 additions & 18 deletions src/anyio/_backends/_curio.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ async def wrapper():
#

async def sleep(delay: float):
await check_cancelled()
await checkpoint()
await curio.sleep(delay)


Expand Down Expand Up @@ -215,7 +215,7 @@ def shield(self) -> bool:
return self._shield


async def check_cancelled():
async def checkpoint():
try:
cancel_scope = _task_states[await curio.current_task()].cancel_scope
except KeyError:
Expand Down Expand Up @@ -421,7 +421,7 @@ def thread_worker():
if not helper_task.cancelled:
queue.put(None)

await check_cancelled()
await checkpoint()
task = await curio.current_task()
queue = curio.UniversalQueue(maxsize=1)
finish_event = curio.Event()
Expand Down Expand Up @@ -554,7 +554,7 @@ def stderr(self) -> Optional[abc.ByteReceiveStream]:


async def open_process(command, *, shell: bool, stdin: int, stdout: int, stderr: int):
await check_cancelled()
await checkpoint()
process = curio.subprocess.Popen(command, stdin=stdin, stdout=stdout, stderr=stderr,
shell=shell)
stdin_stream = FileStreamWrapper(process.stdin) if process.stdin else None
Expand Down Expand Up @@ -634,7 +634,7 @@ def remote_address(self) -> Union[IPSockAddrType, str]:

async def receive(self, max_bytes: int = 65536) -> bytes:
with self._receive_guard:
await check_cancelled()
await checkpoint()
try:
data = await self._curio_socket.recv(max_bytes)
except (OSError, AttributeError) as exc:
Expand All @@ -647,7 +647,7 @@ async def receive(self, max_bytes: int = 65536) -> bytes:

async def send(self, item: bytes) -> None:
with self._send_guard:
await check_cancelled()
await checkpoint()
try:
await self._curio_socket.sendall(item)
except (OSError, AttributeError) as exc:
Expand All @@ -664,7 +664,7 @@ def __init__(self, raw_socket: socket.SocketType):

async def accept(self) -> SocketStream:
with self._accept_guard:
await check_cancelled()
await checkpoint()
try:
curio_socket, _addr = await self._curio_socket.accept()
except (OSError, AttributeError) as exc:
Expand All @@ -684,15 +684,15 @@ def __init__(self, curio_socket: curio.io.Socket):

async def receive(self) -> Tuple[bytes, IPSockAddrType]:
with self._receive_guard:
await check_cancelled()
await checkpoint()
try:
return await self._curio_socket.recvfrom(65536)
except (OSError, AttributeError) as exc:
self._convert_socket_error(exc)

async def send(self, item: UDPPacketType) -> None:
with self._send_guard:
await check_cancelled()
await checkpoint()
try:
await self._curio_socket.sendto(*item)
except (OSError, AttributeError) as exc:
Expand All @@ -714,15 +714,15 @@ def remote_address(self) -> IPSockAddrType:

async def receive(self) -> bytes:
with self._receive_guard:
await check_cancelled()
await checkpoint()
try:
return await self._curio_socket.recv(65536)
except (OSError, AttributeError) as exc:
self._convert_socket_error(exc)

async def send(self, item: bytes) -> None:
with self._send_guard:
await check_cancelled()
await checkpoint()
try:
await self._curio_socket.send(item)
except (OSError, AttributeError) as exc:
Expand Down Expand Up @@ -772,7 +772,7 @@ def getaddrinfo(host: Union[bytearray, bytes, str], port: Union[str, int, None],


async def wait_socket_readable(sock):
await check_cancelled()
await checkpoint()
if _reader_tasks.get(sock):
raise BusyResourceError('reading from') from None

Expand All @@ -789,7 +789,7 @@ async def wait_socket_readable(sock):


async def wait_socket_writable(sock):
await check_cancelled()
await checkpoint()
if _writer_tasks.get(sock):
raise BusyResourceError('writing to') from None

Expand Down Expand Up @@ -817,7 +817,7 @@ def locked(self) -> bool:
return self._lock.locked()

async def acquire(self) -> None:
await check_cancelled()
await checkpoint()
await self._lock.acquire()

async def release(self) -> None:
Expand All @@ -830,7 +830,7 @@ def __init__(self, lock: Optional[Lock]):
self._condition = curio.Condition(curio_lock)

async def acquire(self) -> None:
await check_cancelled()
await checkpoint()
await self._condition.acquire()

async def release(self) -> None:
Expand All @@ -846,7 +846,7 @@ async def notify_all(self):
await self._condition.notify_all()

async def wait(self):
await check_cancelled()
await checkpoint()
return await self._condition.wait()


Expand All @@ -861,7 +861,7 @@ def is_set(self) -> bool:
return self._event.is_set()

async def wait(self):
await check_cancelled()
await checkpoint()
return await self._event.wait()


Expand All @@ -870,7 +870,7 @@ def __init__(self, value: int):
self._semaphore = curio.Semaphore(value)

async def acquire(self) -> None:
await check_cancelled()
await checkpoint()
await self._semaphore.acquire()

async def release(self) -> None:
Expand Down
1 change: 1 addition & 0 deletions src/anyio/_backends/_trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
#

CancelledError = trio.Cancelled
checkpoint = trio.lowlevel.checkpoint


class CancelScope(abc.CancelScope):
Expand Down
6 changes: 6 additions & 0 deletions src/anyio/_core/_lowlevel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from anyio._core._eventloop import get_asynclib


async def checkpoint():
"""Checks for cancellation and allows the scheduler to switch to another task."""
await get_asynclib().checkpoint()
Loading