Skip to content

Commit

Permalink
Reduce dependencies on numba. (#1761)
Browse files Browse the repository at this point in the history
This PR makes `numba` an optional dependency of RMM.

We are keeping `numba` as a hard dependency in tests, though I explored what it would look like as a soft dependency in e2ff7f1. It turns out that the current RMM test suite relies on `numba` for about 90% of the tests, as a way to copy data from host to device and back (to verify that the allocations are valid and usable).

Closes #1760.

Authors:
  - Bradley Dice (https://github.com/bdice)

Approvers:
  - Matthew Murray (https://github.com/Matt711)
  - Mark Harris (https://github.com/harrism)
  - Vyas Ramasubramani (https://github.com/vyasr)

URL: #1761
  • Loading branch information
bdice authored Dec 20, 2024
1 parent 3bf6026 commit ba35f8e
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 9 deletions.
4 changes: 2 additions & 2 deletions dependencies.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,6 @@ dependencies:
common:
- output_types: [conda, requirements, pyproject]
packages:
- numba>=0.57
- numpy>=1.23,<3.0a0
specific:
- output_types: [conda, requirements, pyproject]
Expand All @@ -295,6 +294,7 @@ dependencies:
common:
- output_types: [conda, requirements, pyproject]
packages:
- numba>=0.57
- pytest
- pytest-cov
specific:
Expand All @@ -309,7 +309,7 @@ dependencies:
- cuda-nvcc
- matrix:
packages:
- output_types: [conda, requirements]
- output_types: [conda, requirements, pyproject]
# Define additional constraints for testing with oldest dependencies.
matrices:
- matrix:
Expand Down
2 changes: 1 addition & 1 deletion python/rmm/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ license = { text = "Apache 2.0" }
requires-python = ">=3.10"
dependencies = [
"cuda-python>=11.8.5,<12.0a0",
"numba>=0.57",
"numpy>=1.23,<3.0a0",
] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`.
classifiers = [
Expand All @@ -47,6 +46,7 @@ classifiers = [

[project.optional-dependencies]
test = [
"numba>=0.57",
"pytest",
"pytest-cov",
] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`.
Expand Down
15 changes: 9 additions & 6 deletions python/rmm/rmm/_cuda/stream.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,15 @@ cdef class Stream:
return self.c_is_default()

def _init_from_numba_stream(self, obj):
from numba import cuda
if isinstance(obj, cuda.cudadrv.driver.Stream):
self._cuda_stream = <cudaStream_t><uintptr_t>(int(obj))
self._owner = obj
else:
raise TypeError(f"Cannot create stream from {type(obj)}")
try:
from numba import cuda
if isinstance(obj, cuda.cudadrv.driver.Stream):
self._cuda_stream = <cudaStream_t><uintptr_t>(int(obj))
self._owner = obj
return
except ImportError:
pass
raise TypeError(f"Cannot create stream from {type(obj)}")

def _init_from_cupy_stream(self, obj):
try:
Expand Down

0 comments on commit ba35f8e

Please sign in to comment.