From 8629f514ea6f2d86a834203c9d9130e8bb46a2a1 Mon Sep 17 00:00:00 2001 From: fis Date: Tue, 20 Jul 2021 20:19:44 +0800 Subject: [PATCH] Fix test. --- tests/python/test_data_iterator.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/python/test_data_iterator.py b/tests/python/test_data_iterator.py index e2f247c0fc73..f6ce46c21f62 100644 --- a/tests/python/test_data_iterator.py +++ b/tests/python/test_data_iterator.py @@ -61,7 +61,9 @@ def run_data_iterator( ) -> None: n_rounds = 2 - it = IteratorForTest(*make_batches(n_samples_per_batch, n_features, n_batches)) + it = IteratorForTest( + *make_batches(n_samples_per_batch, n_features, n_batches, cupy) + ) if n_batches == 0: with pytest.raises(ValueError, match="1 batch"): Xy = xgb.DMatrix(it)