Skip to content

Commit

Permalink
Replace use of RMM provided CUDA bindings with CUDA Python (rapidsai#…
Browse files Browse the repository at this point in the history
…4499)

Authors:
  - Ashwin Srinath (https://github.com/shwina)
  - Divye Gala (https://github.com/divyegala)

Approvers:
  - Jordan Jacobelli (https://github.com/Ethyling)
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: rapidsai#4499
  • Loading branch information
shwina authored Jan 21, 2022
1 parent ac4db43 commit dbbac47
Show file tree
Hide file tree
Showing 8 changed files with 20 additions and 11 deletions.
1 change: 1 addition & 0 deletions conda/environments/cuml_dev_cuda11.0.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ channels:
- conda-forge
dependencies:
- cudatoolkit=11.0
- cuda-python >=11.5,<12.0
- rapids-build-env=22.02.*
- rapids-notebook-env=22.02.*
- rapids-doc-env=22.02.*
Expand Down
1 change: 1 addition & 0 deletions conda/environments/cuml_dev_cuda11.2.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ channels:
- conda-forge
dependencies:
- cudatoolkit=11.2
- cuda-python >=11.5,<12.0
- rapids-build-env=22.02.*
- rapids-notebook-env=22.02.*
- rapids-doc-env=22.02.*
Expand Down
1 change: 1 addition & 0 deletions conda/environments/cuml_dev_cuda11.4.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ channels:
- conda-forge
dependencies:
- cudatoolkit=11.4
- cuda-python >=11.5,<12.0
- rapids-build-env=22.02.*
- rapids-notebook-env=22.02.*
- rapids-doc-env=22.02.*
Expand Down
1 change: 1 addition & 0 deletions conda/environments/cuml_dev_cuda11.5.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ channels:
- conda-forge
dependencies:
- cudatoolkit=11.5
- cuda-python >=11.5,<12.0
- rapids-build-env=22.02.*
- rapids-notebook-env=22.02.*
- rapids-doc-env=22.02.*
Expand Down
2 changes: 2 additions & 0 deletions conda/recipes/cuml/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ requirements:
- cudatoolkit {{ cuda_version }}.*
- ucx-py {{ ucx_py_version }}
- ucx-proc=*=gpu
- cuda-python >=11.5,<12.0
run:
- python x.x
- cudf {{ minor_version }}
Expand All @@ -52,6 +53,7 @@ requirements:
- distributed>=2021.11.1
- joblib >=0.11
- {{ pin_compatible('cudatoolkit', max_pin='x', min_pin='x') }}
- cuda-python >=11.5,<12.0

tests: # [linux64]
requirements: # [linux64]
Expand Down
17 changes: 10 additions & 7 deletions python/cuml/svm/linear.pyx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021, NVIDIA CORPORATION.
# Copyright (c) 2021-2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -33,12 +33,15 @@ from cuml.raft.common.handle cimport handle_t
from cuml.common import input_to_cuml_array
from libc.stdint cimport uintptr_t
from libcpp cimport bool as cppbool
cimport rmm._lib.lib as rmm
from cuda.ccudart cimport(
cudaMemcpyAsync,
cudaMemcpyKind,
cudaMemcpyDeviceToDevice
)


__all__ = ['LinearSVM', 'LinearSVM_defaults']

cdef extern from * nogil:
ctypedef void* _Stream "cudaStream_t"

cdef extern from "cuml/svm/linear.hpp" namespace "ML::SVM":

Expand Down Expand Up @@ -236,12 +239,12 @@ cdef class LinearSVMWrapper:
raise AttributeError(
f"Expected an array of type {target.dtype}, "
f"but got {source.dtype}")
rmm.cudaMemcpyAsync(
cudaMemcpyAsync(
<void*><uintptr_t>target.ptr,
<void*><uintptr_t>source.ptr,
<size_t>(source.nbytes),
rmm.cudaMemcpyDeviceToDevice,
<_Stream> stream)
cudaMemcpyKind.cudaMemcpyDeviceToDevice,
stream.value())
if synchronize:
self.handle.sync_stream()

Expand Down
4 changes: 2 additions & 2 deletions python/cuml/test/test_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019-2021, NVIDIA CORPORATION.
# Copyright (c) 2019-2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -40,7 +40,7 @@ def test_base_class_usage():


def test_base_class_usage_with_handle():
stream = cuml.cuda.Stream()
stream = cuml.raft.common.cuda.Stream()
handle = cuml.Handle(stream=stream)
base = cuml.Base(handle=handle)
base.handle.sync()
Expand Down
4 changes: 2 additions & 2 deletions python/cuml/test/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2021, NVIDIA CORPORATION.
# Copyright (c) 2020-2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -171,7 +171,7 @@ def sqnorm(x):
def get_handle(use_handle, n_streams=0):
if not use_handle:
return None, None
s = cuml.cuda.Stream()
s = cuml.raft.common.cuda.Stream()
h = cuml.Handle(stream=s, n_streams=n_streams)
return h, s

Expand Down

0 comments on commit dbbac47

Please sign in to comment.