Skip to content

Commit

Permalink
move away from internal entry-points, and clean up the fixture use in…
Browse files Browse the repository at this point in the history
… the testsuite
  • Loading branch information
ksimpson-work committed Nov 4, 2024
1 parent 2ab5d3d commit 770f39f
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 49 deletions.
18 changes: 16 additions & 2 deletions cuda_core/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,24 @@
# is strictly prohibited.

from cuda.core.experimental._device import Device
from cuda.core.experimental import _device
from cuda import cuda
from cuda.core.experimental._utils import handle_return
import pytest

@pytest.fixture(scope="module")
@pytest.fixture(scope="module", autouse=True)
def ensure_no_context():
device = Device()
device.set_current()
with _device._tls_lock:
if hasattr(_device._tls, 'devices'):
del _device._tls.devices

@pytest.fixture(scope="function")
def init_cuda():
device = Device()
device.set_current()

yield
handle_return(cuda.cuCtxPopCurrent())
with _device._tls_lock:
del _device._tls.devices
6 changes: 1 addition & 5 deletions cuda_core/tests/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,7 @@ def test_device_alloc(init_cuda):
assert buffer.size == 1024
assert buffer.device_id == 0

def test_device_set_current():
device = Device()
device.set_current()

def test_device_create_stream():
def test_device_create_stream(init_cuda):
device = Device()
stream = device.create_stream()
assert stream is not None
Expand Down
45 changes: 29 additions & 16 deletions cuda_core/tests/test_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,34 +6,47 @@
# this software and related documentation outside the terms of the EULA
# is strictly prohibited.

from cuda import cuda
from cuda.core.experimental._event import EventOptions, Event
from cuda.core.experimental._utils import handle_return
from cuda.core.experimental._device import Device
import pytest

def test_is_timing_disabled():
def test_is_timing_disabled(init_cuda):
options = EventOptions(enable_timing=False)
event = Event._init(options)
stream = Device().create_stream()
event = stream.record(options=options)
assert event.is_timing_disabled == True

options = EventOptions(enable_timing=True)
stream = Device().create_stream()
event = stream.record(options=options)
assert event.is_timing_disabled == False

def test_is_sync_busy_waited():
options = EventOptions(busy_waited_sync=True)
event = Event._init(options)
def test_is_sync_busy_waited(init_cuda):
options = EventOptions(enable_timing=False, busy_waited_sync=True)
stream = Device().create_stream()
event = stream.record(options=options)
assert event.is_sync_busy_waited == True

def test_sync():
options = EventOptions()
event = Event._init(options)
options = EventOptions(enable_timing=False)
stream = Device().create_stream()
event = stream.record(options=options)
assert event.is_sync_busy_waited == False

def test_sync(init_cuda):
options = EventOptions(enable_timing=False)
stream = Device().create_stream()
event = stream.record(options=options)
event.sync()
assert event.is_done == True

def test_is_done():
options = EventOptions()
event = Event._init(options)
def test_is_done(init_cuda):
options = EventOptions(enable_timing=False)
stream = Device().create_stream()
event = stream.record(options=options)
assert event.is_done == True

def test_handle():
options = EventOptions()
event = Event._init(options)
def test_handle(init_cuda):
options = EventOptions(enable_timing=False)
stream = Device().create_stream()
event = stream.record(options=options)
assert isinstance(event.handle, int)
4 changes: 2 additions & 2 deletions cuda_core/tests/test_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from cuda.core.experimental._utils import handle_return
import pytest

def test_launch_config_init():
def test_launch_config_init(init_cuda):
config = LaunchConfig(grid=(1, 1, 1), block=(1, 1, 1), stream=None, shmem_size=0)
assert config.grid == (1, 1, 1)
assert config.block == (1, 1, 1)
Expand Down Expand Up @@ -50,7 +50,7 @@ def test_launch_config_invalid_values():
with pytest.raises(ValueError):
LaunchConfig(grid=(1, 1, 1), block=(0, 1))

def test_launch_config_stream():
def test_launch_config_stream(init_cuda):
stream = Device().create_stream()
config = LaunchConfig(grid=(1, 1, 1), block=(1, 1, 1), stream=stream, shmem_size=0)
assert config.stream == stream
Expand Down
48 changes: 24 additions & 24 deletions cuda_core/tests/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,59 +15,59 @@ def test_stream_init():
with pytest.raises(NotImplementedError):
Stream()

def test_stream_init_with_options():
stream = Stream._init(options=StreamOptions(nonblocking=True, priority=0))
def test_stream_init_with_options(init_cuda):
stream = Device().create_stream(options=StreamOptions(nonblocking=True, priority=0))
assert stream.is_nonblocking is True
assert stream.priority == 0

def test_stream_handle():
stream = Stream._init(options=StreamOptions())
def test_stream_handle(init_cuda):
stream = Device().create_stream(options=StreamOptions())
assert isinstance(stream.handle, int)

def test_stream_is_nonblocking():
stream = Stream._init(options=StreamOptions(nonblocking=True))
def test_stream_is_nonblocking(init_cuda):
stream = Device().create_stream(options=StreamOptions(nonblocking=True))
assert stream.is_nonblocking is True

def test_stream_priority():
stream = Stream._init(options=StreamOptions(priority=0))
def test_stream_priority(init_cuda):
stream = Device().create_stream(options=StreamOptions(priority=0))
assert stream.priority == 0
stream = Stream._init(options=StreamOptions(priority=-1))
stream = Device().create_stream(options=StreamOptions(priority=-1))
assert stream.priority == -1
with pytest.raises(ValueError):
stream = Stream._init(options=StreamOptions(priority=1))
stream = Device().create_stream(options=StreamOptions(priority=1))

def test_stream_sync():
stream = Stream._init(options=StreamOptions())
def test_stream_sync(init_cuda):
stream = Device().create_stream(options=StreamOptions())
stream.sync() # Should not raise any exceptions

def test_stream_record():
stream = Stream._init(options=StreamOptions())
def test_stream_record(init_cuda):
stream = Device().create_stream(options=StreamOptions())
event = stream.record()
assert isinstance(event, Event)

def test_stream_record_invalid_event():
stream = Stream._init(options=StreamOptions())
def test_stream_record_invalid_event(init_cuda):
stream = Device().create_stream(options=StreamOptions())
with pytest.raises(TypeError):
stream.record(event="invalid_event")

def test_stream_wait_event():
stream = Stream._init(options=StreamOptions())
def test_stream_wait_event(init_cuda):
stream = Device().create_stream(options=StreamOptions())
event = Event._init()
stream.record(event)
stream.wait(event) # Should not raise any exceptions

def test_stream_wait_invalid_event():
stream = Stream._init(options=StreamOptions())
def test_stream_wait_invalid_event(init_cuda):
stream = Device().create_stream(options=StreamOptions())
with pytest.raises(ValueError):
stream.wait(event_or_stream="invalid_event")

def test_stream_device():
stream = Stream._init(options=StreamOptions())
def test_stream_device(init_cuda):
stream = Device().create_stream(options=StreamOptions())
device = stream.device
assert isinstance(device, Device)

def test_stream_context():
stream = Stream._init(options=StreamOptions())
def test_stream_context(init_cuda):
stream = Device().create_stream(options=StreamOptions())
context = stream.context
assert context is not None

Expand Down

0 comments on commit 770f39f

Please sign in to comment.