From 059a13cc513e137ed899804d30bb8b7563ebba99 Mon Sep 17 00:00:00 2001 From: Colin Taylor Date: Fri, 22 Dec 2023 14:20:43 -0500 Subject: [PATCH] Modify unix socket tests to use stdlib tempdirs --- tests/test_sockets.py | 99 ++++++++++++++++++++++++++----------------- 1 file changed, 60 insertions(+), 39 deletions(-) diff --git a/tests/test_sockets.py b/tests/test_sockets.py index f34a0381..91652e67 100644 --- a/tests/test_sockets.py +++ b/tests/test_sockets.py @@ -7,6 +7,7 @@ import platform import socket import sys +import tempfile import threading import time from contextlib import suppress @@ -14,7 +15,7 @@ from socket import AddressFamily from ssl import SSLContext, SSLError from threading import Thread -from typing import Any, Iterable, Iterator, NoReturn, TypeVar, cast +from typing import Any, Generator, Iterable, Iterator, NoReturn, TypeVar, cast import psutil import pytest @@ -707,8 +708,11 @@ async def test_bind_link_local(self) -> None: ) class TestUNIXStream: @pytest.fixture - def socket_path(self, tmp_path_factory: TempPathFactory) -> Path: - return tmp_path_factory.mktemp("unix").joinpath("socket") + def socket_path(self) -> Generator[Path, None, None]: + # Use stdlib tempdir generation + # Fixes `OSError: AF_UNIX path too long` from pytest generated temp_path + with tempfile.TemporaryDirectory() as path: + yield Path(path) / "socket" @pytest.fixture(params=[False, True], ids=["str", "path"]) def socket_path_or_str(self, request: SubRequest, socket_path: Path) -> Path | str: @@ -1026,8 +1030,11 @@ async def test_connecting_with_non_utf8(self, socket_path: Path) -> None: ) class TestUNIXListener: @pytest.fixture - def socket_path(self, tmp_path_factory: TempPathFactory) -> Path: - return tmp_path_factory.mktemp("unix").joinpath("socket") + def socket_path(self) -> Generator[Path, None, None]: + # Use stdlib tempdir generation + # Fixes `OSError: AF_UNIX path too long` from pytest generated temp_path + with tempfile.TemporaryDirectory() as path: + yield Path(path) / "socket" @pytest.fixture(params=[False, True], ids=["str", "path"]) def socket_path_or_str(self, request: SubRequest, socket_path: Path) -> Path | str: @@ -1140,35 +1147,37 @@ async def handle(stream: SocketStream) -> None: client_addresses: list[str | IPSockAddrType] = [] listeners: list[Listener] = [await create_tcp_listener(local_host="localhost")] - if sys.platform != "win32": - socket_path = tmp_path_factory.mktemp("unix").joinpath("socket") - listeners.append(await create_unix_listener(socket_path)) - - expected_addresses: list[str | IPSockAddrType] = [] - async with MultiListener(listeners) as multi_listener: - async with create_task_group() as tg: - tg.start_soon(multi_listener.serve, handle) - for listener in multi_listener.listeners: - event = Event() - local_address = listener.extra(SocketAttribute.local_address) - if ( - sys.platform != "win32" - and listener.extra(SocketAttribute.family) - == socket.AddressFamily.AF_UNIX - ): - assert isinstance(local_address, str) - stream: SocketStream = await connect_unix(local_address) - else: - assert isinstance(local_address, tuple) - stream = await connect_tcp(*local_address) + with tempfile.TemporaryDirectory() as path: + if sys.platform != "win32": + listeners.append(await create_unix_listener(Path(path) / "socket")) - expected_addresses.append(stream.extra(SocketAttribute.local_address)) - await event.wait() - await stream.aclose() + expected_addresses: list[str | IPSockAddrType] = [] + async with MultiListener(listeners) as multi_listener: + async with create_task_group() as tg: + tg.start_soon(multi_listener.serve, handle) + for listener in multi_listener.listeners: + event = Event() + local_address = listener.extra(SocketAttribute.local_address) + if ( + sys.platform != "win32" + and listener.extra(SocketAttribute.family) + == socket.AddressFamily.AF_UNIX + ): + assert isinstance(local_address, str) + stream: SocketStream = await connect_unix(local_address) + else: + assert isinstance(local_address, tuple) + stream = await connect_tcp(*local_address) + + expected_addresses.append( + stream.extra(SocketAttribute.local_address) + ) + await event.wait() + await stream.aclose() - tg.cancel_scope.cancel() + tg.cancel_scope.cancel() - assert client_addresses == expected_addresses + assert client_addresses == expected_addresses @pytest.mark.usefixtures("check_asyncio_bug") @@ -1423,16 +1432,22 @@ async def test_send_after_close(self, family: AnyIPAddressFamily) -> None: ) class TestUNIXDatagramSocket: @pytest.fixture - def socket_path(self, tmp_path_factory: TempPathFactory) -> Path: - return tmp_path_factory.mktemp("unix").joinpath("socket") + def socket_path(self) -> Generator[Path, None, None]: + # Use stdlib tempdir generation + # Fixes `OSError: AF_UNIX path too long` from pytest generated temp_path + with tempfile.TemporaryDirectory() as path: + yield Path(path) / "socket" @pytest.fixture(params=[False, True], ids=["str", "path"]) def socket_path_or_str(self, request: SubRequest, socket_path: Path) -> Path | str: return socket_path if request.param else str(socket_path) @pytest.fixture - def peer_socket_path(self, tmp_path_factory: TempPathFactory) -> Path: - return tmp_path_factory.mktemp("unix").joinpath("peer_socket") + def peer_socket_path(self) -> Generator[Path, None, None]: + # Use stdlib tempdir generation + # Fixes `OSError: AF_UNIX path too long` from pytest generated temp_path + with tempfile.TemporaryDirectory() as path: + yield Path(path) / "peer_socket" async def test_extra_attributes(self, socket_path: Path) -> None: async with await create_unix_datagram_socket(local_path=socket_path) as unix_dg: @@ -1545,16 +1560,22 @@ async def test_local_path_invalid_ascii(self, socket_path: Path) -> None: ) class TestConnectedUNIXDatagramSocket: @pytest.fixture - def socket_path(self, tmp_path_factory: TempPathFactory) -> Path: - return tmp_path_factory.mktemp("unix").joinpath("socket") + def socket_path(self) -> Generator[Path, None, None]: + # Use stdlib tempdir generation + # Fixes `OSError: AF_UNIX path too long` from pytest generated temp_path + with tempfile.TemporaryDirectory() as path: + yield Path(path) / "socket" @pytest.fixture(params=[False, True], ids=["str", "path"]) def socket_path_or_str(self, request: SubRequest, socket_path: Path) -> Path | str: return socket_path if request.param else str(socket_path) @pytest.fixture - def peer_socket_path(self, tmp_path_factory: TempPathFactory) -> Path: - return tmp_path_factory.mktemp("unix").joinpath("peer_socket") + def peer_socket_path(self) -> Generator[Path, None, None]: + # Use stdlib tempdir generation + # Fixes `OSError: AF_UNIX path too long` from pytest generated temp_path + with tempfile.TemporaryDirectory() as path: + yield Path(path) / "peer_socket" @pytest.fixture(params=[False, True], ids=["peer_str", "peer_path"]) def peer_socket_path_or_str(