Skip to content

Commit

Permalink
Optionally skip cupy on windows. (#10611)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Jul 20, 2024
1 parent 344ddeb commit 0846ad8
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 30 deletions.
41 changes: 11 additions & 30 deletions python-package/xgboost/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
get_cancer,
get_digits,
get_sparse,
make_batches,
memory,
)

Expand Down Expand Up @@ -161,7 +162,16 @@ def no_cudf() -> PytestSkip:


def no_cupy() -> PytestSkip:
return no_mod("cupy")
skip_cupy = no_mod("cupy")
if not skip_cupy["condition"] and system() == "Windows":
import cupy as cp

# Cupy might run into issue on Windows due to missing compiler
try:
cp.array([1, 2, 3]).sum()
except Exception: # pylint: disable=broad-except
skip_cupy["condition"] = True
return skip_cupy


def no_dask_cudf() -> PytestSkip:
Expand Down Expand Up @@ -248,35 +258,6 @@ def as_arrays(
return X, y, w


def make_batches( # pylint: disable=too-many-arguments,too-many-locals
n_samples_per_batch: int,
n_features: int,
n_batches: int,
use_cupy: bool = False,
*,
vary_size: bool = False,
random_state: int = 1994,
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
X = []
y = []
w = []
if use_cupy:
import cupy

rng = cupy.random.RandomState(random_state)
else:
rng = np.random.RandomState(random_state)
for i in range(n_batches):
n_samples = n_samples_per_batch + i * 10 if vary_size else n_samples_per_batch
_X = rng.randn(n_samples, n_features)
_y = rng.randn(n_samples)
_w = rng.uniform(low=0, high=1, size=n_samples)
X.append(_X)
y.append(_y)
w.append(_w)
return X, y, w


def make_regression(
n_samples: int, n_features: int, use_cupy: bool
) -> Tuple[ArrayLike, ArrayLike, ArrayLike]:
Expand Down
31 changes: 31 additions & 0 deletions python-package/xgboost/testing/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Callable,
Dict,
Generator,
List,
NamedTuple,
Optional,
Tuple,
Expand Down Expand Up @@ -506,6 +507,36 @@ def get_mq2008(
)


def make_batches( # pylint: disable=too-many-arguments,too-many-locals
n_samples_per_batch: int,
n_features: int,
n_batches: int,
use_cupy: bool = False,
*,
vary_size: bool = False,
random_state: int = 1994,
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
"""Make batches of dense data."""
X = []
y = []
w = []
if use_cupy:
import cupy # pylint: disable=import-error

rng = cupy.random.RandomState(random_state)
else:
rng = np.random.RandomState(random_state)
for i in range(n_batches):
n_samples = n_samples_per_batch + i * 10 if vary_size else n_samples_per_batch
_X = rng.randn(n_samples, n_features)
_y = rng.randn(n_samples)
_w = rng.uniform(low=0, high=1, size=n_samples)
X.append(_X)
y.append(_y)
w.append(_w)
return X, y, w


RelData = Tuple[sparse.csr_matrix, npt.NDArray[np.int32], npt.NDArray[np.int32]]


Expand Down

0 comments on commit 0846ad8

Please sign in to comment.