diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index b2d87857..72903aea 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -37,6 +37,8 @@ This library adheres to `Semantic Versioning 2.0 `_. - Fixed quitting the debugger in a pytest test session while in an active task group failing the test instead of exiting the test session (because the exit exception arrives in an exception group) +- Fixed support for Linux abstract namespaces in UNIX sockets that was broken in v4.2 + (#781 _; PR by @tapetersen) **4.4.0** diff --git a/src/anyio/_core/_sockets.py b/src/anyio/_core/_sockets.py index 5e09cdbf..6070c647 100644 --- a/src/anyio/_core/_sockets.py +++ b/src/anyio/_core/_sockets.py @@ -680,19 +680,26 @@ async def setup_unix_local_socket( :param socktype: socket.SOCK_STREAM or socket.SOCK_DGRAM """ - path_str: str | bytes | None + path_str: str | None if path is not None: - path_str = os.fspath(path) - - # Copied from pathlib... - try: - stat_result = os.stat(path) - except OSError as e: - if e.errno not in (errno.ENOENT, errno.ENOTDIR, errno.EBADF, errno.ELOOP): - raise - else: - if stat.S_ISSOCK(stat_result.st_mode): - os.unlink(path) + path_str = os.fsdecode(path) + + # Linux abstract namespace sockets aren't backed by a concrete file so skip stat call + if not path_str.startswith("\0"): + # Copied from pathlib... + try: + stat_result = os.stat(path) + except OSError as e: + if e.errno not in ( + errno.ENOENT, + errno.ENOTDIR, + errno.EBADF, + errno.ELOOP, + ): + raise + else: + if stat.S_ISSOCK(stat_result.st_mode): + os.unlink(path) else: path_str = None diff --git a/tests/test_sockets.py b/tests/test_sockets.py index 43738eec..832ae6bc 100644 --- a/tests/test_sockets.py +++ b/tests/test_sockets.py @@ -83,6 +83,10 @@ has_ipv6 = True skip_ipv6_mark = pytest.mark.skipif(not has_ipv6, reason="IPv6 is not available") +skip_unix_abstract_mark = pytest.mark.skipif( + not sys.platform.startswith("linux"), + reason="Abstract namespace sockets is a Linux only feature", +) @pytest.fixture @@ -735,12 +739,20 @@ async def test_bind_link_local(self) -> None: sys.platform == "win32", reason="UNIX sockets are not available on Windows" ) class TestUNIXStream: - @pytest.fixture - def socket_path(self) -> Generator[Path, None, None]: + @pytest.fixture( + params=[ + "path", + pytest.param("abstract", marks=[skip_unix_abstract_mark]), + ] + ) + def socket_path(self, request: SubRequest) -> Generator[Path, None, None]: # Use stdlib tempdir generation # Fixes `OSError: AF_UNIX path too long` from pytest generated temp_path with tempfile.TemporaryDirectory() as path: - yield Path(path) / "socket" + if request.param == "path": + yield Path(path) / "socket" + else: + yield Path(f"\0{path}") / "socket" @pytest.fixture(params=[False, True], ids=["str", "path"]) def socket_path_or_str(self, request: SubRequest, socket_path: Path) -> Path | str: @@ -764,7 +776,15 @@ async def test_extra_attributes( assert ( stream.extra(SocketAttribute.local_address) == raw_socket.getsockname() ) - assert stream.extra(SocketAttribute.remote_address) == str(socket_path) + remote_addr = stream.extra(SocketAttribute.remote_address) + if isinstance(remote_addr, str): + assert stream.extra(SocketAttribute.remote_address) == str(socket_path) + else: + assert isinstance(remote_addr, bytes) + assert stream.extra(SocketAttribute.remote_address) == bytes( + socket_path + ) + pytest.raises( TypedAttributeLookupError, stream.extra, SocketAttribute.local_port ) @@ -1031,8 +1051,12 @@ async def test_send_after_close( await stream.send(b"foo") async def test_cannot_connect(self, socket_path: Path) -> None: - with pytest.raises(FileNotFoundError): - await connect_unix(socket_path) + if str(socket_path).startswith("\0"): + with pytest.raises(ConnectionRefusedError): + await connect_unix(socket_path) + else: + with pytest.raises(FileNotFoundError): + await connect_unix(socket_path) async def test_connecting_using_bytes( self, server_sock: socket.socket, socket_path: Path @@ -1057,12 +1081,20 @@ async def test_connecting_with_non_utf8(self, socket_path: Path) -> None: sys.platform == "win32", reason="UNIX sockets are not available on Windows" ) class TestUNIXListener: - @pytest.fixture - def socket_path(self) -> Generator[Path, None, None]: + @pytest.fixture( + params=[ + "path", + pytest.param("abstract", marks=[skip_unix_abstract_mark]), + ] + ) + def socket_path(self, request: SubRequest) -> Generator[Path, None, None]: # Use stdlib tempdir generation # Fixes `OSError: AF_UNIX path too long` from pytest generated temp_path with tempfile.TemporaryDirectory() as path: - yield Path(path) / "socket" + if request.param == "path": + yield Path(path) / "socket" + else: + yield Path(f"\0{path}") / "socket" @pytest.fixture(params=[False, True], ids=["str", "path"]) def socket_path_or_str(self, request: SubRequest, socket_path: Path) -> Path | str: @@ -1461,12 +1493,20 @@ async def test_send_after_close(self, family: AnyIPAddressFamily) -> None: sys.platform == "win32", reason="UNIX sockets are not available on Windows" ) class TestUNIXDatagramSocket: - @pytest.fixture - def socket_path(self) -> Generator[Path, None, None]: + @pytest.fixture( + params=[ + "path", + pytest.param("abstract", marks=[skip_unix_abstract_mark]), + ] + ) + def socket_path(self, request: SubRequest) -> Generator[Path, None, None]: # Use stdlib tempdir generation # Fixes `OSError: AF_UNIX path too long` from pytest generated temp_path with tempfile.TemporaryDirectory() as path: - yield Path(path) / "socket" + if request.param == "path": + yield Path(path) / "socket" + else: + yield Path(f"\0{path}") / "socket" @pytest.fixture(params=[False, True], ids=["str", "path"]) def socket_path_or_str(self, request: SubRequest, socket_path: Path) -> Path | str: @@ -1506,12 +1546,18 @@ async def test_send_receive(self, socket_path_or_str: Path | str) -> None: await sock.sendto(b"blah", path) request, addr = await sock.receive() assert request == b"blah" - assert addr == path + if isinstance(addr, bytes): + assert addr == path.encode() + else: + assert addr == path await sock.sendto(b"halb", path) response, addr = await sock.receive() assert response == b"halb" - assert addr == path + if isinstance(addr, bytes): + assert addr == path.encode() + else: + assert addr == path async def test_iterate(self, peer_socket_path: Path, socket_path: Path) -> None: async def serve() -> None: @@ -1589,18 +1635,33 @@ async def test_local_path_invalid_ascii(self, socket_path: Path) -> None: sys.platform == "win32", reason="UNIX sockets are not available on Windows" ) class TestConnectedUNIXDatagramSocket: - @pytest.fixture - def socket_path(self) -> Generator[Path, None, None]: + @pytest.fixture( + params=[ + "path", + pytest.param("abstract", marks=[skip_unix_abstract_mark]), + ] + ) + def socket_path(self, request: SubRequest) -> Generator[Path, None, None]: # Use stdlib tempdir generation # Fixes `OSError: AF_UNIX path too long` from pytest generated temp_path with tempfile.TemporaryDirectory() as path: - yield Path(path) / "socket" + if request.param == "path": + yield Path(path) / "socket" + else: + yield Path(f"\0{path}") / "socket" @pytest.fixture(params=[False, True], ids=["str", "path"]) def socket_path_or_str(self, request: SubRequest, socket_path: Path) -> Path | str: return socket_path if request.param else str(socket_path) - @pytest.fixture + @pytest.fixture( + params=[ + pytest.param("path", id="path-peer"), + pytest.param( + "abstract", marks=[skip_unix_abstract_mark], id="abstract-peer" + ), + ] + ) def peer_socket_path(self) -> Generator[Path, None, None]: # Use stdlib tempdir generation # Fixes `OSError: AF_UNIX path too long` from pytest generated temp_path @@ -1634,10 +1695,12 @@ async def test_extra_attributes( raw_socket = unix_dg.extra(SocketAttribute.raw_socket) assert raw_socket is not None assert unix_dg.extra(SocketAttribute.family) == AddressFamily.AF_UNIX - assert unix_dg.extra(SocketAttribute.local_address) == str(socket_path) - assert unix_dg.extra(SocketAttribute.remote_address) == str( - peer_socket_path - ) + assert os.fsencode( + cast(os.PathLike, unix_dg.extra(SocketAttribute.local_address)) + ) == os.fsencode(socket_path) + assert os.fsencode( + cast(os.PathLike, unix_dg.extra(SocketAttribute.remote_address)) + ) == os.fsencode(peer_socket_path) pytest.raises( TypedAttributeLookupError, unix_dg.extra, SocketAttribute.local_port ) @@ -1657,11 +1720,11 @@ async def test_send_receive( peer_socket_path_or_str, local_path=socket_path_or_str, ) as unix_dg2: - socket_path = str(socket_path_or_str) + socket_path = os.fsdecode(socket_path_or_str) await unix_dg2.send(b"blah") - request = await unix_dg1.receive() - assert request == (b"blah", socket_path) + data, remote_addr = await unix_dg1.receive() + assert (data, os.fsdecode(remote_addr)) == (b"blah", socket_path) await unix_dg1.sendto(b"halb", socket_path) response = await unix_dg2.receive() @@ -1682,13 +1745,15 @@ async def serve() -> None: async with await create_connected_unix_datagram_socket( peer_socket_path, local_path=socket_path ) as unix_dg2: - path = str(socket_path) + path = os.fsdecode(socket_path) async with create_task_group() as tg: tg.start_soon(serve) await unix_dg1.sendto(b"FOOBAR", path) - assert await unix_dg1.receive() == (b"RABOOF", path) + data, addr = await unix_dg1.receive() + assert (data, os.fsdecode(addr)) == (b"RABOOF", path) await unix_dg1.sendto(b"123456", path) - assert await unix_dg1.receive() == (b"654321", path) + data, addr = await unix_dg1.receive() + assert (data, os.fsdecode(addr)) == (b"654321", path) tg.cancel_scope.cancel() async def test_concurrent_receive(