From 8af706267622e96adac6de9fa07bb6ec9ca47631 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sat, 19 Oct 2024 02:11:26 +0800 Subject: [PATCH] Test ellpack categorical feature with missing values. (#10906) --- tests/cpp/data/test_ellpack_page.cu | 39 ++++++++++++++++++++++++++--- 1 file changed, 35 insertions(+), 4 deletions(-) diff --git a/tests/cpp/data/test_ellpack_page.cu b/tests/cpp/data/test_ellpack_page.cu index 3a8d0b8d0b17..85d3008dc55b 100644 --- a/tests/cpp/data/test_ellpack_page.cu +++ b/tests/cpp/data/test_ellpack_page.cu @@ -5,10 +5,10 @@ #include -#include "../../../src/common/categorical.h" +#include "../../../src/common/categorical.h" // for AsCat +#include "../../../src/common/compressed_iterator.h" // for CompressedByteT #include "../../../src/common/hist_util.h" -#include "../../../src/common/ref_resource_view.cuh" // for MakeCudaGrowOnly -#include "../../../src/data/device_adapter.cuh" // for CupyAdapter +#include "../../../src/data/device_adapter.cuh" // for CupyAdapter #include "../../../src/data/ellpack_page.cuh" #include "../../../src/data/ellpack_page.h" #include "../../../src/data/gradient_index.h" // for GHistIndexMatrix @@ -98,7 +98,7 @@ TEST(EllpackPage, FromCategoricalBasic) { auto& h_ft = m->Info().feature_types.HostVector(); h_ft.resize(kCols, FeatureType::kCategorical); - Context ctx{MakeCUDACtx(0)}; + auto ctx = MakeCUDACtx(0); auto p = BatchParam{max_bins, tree::TrainParam::DftSparseThreshold()}; auto ellpack = EllpackPage(&ctx, m.get(), p); auto accessor = ellpack.Impl()->GetDeviceAccessor(&ctx); @@ -128,6 +128,37 @@ TEST(EllpackPage, FromCategoricalBasic) { } } +TEST(EllpackPage, FromCategoricalMissing) { + auto ctx = MakeCUDACtx(0); + + std::shared_ptr cuts; + auto nan = std::numeric_limits::quiet_NaN(); + // 2 rows and 3 columns. The second column is nan, row_stride is 2. + std::vector data{{0.1, nan, 1, 0.2, nan, 0}}; + auto p_fmat = GetDMatrixFromData(data, 2, 3); + p_fmat->Info().feature_types.HostVector() = {FeatureType::kNumerical, FeatureType::kNumerical, + FeatureType::kCategorical}; + p_fmat->Info().feature_types.SetDevice(ctx.Device()); + + auto p = BatchParam{256, tree::TrainParam::DftSparseThreshold()}; + for (auto const& page : p_fmat->GetBatches(&ctx, p)) { + cuts = std::make_shared(page.Cuts()); + } + cuts->SetDevice(ctx.Device()); + for (auto const& page : p_fmat->GetBatches(&ctx, p)) { + std::vector h_buffer; + auto h_acc = page.Impl()->GetHostAccessor(&ctx, &h_buffer, + p_fmat->Info().feature_types.ConstDeviceSpan()); + ASSERT_EQ(h_acc.n_rows, 2); + ASSERT_EQ(cuts->NumFeatures(), 3); + ASSERT_EQ(h_acc.row_stride, 2); + ASSERT_EQ(h_acc.gidx_iter[0], 0); + ASSERT_EQ(h_acc.gidx_iter[1], 4); // cat 1 + ASSERT_EQ(h_acc.gidx_iter[2], 1); + ASSERT_EQ(h_acc.gidx_iter[3], 3); // cat 0 + } +} + struct ReadRowFunction { EllpackDeviceAccessor matrix; int row;