Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed EOF detection on asyncio when there is also data in the buffer #703

Merged
merged 7 commits into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
``KeyError``
- Fixed the asyncio backend not respecting the ``PYTHONASYNCIODEBUG`` environment
variable when setting the ``debug`` flag in ``anyio.run()``
- Fixed ``SocketStream.receive()`` not detecting EOF on asyncio if there is also data in
the read buffer (`#701 <https://github.com/agronholm/anyio/issues/701>`_)

**4.3.0**

Expand Down
7 changes: 5 additions & 2 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,6 +1047,7 @@ class StreamProtocol(asyncio.Protocol):
read_event: asyncio.Event
write_event: asyncio.Event
exception: Exception | None = None
is_at_eof: bool = False

def connection_made(self, transport: asyncio.BaseTransport) -> None:
self.read_queue = deque()
Expand All @@ -1068,6 +1069,7 @@ def data_received(self, data: bytes) -> None:
self.read_event.set()

def eof_received(self) -> bool | None:
self.is_at_eof = True
self.read_event.set()
return True

Expand Down Expand Up @@ -1123,15 +1125,16 @@ def _raw_socket(self) -> socket.socket:

async def receive(self, max_bytes: int = 65536) -> bytes:
with self._receive_guard:
await AsyncIOBackend.checkpoint()

if (
not self._protocol.read_event.is_set()
and not self._transport.is_closing()
and not self._protocol.is_at_eof
):
self._transport.resume_reading()
await self._protocol.read_event.wait()
self._transport.pause_reading()
else:
await AsyncIOBackend.checkpoint()

try:
chunk = self._protocol.read_queue.popleft()
Expand Down
24 changes: 24 additions & 0 deletions tests/test_sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
BrokenResourceError,
BusyResourceError,
ClosedResourceError,
EndOfStream,
Event,
TypedAttributeLookupError,
connect_tcp,
Expand Down Expand Up @@ -681,6 +682,29 @@ async def handle(stream: SocketStream) -> None:

tg.cancel_scope.cancel()

async def test_eof_after_send(self, family: AnyIPAddressFamily) -> None:
"""Regression test for #701."""
received_bytes = b""

async def handle(stream: SocketStream) -> None:
nonlocal received_bytes
async with stream:
received_bytes = await stream.receive()
with pytest.raises(EndOfStream), fail_after(1):
await stream.receive()

tg.cancel_scope.cancel()

multi = await create_tcp_listener(family=family, local_host="localhost")
async with multi, create_task_group() as tg:
with socket.socket(family) as client:
client.connect(multi.extra(SocketAttribute.local_address))
client.send(b"Hello")
client.shutdown(socket.SHUT_WR)
await multi.serve(handle)

assert received_bytes == b"Hello"

@skip_ipv6_mark
@pytest.mark.skipif(
sys.platform == "win32",
Expand Down