From f76b139ac8e30d22f3af10e552129f3aa4ec0ef0 Mon Sep 17 00:00:00 2001 From: Tobias Petersen Date: Wed, 4 Sep 2024 13:14:20 +0200 Subject: [PATCH] Accept abstract namespace paths for unix domain sockets Accept paths starting with a null byte in create_unix_listener and connect_unix_socket to allow creating abstract unix sockets. Fixes #781 --- docs/versionhistory.rst | 2 + src/anyio/_core/_sockets.py | 30 ++++++--- tests/test_sockets.py | 121 +++++++++++++++++++++++++++--------- 3 files changed, 117 insertions(+), 36 deletions(-) diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index b2d87857..3c9d1e8c 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -37,6 +37,8 @@ This library adheres to `Semantic Versioning 2.0 `_. - Fixed quitting the debugger in a pytest test session while in an active task group failing the test instead of exiting the test session (because the exit exception arrives in an exception group) +- Re-add support for linux abstract namespace sockets (#781 + _; PR by @tapetersen) **4.4.0** diff --git a/src/anyio/_core/_sockets.py b/src/anyio/_core/_sockets.py index 5e09cdbf..34901a64 100644 --- a/src/anyio/_core/_sockets.py +++ b/src/anyio/_core/_sockets.py @@ -683,16 +683,30 @@ async def setup_unix_local_socket( path_str: str | bytes | None if path is not None: path_str = os.fspath(path) + is_abstract = ( + path_str.startswith(b"\0") + if isinstance(path_str, bytes) + else path_str.startswith("\0") + ) - # Copied from pathlib... - try: - stat_result = os.stat(path) - except OSError as e: - if e.errno not in (errno.ENOENT, errno.ENOTDIR, errno.EBADF, errno.ELOOP): - raise + if is_abstract: + # Unix abstract namespace socket. No file backing so skip stat call + pass else: - if stat.S_ISSOCK(stat_result.st_mode): - os.unlink(path) + # Copied from pathlib... + try: + stat_result = os.stat(path) + except OSError as e: + if e.errno not in ( + errno.ENOENT, + errno.ENOTDIR, + errno.EBADF, + errno.ELOOP, + ): + raise + else: + if stat.S_ISSOCK(stat_result.st_mode): + os.unlink(path) else: path_str = None diff --git a/tests/test_sockets.py b/tests/test_sockets.py index 43738eec..832ae6bc 100644 --- a/tests/test_sockets.py +++ b/tests/test_sockets.py @@ -83,6 +83,10 @@ has_ipv6 = True skip_ipv6_mark = pytest.mark.skipif(not has_ipv6, reason="IPv6 is not available") +skip_unix_abstract_mark = pytest.mark.skipif( + not sys.platform.startswith("linux"), + reason="Abstract namespace sockets is a Linux only feature", +) @pytest.fixture @@ -735,12 +739,20 @@ async def test_bind_link_local(self) -> None: sys.platform == "win32", reason="UNIX sockets are not available on Windows" ) class TestUNIXStream: - @pytest.fixture - def socket_path(self) -> Generator[Path, None, None]: + @pytest.fixture( + params=[ + "path", + pytest.param("abstract", marks=[skip_unix_abstract_mark]), + ] + ) + def socket_path(self, request: SubRequest) -> Generator[Path, None, None]: # Use stdlib tempdir generation # Fixes `OSError: AF_UNIX path too long` from pytest generated temp_path with tempfile.TemporaryDirectory() as path: - yield Path(path) / "socket" + if request.param == "path": + yield Path(path) / "socket" + else: + yield Path(f"\0{path}") / "socket" @pytest.fixture(params=[False, True], ids=["str", "path"]) def socket_path_or_str(self, request: SubRequest, socket_path: Path) -> Path | str: @@ -764,7 +776,15 @@ async def test_extra_attributes( assert ( stream.extra(SocketAttribute.local_address) == raw_socket.getsockname() ) - assert stream.extra(SocketAttribute.remote_address) == str(socket_path) + remote_addr = stream.extra(SocketAttribute.remote_address) + if isinstance(remote_addr, str): + assert stream.extra(SocketAttribute.remote_address) == str(socket_path) + else: + assert isinstance(remote_addr, bytes) + assert stream.extra(SocketAttribute.remote_address) == bytes( + socket_path + ) + pytest.raises( TypedAttributeLookupError, stream.extra, SocketAttribute.local_port ) @@ -1031,8 +1051,12 @@ async def test_send_after_close( await stream.send(b"foo") async def test_cannot_connect(self, socket_path: Path) -> None: - with pytest.raises(FileNotFoundError): - await connect_unix(socket_path) + if str(socket_path).startswith("\0"): + with pytest.raises(ConnectionRefusedError): + await connect_unix(socket_path) + else: + with pytest.raises(FileNotFoundError): + await connect_unix(socket_path) async def test_connecting_using_bytes( self, server_sock: socket.socket, socket_path: Path @@ -1057,12 +1081,20 @@ async def test_connecting_with_non_utf8(self, socket_path: Path) -> None: sys.platform == "win32", reason="UNIX sockets are not available on Windows" ) class TestUNIXListener: - @pytest.fixture - def socket_path(self) -> Generator[Path, None, None]: + @pytest.fixture( + params=[ + "path", + pytest.param("abstract", marks=[skip_unix_abstract_mark]), + ] + ) + def socket_path(self, request: SubRequest) -> Generator[Path, None, None]: # Use stdlib tempdir generation # Fixes `OSError: AF_UNIX path too long` from pytest generated temp_path with tempfile.TemporaryDirectory() as path: - yield Path(path) / "socket" + if request.param == "path": + yield Path(path) / "socket" + else: + yield Path(f"\0{path}") / "socket" @pytest.fixture(params=[False, True], ids=["str", "path"]) def socket_path_or_str(self, request: SubRequest, socket_path: Path) -> Path | str: @@ -1461,12 +1493,20 @@ async def test_send_after_close(self, family: AnyIPAddressFamily) -> None: sys.platform == "win32", reason="UNIX sockets are not available on Windows" ) class TestUNIXDatagramSocket: - @pytest.fixture - def socket_path(self) -> Generator[Path, None, None]: + @pytest.fixture( + params=[ + "path", + pytest.param("abstract", marks=[skip_unix_abstract_mark]), + ] + ) + def socket_path(self, request: SubRequest) -> Generator[Path, None, None]: # Use stdlib tempdir generation # Fixes `OSError: AF_UNIX path too long` from pytest generated temp_path with tempfile.TemporaryDirectory() as path: - yield Path(path) / "socket" + if request.param == "path": + yield Path(path) / "socket" + else: + yield Path(f"\0{path}") / "socket" @pytest.fixture(params=[False, True], ids=["str", "path"]) def socket_path_or_str(self, request: SubRequest, socket_path: Path) -> Path | str: @@ -1506,12 +1546,18 @@ async def test_send_receive(self, socket_path_or_str: Path | str) -> None: await sock.sendto(b"blah", path) request, addr = await sock.receive() assert request == b"blah" - assert addr == path + if isinstance(addr, bytes): + assert addr == path.encode() + else: + assert addr == path await sock.sendto(b"halb", path) response, addr = await sock.receive() assert response == b"halb" - assert addr == path + if isinstance(addr, bytes): + assert addr == path.encode() + else: + assert addr == path async def test_iterate(self, peer_socket_path: Path, socket_path: Path) -> None: async def serve() -> None: @@ -1589,18 +1635,33 @@ async def test_local_path_invalid_ascii(self, socket_path: Path) -> None: sys.platform == "win32", reason="UNIX sockets are not available on Windows" ) class TestConnectedUNIXDatagramSocket: - @pytest.fixture - def socket_path(self) -> Generator[Path, None, None]: + @pytest.fixture( + params=[ + "path", + pytest.param("abstract", marks=[skip_unix_abstract_mark]), + ] + ) + def socket_path(self, request: SubRequest) -> Generator[Path, None, None]: # Use stdlib tempdir generation # Fixes `OSError: AF_UNIX path too long` from pytest generated temp_path with tempfile.TemporaryDirectory() as path: - yield Path(path) / "socket" + if request.param == "path": + yield Path(path) / "socket" + else: + yield Path(f"\0{path}") / "socket" @pytest.fixture(params=[False, True], ids=["str", "path"]) def socket_path_or_str(self, request: SubRequest, socket_path: Path) -> Path | str: return socket_path if request.param else str(socket_path) - @pytest.fixture + @pytest.fixture( + params=[ + pytest.param("path", id="path-peer"), + pytest.param( + "abstract", marks=[skip_unix_abstract_mark], id="abstract-peer" + ), + ] + ) def peer_socket_path(self) -> Generator[Path, None, None]: # Use stdlib tempdir generation # Fixes `OSError: AF_UNIX path too long` from pytest generated temp_path @@ -1634,10 +1695,12 @@ async def test_extra_attributes( raw_socket = unix_dg.extra(SocketAttribute.raw_socket) assert raw_socket is not None assert unix_dg.extra(SocketAttribute.family) == AddressFamily.AF_UNIX - assert unix_dg.extra(SocketAttribute.local_address) == str(socket_path) - assert unix_dg.extra(SocketAttribute.remote_address) == str( - peer_socket_path - ) + assert os.fsencode( + cast(os.PathLike, unix_dg.extra(SocketAttribute.local_address)) + ) == os.fsencode(socket_path) + assert os.fsencode( + cast(os.PathLike, unix_dg.extra(SocketAttribute.remote_address)) + ) == os.fsencode(peer_socket_path) pytest.raises( TypedAttributeLookupError, unix_dg.extra, SocketAttribute.local_port ) @@ -1657,11 +1720,11 @@ async def test_send_receive( peer_socket_path_or_str, local_path=socket_path_or_str, ) as unix_dg2: - socket_path = str(socket_path_or_str) + socket_path = os.fsdecode(socket_path_or_str) await unix_dg2.send(b"blah") - request = await unix_dg1.receive() - assert request == (b"blah", socket_path) + data, remote_addr = await unix_dg1.receive() + assert (data, os.fsdecode(remote_addr)) == (b"blah", socket_path) await unix_dg1.sendto(b"halb", socket_path) response = await unix_dg2.receive() @@ -1682,13 +1745,15 @@ async def serve() -> None: async with await create_connected_unix_datagram_socket( peer_socket_path, local_path=socket_path ) as unix_dg2: - path = str(socket_path) + path = os.fsdecode(socket_path) async with create_task_group() as tg: tg.start_soon(serve) await unix_dg1.sendto(b"FOOBAR", path) - assert await unix_dg1.receive() == (b"RABOOF", path) + data, addr = await unix_dg1.receive() + assert (data, os.fsdecode(addr)) == (b"RABOOF", path) await unix_dg1.sendto(b"123456", path) - assert await unix_dg1.receive() == (b"654321", path) + data, addr = await unix_dg1.receive() + assert (data, os.fsdecode(addr)) == (b"654321", path) tg.cancel_scope.cancel() async def test_concurrent_receive(