diff --git a/python/cuml/cluster/dbscan.pyx b/python/cuml/cluster/dbscan.pyx index f05fa71e1b..fa92a77bb7 100644 --- a/python/cuml/cluster/dbscan.pyx +++ b/python/cuml/cluster/dbscan.pyx @@ -259,7 +259,8 @@ class DBSCAN(Base, cdef handle_t* handle_ = self.handle.getHandle() - self.labels_ = CumlArray.empty(n_rows, dtype=out_dtype) + self.labels_ = CumlArray.empty(n_rows, dtype=out_dtype, + index=X_m.index) cdef uintptr_t labels_ptr = self.labels_.ptr cdef uintptr_t core_sample_indices_ptr = NULL diff --git a/python/cuml/cluster/hdbscan.pyx b/python/cuml/cluster/hdbscan.pyx index e5c32d41f6..395e1f5e94 100644 --- a/python/cuml/cluster/hdbscan.pyx +++ b/python/cuml/cluster/hdbscan.pyx @@ -580,7 +580,7 @@ class HDBSCAN(Base, ClusterMixin, CMajorInputTagMixin): self.n_connected_components_ = 1 self.n_leaves_ = n_rows - self.labels_ = CumlArray.empty(n_rows, dtype="int32") + self.labels_ = CumlArray.empty(n_rows, dtype="int32", index=X_m.index) self.children_ = CumlArray.empty((2, n_rows), dtype="int32") self.probabilities_ = CumlArray.empty(n_rows, dtype="float32") self.sizes_ = CumlArray.empty(n_rows, dtype="int32") diff --git a/python/cuml/cluster/kmeans.pyx b/python/cuml/cluster/kmeans.pyx index 52d77545af..1cacea929a 100644 --- a/python/cuml/cluster/kmeans.pyx +++ b/python/cuml/cluster/kmeans.pyx @@ -476,7 +476,8 @@ class KMeans(Base, cdef uintptr_t cluster_centers_ptr = self.cluster_centers_.ptr - self.labels_ = CumlArray.zeros(shape=n_rows, dtype=np.int32) + self.labels_ = CumlArray.zeros(shape=n_rows, dtype=np.int32, + index=X_m.index) cdef uintptr_t labels_ptr = self.labels_.ptr # Sum of squared distances of samples to their closest cluster center. diff --git a/python/cuml/common/array.py b/python/cuml/common/array.py index 6300f61aa5..3fce1ce962 100644 --- a/python/cuml/common/array.py +++ b/python/cuml/common/array.py @@ -98,7 +98,12 @@ class CumlArray(Buffer): @nvtx.annotate(message="common.CumlArray.__init__", category="utils", domain="cuml_python") - def __init__(self, data=None, owner=None, dtype=None, shape=None, + def __init__(self, + data=None, + index=None, + owner=None, + dtype=None, + shape=None, order=None): # Checks of parameters @@ -148,6 +153,7 @@ def __init__(self, data=None, owner=None, dtype=None, shape=None, else: flattened_data = data + self._index = index super().__init__(data=flattened_data, owner=owner, size=size) @@ -179,6 +185,16 @@ def __init__(self, data=None, owner=None, dtype=None, shape=None, self.strides = ary_interface['strides'] self.order = _strides_to_order(self.strides, self.dtype) + # We use the index as a property to allow for validation/processing + # in the future if needed + @property + def index(self): + return self._index + + @index.setter + def index(self, index): + self._index = index + @with_cupy_rmm def __getitem__(self, slice): return CumlArray(data=cp.asarray(self).__getitem__(slice)) @@ -267,7 +283,8 @@ def to_output(self, output_type='cupy', output_dtype=None): mat = cp.asarray(self, dtype=output_dtype) if len(mat.shape) == 1: mat = mat.reshape(mat.shape[0], 1) - return DataFrame(mat) + return DataFrame(mat, + index=self.index) else: raise ValueError('cuDF unsupported Array dtype') @@ -277,7 +294,9 @@ def to_output(self, output_type='cupy', output_dtype=None): if len(self.shape) == 1: if self.dtype not in [np.uint8, np.uint16, np.uint32, np.uint64, np.float16]: - return Series(self, dtype=output_dtype) + return Series(self, + dtype=output_dtype, + index=self.index) else: raise ValueError('cuDF unsupported Array dtype') elif self.shape[1] > 1: @@ -307,7 +326,11 @@ def serialize(self): @classmethod @nvtx.annotate(message="common.CumlArray.empty", category="utils", domain="cuml_python") - def empty(cls, shape, dtype, order='F'): + def empty(cls, + shape, + dtype, + order='F', + index=None): """ Create an empty Array with an allocated but uninitialized DeviceBuffer @@ -321,12 +344,17 @@ def empty(cls, shape, dtype, order='F'): Whether to create a F-major or C-major array. """ - return CumlArray(cp.empty(shape, dtype, order)) + return CumlArray(cp.empty(shape, dtype, order), index=index) @classmethod @nvtx.annotate(message="common.CumlArray.full", category="utils", domain="cuml_python") - def full(cls, shape, value, dtype, order='F'): + def full(cls, + shape, + value, + dtype, + order='F', + index=None): """ Create an Array with an allocated DeviceBuffer initialized to value. @@ -340,12 +368,16 @@ def full(cls, shape, value, dtype, order='F'): Whether to create a F-major or C-major array. """ - return CumlArray(cp.full(shape, value, dtype, order)) + return CumlArray(cp.full(shape, value, dtype, order), index=index) @classmethod @nvtx.annotate(message="common.CumlArray.zeros", category="utils", domain="cuml_python") - def zeros(cls, shape, dtype='float32', order='F'): + def zeros(cls, + shape, + dtype='float32', + order='F', + index=None): """ Create an Array with an allocated DeviceBuffer initialized to zeros. @@ -358,12 +390,17 @@ def zeros(cls, shape, dtype='float32', order='F'): order: string, optional Whether to create a F-major or C-major array. """ - return CumlArray.full(value=0, shape=shape, dtype=dtype, order=order) + return CumlArray.full(value=0, shape=shape, dtype=dtype, order=order, + index=index) @classmethod @nvtx.annotate(message="common.CumlArray.ones", category="utils", domain="cuml_python") - def ones(cls, shape, dtype='float32', order='F'): + def ones(cls, + shape, + dtype='float32', + order='F', + index=None): """ Create an Array with an allocated DeviceBuffer initialized to zeros. @@ -376,7 +413,8 @@ def ones(cls, shape, dtype='float32', order='F'): order: string, optional Whether to create a F-major or C-major array. """ - return CumlArray.full(value=1, shape=shape, dtype=dtype, order=order) + return CumlArray.full(value=1, shape=shape, dtype=dtype, order=order, + index=index) def _check_low_level_type(data): diff --git a/python/cuml/common/array_sparse.py b/python/cuml/common/array_sparse.py index 2f24a2776f..c452c6dcd9 100644 --- a/python/cuml/common/array_sparse.py +++ b/python/cuml/common/array_sparse.py @@ -119,6 +119,7 @@ def __init__(self, data=None, self.shape = data.shape self.dtype = self.data.dtype self.nnz = data.nnz + self.index = None @nvtx.annotate(message="common.SparseCumlArray.to_output", category="utils", domain="cuml_python") diff --git a/python/cuml/common/input_utils.py b/python/cuml/common/input_utils.py index a8b2f26603..4043df4c09 100644 --- a/python/cuml/common/input_utils.py +++ b/python/cuml/common/input_utils.py @@ -308,6 +308,8 @@ def check_order(arr_order): safe_dtype=safe_dtype_conversion) check_dtype = False + index = getattr(X, 'index', None) + # format conversion if (isinstance(X, cudf.Series)): @@ -322,9 +324,11 @@ def check_order(arr_order): if isinstance(X, cudf.DataFrame): if order == 'K': - X_m = CumlArray(data=X.as_gpu_matrix(order='F')) + X_m = CumlArray(data=X.as_gpu_matrix(order='F'), + index=index) else: - X_m = CumlArray(data=X.as_gpu_matrix(order=order)) + X_m = CumlArray(data=X.as_gpu_matrix(order=order), + index=index) elif isinstance(X, CumlArray): X_m = X @@ -349,7 +353,6 @@ def check_order(arr_order): if not _check_array_contiguity(X): debug("Non contiguous array or view detected, a " "contiguous copy of the data will be done.") - # X = cp.array(X, order=order, copy=True) make_copy = True # If we have a host array, we copy it first before changing order @@ -359,7 +362,8 @@ def check_order(arr_order): cp_arr = cp.array(X, copy=make_copy, order=order) - X_m = CumlArray(data=cp_arr) + X_m = CumlArray(data=cp_arr, + index=index) if deepcopy: X_m = copy.deepcopy(X_m) @@ -404,9 +408,13 @@ def check_order(arr_order): if (check_order(X_m.order)): X_m = cp.array(X_m, copy=False, order=order) - X_m = CumlArray(data=X_m) + X_m = CumlArray(data=X_m, + index=index) - return cuml_array(array=X_m, n_rows=n_rows, n_cols=n_cols, dtype=X_m.dtype) + return cuml_array(array=X_m, + n_rows=n_rows, + n_cols=n_cols, + dtype=X_m.dtype) @nvtx.annotate(message="common.input_utils.input_to_cupy_array", diff --git a/python/cuml/decomposition/pca.pyx b/python/cuml/decomposition/pca.pyx index 9e41827f73..4a599e29bf 100644 --- a/python/cuml/decomposition/pca.pyx +++ b/python/cuml/decomposition/pca.pyx @@ -693,7 +693,7 @@ class PCA(Base, t_input_data = \ CumlArray.zeros((params.n_rows, params.n_components), - dtype=dtype.type) + dtype=dtype.type, index=X_m.index) cdef uintptr_t _trans_input_ptr = t_input_data.ptr cdef uintptr_t components_ptr = self.components_.ptr diff --git a/python/cuml/decomposition/tsvd.pyx b/python/cuml/decomposition/tsvd.pyx index a43a5f9c29..2e79c61c9a 100644 --- a/python/cuml/decomposition/tsvd.pyx +++ b/python/cuml/decomposition/tsvd.pyx @@ -346,7 +346,7 @@ class TruncatedSVD(Base, self.singular_values_.ptr _trans_input_ = CumlArray.zeros((params.n_rows, params.n_components), - dtype=self.dtype) + dtype=self.dtype, index=X_m.index) cdef uintptr_t t_input_ptr = _trans_input_.ptr if self.n_components> self.n_cols: @@ -389,7 +389,7 @@ class TruncatedSVD(Base, """ - trans_input, n_rows, _, dtype = \ + X_m, n_rows, _, dtype = \ input_to_cuml_array(X, check_dtype=self.dtype, convert_to_dtype=(self.dtype if convert_dtype else None)) @@ -400,9 +400,9 @@ class TruncatedSVD(Base, params.n_cols = self.n_cols input_data = CumlArray.zeros((params.n_rows, params.n_cols), - dtype=self.dtype) + dtype=self.dtype, index=X_m.index) - cdef uintptr_t trans_input_ptr = trans_input.ptr + cdef uintptr_t trans_input_ptr = X_m.ptr cdef uintptr_t input_ptr = input_data.ptr cdef uintptr_t components_ptr = self.components_.ptr @@ -436,7 +436,7 @@ class TruncatedSVD(Base, Perform dimensionality reduction on X. """ - input, n_rows, _, dtype = \ + X_m, n_rows, _, dtype = \ input_to_cuml_array(X, check_dtype=self.dtype, convert_to_dtype=(self.dtype if convert_dtype else None), @@ -449,9 +449,9 @@ class TruncatedSVD(Base, t_input_data = \ CumlArray.zeros((params.n_rows, params.n_components), - dtype=self.dtype) + dtype=self.dtype, index=X_m.index) - cdef uintptr_t input_ptr = input.ptr + cdef uintptr_t input_ptr = X_m.ptr cdef uintptr_t trans_input_ptr = t_input_data.ptr cdef uintptr_t components_ptr = self.components_.ptr diff --git a/python/cuml/experimental/linear_model/lars.pyx b/python/cuml/experimental/linear_model/lars.pyx index 848350c741..273dc866c3 100644 --- a/python/cuml/experimental/linear_model/lars.pyx +++ b/python/cuml/experimental/linear_model/lars.pyx @@ -365,7 +365,7 @@ class Lars(Base, RegressorMixin): cdef uintptr_t active_idx_ptr = \ input_to_cuml_array(self.active_).array.ptr - preds = CumlArray.zeros(n_rows, dtype=self.dtype) + preds = CumlArray.zeros(n_rows, dtype=self.dtype, index=X_m.index) if self.dtype == np.float32: larsPredict(handle_[0], X_ptr, n_rows, diff --git a/python/cuml/fil/fil.pyx b/python/cuml/fil/fil.pyx index 7e977d195f..757a2dd9a0 100644 --- a/python/cuml/fil/fil.pyx +++ b/python/cuml/fil/fil.pyx @@ -329,11 +329,13 @@ cdef class ForestInference_impl(): shape += (2,) else: shape += (self.num_class,) - preds = CumlArray.empty(shape=shape, dtype=np.float32, order='C') - elif (not isinstance(preds, cudf.Series) and - not rmm.is_cuda_array(preds)): - raise ValueError("Invalid type for output preds," - " need GPU array") + preds = CumlArray.empty(shape=shape, dtype=np.float32, order='C', + index=X_m.index) + else: + if not hasattr(preds, "__cuda_array_interface__"): + raise ValueError("Invalid type for output preds," + " need GPU array") + preds.index = X_m.index cdef uintptr_t preds_ptr preds_ptr = preds.ptr diff --git a/python/cuml/linear_model/base.pyx b/python/cuml/linear_model/base.pyx index fe33fbe65f..3a6a797ac2 100644 --- a/python/cuml/linear_model/base.pyx +++ b/python/cuml/linear_model/base.pyx @@ -67,17 +67,15 @@ class LinearPredictMixin: Predicts `y` values for `X`. """ - cdef uintptr_t X_ptr X_m, n_rows, n_cols, dtype = \ input_to_cuml_array(X, check_dtype=self.dtype, convert_to_dtype=(self.dtype if convert_dtype else None), check_cols=self.n_cols) - X_ptr = X_m.ptr - + cdef uintptr_t X_ptr = X_m.ptr cdef uintptr_t coef_ptr = self.coef_.ptr - preds = CumlArray.zeros(n_rows, dtype=dtype) + preds = CumlArray.zeros(n_rows, dtype=dtype, index=X_m.index) cdef uintptr_t preds_ptr = preds.ptr cdef handle_t* handle_ = self.handle.getHandle() diff --git a/python/cuml/manifold/t_sne.pyx b/python/cuml/manifold/t_sne.pyx index 516ea842f6..d4de67435e 100644 --- a/python/cuml/manifold/t_sne.pyx +++ b/python/cuml/manifold/t_sne.pyx @@ -422,7 +422,6 @@ class TSNE(Base, convert_format=False) n, p = self.X_m.shape self.sparse_fit = True - # Handle dense inputs else: self.X_m, n, p, _ = \ @@ -451,7 +450,8 @@ class TSNE(Base, self.embedding_ = CumlArray.zeros( (n, self.n_components), order="F", - dtype=np.float32) + dtype=np.float32, + index=self.X_m.index) cdef uintptr_t embed_ptr = self.embedding_.ptr diff --git a/python/cuml/manifold/umap.pyx b/python/cuml/manifold/umap.pyx index 96a857fb2c..6a6dc4c08d 100644 --- a/python/cuml/manifold/umap.pyx +++ b/python/cuml/manifold/umap.pyx @@ -560,7 +560,8 @@ class UMAP(Base, self.embedding_ = CumlArray.zeros((self.n_rows, self.n_components), - order="C", dtype=np.float32) + order="C", dtype=np.float32, + index=self.X_m.index) if self.hash_input: with using_output_type("numpy"): @@ -720,12 +721,14 @@ class UMAP(Base, if is_sparse(X): X_m = SparseCumlArray(X, convert_to_dtype=cupy.float32, convert_format=False) + index = None else: X_m, n_rows, n_cols, dtype = \ input_to_cuml_array(X, order='C', check_dtype=np.float32, convert_to_dtype=(np.float32 if convert_dtype else None)) + index = X_m.index n_rows = X_m.shape[0] n_cols = X_m.shape[1] @@ -745,7 +748,8 @@ class UMAP(Base, embedding = CumlArray.zeros((X_m.shape[0], self.n_components), - order="C", dtype=np.float32) + order="C", dtype=np.float32, + index=index) cdef uintptr_t xformed_ptr = embedding.ptr (knn_indices_m, knn_indices_ctype), (knn_dists_m, knn_dists_ctype) =\ diff --git a/python/cuml/naive_bayes/naive_bayes.py b/python/cuml/naive_bayes/naive_bayes.py index ce33c828df..20fc22fc34 100644 --- a/python/cuml/naive_bayes/naive_bayes.py +++ b/python/cuml/naive_bayes/naive_bayes.py @@ -182,16 +182,21 @@ def predict(self, X) -> CumlArray: # https://github.com/rapidsai/cuml/issues/2216 if scipy_sparse_isspmatrix(X) or cupyx.scipy.sparse.isspmatrix(X): X = _convert_x_sparse(X) + index = None else: - X = input_to_cupy_array(X, order='K', + X = input_to_cuml_array(X, order='K', check_dtype=[cp.float32, cp.float64, - cp.int32]).array + cp.int32]) + index = X.index + # todo: improve index management for cupy based codebases + X = X.array.to_output('cupy') X = self._check_X(X) jll = self._joint_log_likelihood(X) indices = cp.argmax(jll, axis=1).astype(self.classes_.dtype) y_hat = invert_labels(indices, classes=self.classes_) + y_hat = CumlArray(data=y_hat, index=index) return y_hat @generate_docstring( @@ -220,11 +225,15 @@ def predict_log_proba(self, X) -> CumlArray: # https://github.com/rapidsai/cuml/issues/2216 if scipy_sparse_isspmatrix(X) or cupyx.scipy.sparse.isspmatrix(X): X = _convert_x_sparse(X) + index = None else: - X = input_to_cupy_array(X, order='K', + X = input_to_cuml_array(X, order='K', check_dtype=[cp.float32, cp.float64, - cp.int32]).array + cp.int32]) + index = X.index + # todo: improve index management for cupy based codebases + X = X.array.to_output('cupy') X = self._check_X(X) jll = self._joint_log_likelihood(X) @@ -246,6 +255,7 @@ def predict_log_proba(self, X) -> CumlArray: if log_prob_x.ndim < 2: log_prob_x = log_prob_x.reshape((1, log_prob_x.shape[0])) result = jll - log_prob_x.T + result = CumlArray(data=result, index=index) return result @generate_docstring( diff --git a/python/cuml/neighbors/kneighbors_classifier.pyx b/python/cuml/neighbors/kneighbors_classifier.pyx index 6a8bf50d09..47ac087ea0 100644 --- a/python/cuml/neighbors/kneighbors_classifier.pyx +++ b/python/cuml/neighbors/kneighbors_classifier.pyx @@ -208,7 +208,8 @@ class KNeighborsClassifier(NearestNeighbors, out_shape = (n_rows, out_cols) if out_cols > 1 else n_rows - classes = CumlArray.zeros(out_shape, dtype=np.int32, order="C") + classes = CumlArray.zeros(out_shape, dtype=np.int32, order="C", + index=knn_indices.index) cdef vector[int*] *y_vec = new vector[int*]() @@ -277,7 +278,8 @@ class KNeighborsClassifier(NearestNeighbors, classes = CumlArray.zeros((n_rows, len(cp.unique(cp.asarray(col)))), dtype=np.float32, - order="C") + order="C", + index=knn_indices.index) out_classes.append(classes) classes_ptr = classes.ptr out_vec.push_back(classes_ptr) diff --git a/python/cuml/neighbors/kneighbors_regressor.pyx b/python/cuml/neighbors/kneighbors_regressor.pyx index da8e0ec70c..b33a9ea79f 100644 --- a/python/cuml/neighbors/kneighbors_regressor.pyx +++ b/python/cuml/neighbors/kneighbors_regressor.pyx @@ -205,7 +205,8 @@ class KNeighborsRegressor(NearestNeighbors, res_cols = 1 if len(self.y.shape) == 1 else self.y.shape[1] res_shape = n_rows if res_cols == 1 else (n_rows, res_cols) results = CumlArray.zeros(res_shape, dtype=np.float32, - order="C") + order="C", + index=knn_indices.index) cdef uintptr_t results_ptr = results.ptr cdef uintptr_t y_ptr diff --git a/python/cuml/neighbors/nearest_neighbors.pyx b/python/cuml/neighbors/nearest_neighbors.pyx index d63ff3ba00..eee5902941 100644 --- a/python/cuml/neighbors/nearest_neighbors.pyx +++ b/python/cuml/neighbors/nearest_neighbors.pyx @@ -391,6 +391,7 @@ class NearestNeighbors(Base, convert_to_dtype=(np.float32 if convert_dtype else None)) + self._output_index = self.X_m.index if self.metric not in \ valid_metrics[self.working_algorithm_]: @@ -654,6 +655,7 @@ class NearestNeighbors(Base, # expanded metrics. This code correct numerical instabilities # that could arise. if metric_is_l2_based: + index = I_ndarr.index X = input_to_cupy_array(X).array I_cparr = I_ndarr.to_output('cupy') @@ -670,7 +672,9 @@ class NearestNeighbors(Base, I_cparr = cp.take_along_axis(I_cparr, correct_order, axis=1) D_ndarr = cuml.common.input_to_cuml_array(D_cparr).array + D_ndarr.index = index I_ndarr = cuml.common.input_to_cuml_array(I_cparr).array + I_ndarr.index = index I_ndarr = I_ndarr.to_output(out_type) D_ndarr = D_ndarr.to_output(out_type) @@ -702,9 +706,11 @@ class NearestNeighbors(Base, # Need to establish result matrices for indices (Nxk) # and for distances (Nxk) - I_ndarr = CumlArray.zeros((N, n_neighbors), dtype=np.int64, order="C") + I_ndarr = CumlArray.zeros((N, n_neighbors), dtype=np.int64, order="C", + index=X_m.index) D_ndarr = CumlArray.zeros((N, n_neighbors), - dtype=np.float32, order="C") + dtype=np.float32, order="C", + index=X_m.index) cdef uintptr_t I_ptr = I_ndarr.ptr cdef uintptr_t D_ptr = D_ndarr.ptr diff --git a/python/cuml/random_projection/random_projection.pyx b/python/cuml/random_projection/random_projection.pyx index 0ab6d9a99a..ec6bbadecb 100644 --- a/python/cuml/random_projection/random_projection.pyx +++ b/python/cuml/random_projection/random_projection.pyx @@ -291,7 +291,8 @@ cdef class BaseRandomProjection(): X_new = CumlArray.empty((n_samples, self.params.n_components), dtype=self.dtype, - order='F') + order='F', + index=X_m.index) cdef uintptr_t output_ptr = X_new.ptr diff --git a/python/cuml/solvers/cd.pyx b/python/cuml/solvers/cd.pyx index bf206aef8c..86d34dbd21 100644 --- a/python/cuml/solvers/cd.pyx +++ b/python/cuml/solvers/cd.pyx @@ -310,7 +310,8 @@ class CD(Base, cdef uintptr_t X_ptr = X_m.ptr cdef uintptr_t coef_ptr = self.coef_.ptr - preds = CumlArray.zeros(n_rows, dtype=self.dtype) + preds = CumlArray.zeros(n_rows, dtype=self.dtype, + index=X_m.index) cdef uintptr_t preds_ptr = preds.ptr cdef handle_t* handle_ = self.handle.getHandle() diff --git a/python/cuml/solvers/qn.pyx b/python/cuml/solvers/qn.pyx index 769a1fd11a..50db45c933 100644 --- a/python/cuml/solvers/qn.pyx +++ b/python/cuml/solvers/qn.pyx @@ -789,7 +789,8 @@ class QN(Base, order='K' ) - preds = CumlArray.zeros(shape=n_rows, dtype=self.dtype) + preds = CumlArray.zeros(shape=n_rows, dtype=self.dtype, + index=X_m.index) cdef uintptr_t coef_ptr = self._coef_.ptr cdef uintptr_t pred_ptr = preds.ptr diff --git a/python/cuml/solvers/sgd.pyx b/python/cuml/solvers/sgd.pyx index 3738ab7240..aeea5f5b72 100644 --- a/python/cuml/solvers/sgd.pyx +++ b/python/cuml/solvers/sgd.pyx @@ -417,7 +417,7 @@ class SGD(Base, cdef uintptr_t X_ptr = X_m.ptr cdef uintptr_t coef_ptr = self.coef_.ptr - preds = CumlArray.zeros(n_rows, dtype=self.dtype) + preds = CumlArray.zeros(n_rows, dtype=self.dtype, index=X_m.index) cdef uintptr_t preds_ptr = preds.ptr cdef handle_t* handle_ = self.handle.getHandle() @@ -465,7 +465,7 @@ class SGD(Base, cdef uintptr_t X_ptr = X_m.ptr cdef uintptr_t coef_ptr = self.coef_.ptr - preds = CumlArray.zeros(n_rows, dtype=dtype) + preds = CumlArray.zeros(n_rows, dtype=dtype, index=X_m.index) cdef uintptr_t preds_ptr = preds.ptr cdef handle_t* handle_ = self.handle.getHandle() diff --git a/python/cuml/svm/svm_base.pyx b/python/cuml/svm/svm_base.pyx index 22c1e461e2..d0b8f9a0b5 100644 --- a/python/cuml/svm/svm_base.pyx +++ b/python/cuml/svm/svm_base.pyx @@ -554,7 +554,7 @@ class SVMBase(Base, cdef uintptr_t X_ptr = X_m.ptr - preds = CumlArray.zeros(n_rows, dtype=self.dtype) + preds = CumlArray.zeros(n_rows, dtype=self.dtype, index=X_m.index) cdef uintptr_t preds_ptr = preds.ptr cdef handle_t* handle_ = self.handle.getHandle() cdef SvmModel[float]* model_f diff --git a/python/cuml/test/test_input_utils.py b/python/cuml/test/test_input_utils.py index bb66fa299d..1b4192c0b3 100644 --- a/python/cuml/test/test_input_utils.py +++ b/python/cuml/test/test_input_utils.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2019-2020, NVIDIA CORPORATION. +# Copyright (c) 2019-2021, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,7 +19,6 @@ import cudf import cupy as cp import numpy as np -from pandas import DataFrame as pdDF from cuml.common import input_to_cuml_array, CumlArray from cuml.common import input_to_host_array @@ -28,6 +27,8 @@ from cuml.common.input_utils import convert_dtype from cuml.common.memory_utils import _check_array_contiguity from numba import cuda as nbcuda +from pandas import DataFrame as pdDF +from pandas import Series as pdSeries ############################################################################### @@ -288,6 +289,33 @@ def test_non_contiguous_to_contiguous_input(dtype, input_type, order, np.testing.assert_equal(real_data, cumlary.to_output('numpy')) +@pytest.mark.parametrize('input_type', ['cudf', 'pandas']) +@pytest.mark.parametrize('num_rows', test_num_rows) +@pytest.mark.parametrize('num_cols', test_num_cols) +@pytest.mark.parametrize('order', ['C', 'F']) +def test_indexed_inputs(input_type, num_rows, num_cols, order): + if num_cols == 1: + input_type += '-series' + + index = np.arange(num_rows, 2 * num_rows) + + input_data, real_data = get_input(input_type, num_rows, num_cols, + np.float32, index=index) + + X, n_rows, n_cols, res_dtype = input_to_cuml_array(input_data, + order=order) + + # testing the index in the cuml array + np.testing.assert_equal(X.index.to_numpy(), index) + + # testing the index in the converted outputs + cudf_output = X.to_output('cudf') + np.testing.assert_equal(cudf_output.index.to_numpy(), index) + + pandas_output = X.to_output('pandas') + np.testing.assert_equal(pandas_output.index.to_numpy(), index) + + ############################################################################### # Utility Functions # ############################################################################### @@ -318,7 +346,8 @@ def get_ptr(x): assert get_ptr(a) == get_ptr(b) -def get_input(type, nrows, ncols, dtype, order='C', out_dtype=False): +def get_input(type, nrows, ncols, dtype, order='C', out_dtype=False, + index=None): rand_mat = (cp.random.rand(nrows, ncols) * 10) rand_mat = cp.array(rand_mat, dtype=dtype, order=order) @@ -332,10 +361,16 @@ def get_input(type, nrows, ncols, dtype, order='C', out_dtype=False): result = nbcuda.as_cuda_array(rand_mat) if type == 'cudf': - result = cudf.DataFrame(rand_mat) + result = cudf.DataFrame(rand_mat, index=index) + + if type == 'cudf-series': + result = cudf.Series(rand_mat, index=index) if type == 'pandas': - result = pdDF(cp.asnumpy(rand_mat)) + result = pdDF(cp.asnumpy(rand_mat), index=index) + + if type == 'pandas-series': + result = pdSeries(cp.asnumpy(rand_mat).reshape(nrows,), index=index) if type == 'cuml': result = CumlArray(data=rand_mat)