From 770f39f482c810c9572d73a0de1765b021ade45a Mon Sep 17 00:00:00 2001 From: ksimpson Date: Mon, 4 Nov 2024 10:37:59 -0800 Subject: [PATCH] move away from internal entry-points, and clean up the fixture use in the testsuite --- cuda_core/tests/conftest.py | 18 ++++++++++-- cuda_core/tests/test_device.py | 6 +--- cuda_core/tests/test_event.py | 45 +++++++++++++++++++----------- cuda_core/tests/test_launcher.py | 4 +-- cuda_core/tests/test_stream.py | 48 ++++++++++++++++---------------- 5 files changed, 72 insertions(+), 49 deletions(-) diff --git a/cuda_core/tests/conftest.py b/cuda_core/tests/conftest.py index 3ff6ce08..ac4e2f22 100644 --- a/cuda_core/tests/conftest.py +++ b/cuda_core/tests/conftest.py @@ -7,10 +7,24 @@ # is strictly prohibited. from cuda.core.experimental._device import Device +from cuda.core.experimental import _device +from cuda import cuda +from cuda.core.experimental._utils import handle_return import pytest -@pytest.fixture(scope="module") +@pytest.fixture(scope="module", autouse=True) +def ensure_no_context(): + device = Device() + device.set_current() + with _device._tls_lock: + if hasattr(_device._tls, 'devices'): + del _device._tls.devices + +@pytest.fixture(scope="function") def init_cuda(): device = Device() device.set_current() - \ No newline at end of file + yield + handle_return(cuda.cuCtxPopCurrent()) + with _device._tls_lock: + del _device._tls.devices \ No newline at end of file diff --git a/cuda_core/tests/test_device.py b/cuda_core/tests/test_device.py index 653dac06..65aa0696 100644 --- a/cuda_core/tests/test_device.py +++ b/cuda_core/tests/test_device.py @@ -24,11 +24,7 @@ def test_device_alloc(init_cuda): assert buffer.size == 1024 assert buffer.device_id == 0 -def test_device_set_current(): - device = Device() - device.set_current() - -def test_device_create_stream(): +def test_device_create_stream(init_cuda): device = Device() stream = device.create_stream() assert stream is not None diff --git a/cuda_core/tests/test_event.py b/cuda_core/tests/test_event.py index b6cfe647..0649df53 100644 --- a/cuda_core/tests/test_event.py +++ b/cuda_core/tests/test_event.py @@ -6,34 +6,47 @@ # this software and related documentation outside the terms of the EULA # is strictly prohibited. -from cuda import cuda from cuda.core.experimental._event import EventOptions, Event -from cuda.core.experimental._utils import handle_return from cuda.core.experimental._device import Device import pytest -def test_is_timing_disabled(): +def test_is_timing_disabled(init_cuda): options = EventOptions(enable_timing=False) - event = Event._init(options) + stream = Device().create_stream() + event = stream.record(options=options) assert event.is_timing_disabled == True + + options = EventOptions(enable_timing=True) + stream = Device().create_stream() + event = stream.record(options=options) + assert event.is_timing_disabled == False -def test_is_sync_busy_waited(): - options = EventOptions(busy_waited_sync=True) - event = Event._init(options) +def test_is_sync_busy_waited(init_cuda): + options = EventOptions(enable_timing=False, busy_waited_sync=True) + stream = Device().create_stream() + event = stream.record(options=options) assert event.is_sync_busy_waited == True -def test_sync(): - options = EventOptions() - event = Event._init(options) + options = EventOptions(enable_timing=False) + stream = Device().create_stream() + event = stream.record(options=options) + assert event.is_sync_busy_waited == False + +def test_sync(init_cuda): + options = EventOptions(enable_timing=False) + stream = Device().create_stream() + event = stream.record(options=options) event.sync() assert event.is_done == True -def test_is_done(): - options = EventOptions() - event = Event._init(options) +def test_is_done(init_cuda): + options = EventOptions(enable_timing=False) + stream = Device().create_stream() + event = stream.record(options=options) assert event.is_done == True -def test_handle(): - options = EventOptions() - event = Event._init(options) +def test_handle(init_cuda): + options = EventOptions(enable_timing=False) + stream = Device().create_stream() + event = stream.record(options=options) assert isinstance(event.handle, int) diff --git a/cuda_core/tests/test_launcher.py b/cuda_core/tests/test_launcher.py index 92dfc726..45e1b6fa 100644 --- a/cuda_core/tests/test_launcher.py +++ b/cuda_core/tests/test_launcher.py @@ -13,7 +13,7 @@ from cuda.core.experimental._utils import handle_return import pytest -def test_launch_config_init(): +def test_launch_config_init(init_cuda): config = LaunchConfig(grid=(1, 1, 1), block=(1, 1, 1), stream=None, shmem_size=0) assert config.grid == (1, 1, 1) assert config.block == (1, 1, 1) @@ -50,7 +50,7 @@ def test_launch_config_invalid_values(): with pytest.raises(ValueError): LaunchConfig(grid=(1, 1, 1), block=(0, 1)) -def test_launch_config_stream(): +def test_launch_config_stream(init_cuda): stream = Device().create_stream() config = LaunchConfig(grid=(1, 1, 1), block=(1, 1, 1), stream=stream, shmem_size=0) assert config.stream == stream diff --git a/cuda_core/tests/test_stream.py b/cuda_core/tests/test_stream.py index e0a98c18..885ef8e4 100644 --- a/cuda_core/tests/test_stream.py +++ b/cuda_core/tests/test_stream.py @@ -15,59 +15,59 @@ def test_stream_init(): with pytest.raises(NotImplementedError): Stream() -def test_stream_init_with_options(): - stream = Stream._init(options=StreamOptions(nonblocking=True, priority=0)) +def test_stream_init_with_options(init_cuda): + stream = Device().create_stream(options=StreamOptions(nonblocking=True, priority=0)) assert stream.is_nonblocking is True assert stream.priority == 0 -def test_stream_handle(): - stream = Stream._init(options=StreamOptions()) +def test_stream_handle(init_cuda): + stream = Device().create_stream(options=StreamOptions()) assert isinstance(stream.handle, int) -def test_stream_is_nonblocking(): - stream = Stream._init(options=StreamOptions(nonblocking=True)) +def test_stream_is_nonblocking(init_cuda): + stream = Device().create_stream(options=StreamOptions(nonblocking=True)) assert stream.is_nonblocking is True -def test_stream_priority(): - stream = Stream._init(options=StreamOptions(priority=0)) +def test_stream_priority(init_cuda): + stream = Device().create_stream(options=StreamOptions(priority=0)) assert stream.priority == 0 - stream = Stream._init(options=StreamOptions(priority=-1)) + stream = Device().create_stream(options=StreamOptions(priority=-1)) assert stream.priority == -1 with pytest.raises(ValueError): - stream = Stream._init(options=StreamOptions(priority=1)) + stream = Device().create_stream(options=StreamOptions(priority=1)) -def test_stream_sync(): - stream = Stream._init(options=StreamOptions()) +def test_stream_sync(init_cuda): + stream = Device().create_stream(options=StreamOptions()) stream.sync() # Should not raise any exceptions -def test_stream_record(): - stream = Stream._init(options=StreamOptions()) +def test_stream_record(init_cuda): + stream = Device().create_stream(options=StreamOptions()) event = stream.record() assert isinstance(event, Event) -def test_stream_record_invalid_event(): - stream = Stream._init(options=StreamOptions()) +def test_stream_record_invalid_event(init_cuda): + stream = Device().create_stream(options=StreamOptions()) with pytest.raises(TypeError): stream.record(event="invalid_event") -def test_stream_wait_event(): - stream = Stream._init(options=StreamOptions()) +def test_stream_wait_event(init_cuda): + stream = Device().create_stream(options=StreamOptions()) event = Event._init() stream.record(event) stream.wait(event) # Should not raise any exceptions -def test_stream_wait_invalid_event(): - stream = Stream._init(options=StreamOptions()) +def test_stream_wait_invalid_event(init_cuda): + stream = Device().create_stream(options=StreamOptions()) with pytest.raises(ValueError): stream.wait(event_or_stream="invalid_event") -def test_stream_device(): - stream = Stream._init(options=StreamOptions()) +def test_stream_device(init_cuda): + stream = Device().create_stream(options=StreamOptions()) device = stream.device assert isinstance(device, Device) -def test_stream_context(): - stream = Stream._init(options=StreamOptions()) +def test_stream_context(init_cuda): + stream = Device().create_stream(options=StreamOptions()) context = stream.context assert context is not None