Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix core config tests for UCX 1.12 #805

Merged
merged 3 commits into from
Nov 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions tests/test_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@


@pytest.mark.asyncio
@pytest.mark.skipif(
ucp.get_ucx_version() < (1, 11, 0),
reason="Endpoint error handling is unreliable in UCX releases prior to 1.11.0",
)
@pytest.mark.parametrize("server_close_callback", [True, False])
async def test_close_callback(server_close_callback):
endpoint_error_handling = ucp.get_ucx_version() >= (1, 10, 0)
closed = [False]

def _close_callback():
Expand All @@ -21,17 +24,13 @@ async def server_node(ep):
await ep.close()

async def client_node(port):
ep = await ucp.create_endpoint(
ucp.get_address(), port, endpoint_error_handling=endpoint_error_handling
)
ep = await ucp.create_endpoint(ucp.get_address(), port,)
if server_close_callback is False:
ep.set_close_callback(_close_callback)
await ep.send(bytearray(b"0" * 10))
if server_close_callback is True:
await ep.close()

listener = ucp.create_listener(
server_node, endpoint_error_handling=endpoint_error_handling
)
listener = ucp.create_listener(server_node,)
await client_node(listener.port)
assert closed[0] is True
4 changes: 2 additions & 2 deletions tests/test_topological_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pynvml
import pytest

from ucp._libs.topological_distance import TopologicalDistance
topological_distance = pytest.importorskip("ucp._libs.topological_distance")


def test_topological_distance_dgx():
Expand Down Expand Up @@ -56,7 +56,7 @@ def test_topological_distance_dgx():
else:
pytest.skip("DGX Server not recognized or not supported")

td = TopologicalDistance()
td = topological_distance.TopologicalDistance()

for i in range(dev_count):
closest_network = td.get_cuda_distances_from_device_index(i, "network")
Expand Down
17 changes: 16 additions & 1 deletion ucp/_libs/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,19 @@


def test_get_config():
# Cache user-defined UCX_TLS and unset it to test default value
tls = os.environ.get("UCX_TLS", None)
if tls is not None:
del os.environ["UCX_TLS"]

ctx = ucx_api.UCXContext()
config = ctx.get_config()
assert isinstance(config, dict)
assert config["MEMTYPE_CACHE"] == "n"
assert config["TLS"] == "all"

# Restore user-defined UCX_TLS
if tls is not None:
os.environ["UCX_TLS"] = tls


def test_set_env():
Expand All @@ -29,6 +38,12 @@ def test_init_options():
assert config["SEG_SIZE"] == options["SEG_SIZE"]


@pytest.mark.skipif(
ucx_api.get_ucx_version() >= (1, 12, 0),
reason="Beginning with UCX >= 1.12, it's only possible to validate "
"UCP options but not options from other modules such as UCT. "
"See https://github.com/openucx/ucx/issues/7519.",
)
def test_init_unknown_option():
options = {"UNKNOWN_OPTION": "3M"}
with pytest.raises(ucp.exceptions.UCXConfigError):
Expand Down
4 changes: 4 additions & 0 deletions ucp/_libs/tests/test_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ def _client(port, endpoint_error_handling, server_close_callback):
worker.progress()


@pytest.mark.skipif(
ucx_api.get_ucx_version() < (1, 11, 0),
reason="Endpoint error handling is unreliable in UCX releases prior to 1.11.0",
)
@pytest.mark.parametrize("server_close_callback", [True, False])
def test_close_callback(server_close_callback):
endpoint_error_handling = ucx_api.get_ucx_version() >= (1, 10, 0)
Expand Down