From 1ebee48b90d54290573134df33a20332cf886fa3 Mon Sep 17 00:00:00 2001 From: fis Date: Mon, 25 May 2020 14:24:07 +0800 Subject: [PATCH] Fix IsDense. --- src/data/device_dmatrix.cu | 11 ++++++----- tests/cpp/data/test_device_dmatrix.cu | 19 +++++++++++++++++++ 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/src/data/device_dmatrix.cu b/src/data/device_dmatrix.cu index 6173ec913995..d11d01b16f2a 100644 --- a/src/data/device_dmatrix.cu +++ b/src/data/device_dmatrix.cu @@ -33,16 +33,17 @@ DeviceDMatrix::DeviceDMatrix(AdapterT* adapter, float missing, int nthread, int size_t row_stride = GetRowCounts(batch, row_counts_span, adapter->DeviceIdx(), missing); - ellpack_page_.reset(new EllpackPage()); - *ellpack_page_->Impl() = - EllpackPageImpl(adapter, missing, this->IsDense(), nthread, max_bin, - row_counts_span, row_stride); - dh::XGBCachingDeviceAllocator alloc; info_.num_nonzero_ = thrust::reduce(thrust::cuda::par(alloc), row_counts.begin(), row_counts.end()); info_.num_col_ = adapter->NumColumns(); info_.num_row_ = adapter->NumRows(); + + ellpack_page_.reset(new EllpackPage()); + *ellpack_page_->Impl() = + EllpackPageImpl(adapter, missing, this->IsDense(), nthread, max_bin, + row_counts_span, row_stride); + // Synchronise worker columns rabit::Allreduce(&info_.num_col_, 1); } diff --git a/tests/cpp/data/test_device_dmatrix.cu b/tests/cpp/data/test_device_dmatrix.cu index 1634cea3c46e..db29cc574342 100644 --- a/tests/cpp/data/test_device_dmatrix.cu +++ b/tests/cpp/data/test_device_dmatrix.cu @@ -129,3 +129,22 @@ TEST(DeviceDMatrix, Equivalent) { } } } + +TEST(DeviceDMatrix, IsDense) { + int num_bins = 16; + auto test = [num_bins] (float sparsity) { + HostDeviceVector data; + std::string interface_str = RandomDataGenerator{10, 10, sparsity} + .Device(0).GenerateArrayInterface(&data); + data::CupyAdapter x{interface_str}; + std::unique_ptr device_dmat{ new data::DeviceDMatrix( + &x, std::numeric_limits::quiet_NaN(), 1, num_bins) }; + if (sparsity == 0.0) { + ASSERT_TRUE(device_dmat->IsDense()) << sparsity; + } else { + ASSERT_FALSE(device_dmat->IsDense()); + } + }; + test(0.0); + test(0.1); +}