Skip to content

Commit

Permalink
Accept abstract namespace paths for unix domain sockets
Browse files Browse the repository at this point in the history
Accept paths starting with a null byte in create_unix_listener and
connect_unix_socket to allow creating abstract unix sockets. Fixes agronholm#781
  • Loading branch information
tapetersen committed Sep 4, 2024
1 parent ee8165b commit d85a6a1
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 17 deletions.
30 changes: 22 additions & 8 deletions src/anyio/_core/_sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,16 +683,30 @@ async def setup_unix_local_socket(
path_str: str | bytes | None
if path is not None:
path_str = os.fspath(path)
is_abstract = (
path_str.startswith(b"\0")
if isinstance(path_str, bytes)
else path_str.startswith("\0")
)

# Copied from pathlib...
try:
stat_result = os.stat(path)
except OSError as e:
if e.errno not in (errno.ENOENT, errno.ENOTDIR, errno.EBADF, errno.ELOOP):
raise
if is_abstract:
# Unix abstract namespace socket. No file backing so skip stat call
pass
else:
if stat.S_ISSOCK(stat_result.st_mode):
os.unlink(path)
# Copied from pathlib...
try:
stat_result = os.stat(path)
except OSError as e:
if e.errno not in (
errno.ENOENT,
errno.ENOTDIR,
errno.EBADF,
errno.ELOOP,
):
raise
else:
if stat.S_ISSOCK(stat_result.st_mode):
os.unlink(path)
else:
path_str = None

Expand Down
40 changes: 31 additions & 9 deletions tests/test_sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,12 +735,17 @@ async def test_bind_link_local(self) -> None:
sys.platform == "win32", reason="UNIX sockets are not available on Windows"
)
class TestUNIXStream:
@pytest.fixture
def socket_path(self) -> Generator[Path, None, None]:
@pytest.fixture(params=["path", "abstract"])
def socket_path(
self, request: SubRequest, tmp_path_factory: TempPathFactory
) -> Generator[Path, None, None]:
# Use stdlib tempdir generation
# Fixes `OSError: AF_UNIX path too long` from pytest generated temp_path
with tempfile.TemporaryDirectory() as path:
yield Path(path) / "socket"
if request.param == "path":
yield Path(path) / "socket"
elif request.param == "abstract":
yield Path(f"\0{path}") / "socket"

@pytest.fixture(params=[False, True], ids=["str", "path"])
def socket_path_or_str(self, request: SubRequest, socket_path: Path) -> Path | str:
Expand All @@ -764,7 +769,15 @@ async def test_extra_attributes(
assert (
stream.extra(SocketAttribute.local_address) == raw_socket.getsockname()
)
assert stream.extra(SocketAttribute.remote_address) == str(socket_path)
remote_addr = stream.extra(SocketAttribute.remote_address)
if isinstance(remote_addr, str):
assert stream.extra(SocketAttribute.remote_address) == str(socket_path)
else:
assert isinstance(remote_addr, bytes)
assert stream.extra(SocketAttribute.remote_address) == bytes(
socket_path
)

pytest.raises(
TypedAttributeLookupError, stream.extra, SocketAttribute.local_port
)
Expand Down Expand Up @@ -1031,8 +1044,12 @@ async def test_send_after_close(
await stream.send(b"foo")

async def test_cannot_connect(self, socket_path: Path) -> None:
with pytest.raises(FileNotFoundError):
await connect_unix(socket_path)
if str(socket_path).startswith("\0"):
with pytest.raises(ConnectionRefusedError):
await connect_unix(socket_path)
else:
with pytest.raises(FileNotFoundError):
await connect_unix(socket_path)

async def test_connecting_using_bytes(
self, server_sock: socket.socket, socket_path: Path
Expand All @@ -1057,12 +1074,17 @@ async def test_connecting_with_non_utf8(self, socket_path: Path) -> None:
sys.platform == "win32", reason="UNIX sockets are not available on Windows"
)
class TestUNIXListener:
@pytest.fixture
def socket_path(self) -> Generator[Path, None, None]:
@pytest.fixture(params=["path", "abstract"])
def socket_path(
self, request: SubRequest, tmp_path_factory: TempPathFactory
) -> Generator[Path, None, None]:
# Use stdlib tempdir generation
# Fixes `OSError: AF_UNIX path too long` from pytest generated temp_path
with tempfile.TemporaryDirectory() as path:
yield Path(path) / "socket"
if request.param == "path":
yield Path(path) / "socket"
elif request.param == "abstract":
yield Path(f"\0{path}") / "socket"

@pytest.fixture(params=[False, True], ids=["str", "path"])
def socket_path_or_str(self, request: SubRequest, socket_path: Path) -> Path | str:
Expand Down

0 comments on commit d85a6a1

Please sign in to comment.