From d663839fded3becbb73ef3271e135ce86d79d83e Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 7 Mar 2023 04:35:49 +0800 Subject: [PATCH 1/2] Support F order for the tensor type. - Add F order support for tensor and view. - Use parameter pack for automatic type cast. (avoid excessive static cast for shape). --- include/xgboost/linalg.h | 157 ++++++++++++++++----- src/metric/elementwise_metric.cu | 5 +- src/objective/adaptive.cc | 12 +- src/objective/adaptive.cu | 4 +- src/objective/quantile_obj.cu | 14 +- src/tree/updater_quantile_hist.cc | 3 +- tests/cpp/common/test_linalg.cc | 73 +++++++--- tests/cpp/common/test_linalg.cu | 10 +- tests/cpp/objective/test_regression_obj.cc | 4 +- 9 files changed, 191 insertions(+), 91 deletions(-) diff --git a/include/xgboost/linalg.h b/include/xgboost/linalg.h index 176002225157..a7a24a2e42b4 100644 --- a/include/xgboost/linalg.h +++ b/include/xgboost/linalg.h @@ -15,11 +15,11 @@ #include #include -#include // std::int32_t -#include // std::size_t +#include // for int32_t +#include // for size_t #include #include -#include +#include // for make_tuple #include #include #include @@ -37,8 +37,7 @@ #endif // defined (__CUDA__) || defined(__NVCC__) #endif // LINALG_HD -namespace xgboost { -namespace linalg { +namespace xgboost::linalg { namespace detail { struct ArrayInterfaceHandler { @@ -86,7 +85,7 @@ template struct RangeTag { I beg; I end; - constexpr size_t Size() const { return end - beg; } + [[nodiscard]] constexpr size_t Size() const { return end - beg; } }; /** @@ -158,14 +157,34 @@ inline LINALG_HD int Popc(uint64_t v) { #endif // compiler } +template +LINALG_HD void IndexToArr(std::size_t (&arr)[D], Head head) { + static_assert(std::is_integral>::value, "Invalid index type."); + arr[D - 1] = head; +} + +/** + * \brief Convert index from parameter pack to C-style array. + */ +template +LINALG_HD void IndexToArr(std::size_t (&arr)[D], Head head, Rest &&...index) { + static_assert(sizeof...(Rest) < D, "Index overflow."); + static_assert(std::is_integral>::value, "Invalid index type."); + arr[D - sizeof...(Rest) - 1] = head; + IndexToArr(arr, std::forward(index)...); +} + template -constexpr auto Arr2Tup(T (&arr)[N], std::index_sequence) { +constexpr auto ArrToTuple(T (&arr)[N], std::index_sequence) { return std::make_tuple(arr[Idx]...); } +/** + * \brief Convert C-styple array to std::tuple. + */ template -constexpr auto Arr2Tup(T (&arr)[N]) { - return Arr2Tup(arr, std::make_index_sequence{}); +constexpr auto ArrToTuple(T (&arr)[N]) { + return ArrToTuple(arr, std::make_index_sequence{}); } // uint division optimization inspired by the CIndexer in cupy. Division operation is @@ -188,7 +207,7 @@ LINALG_HD auto UnravelImpl(I idx, common::Span shape) { } } index[0] = idx; - return Arr2Tup(index); + return ArrToTuple(index); } template @@ -252,6 +271,11 @@ constexpr detail::RangeTag Range(I beg, I end) { return {beg, end}; } +enum Order : std::uint8_t { + kC, // Row major + kF, // Col major +}; + /** * \brief A tensor view with static type and dimension. It implements indexing and slicing. * @@ -377,7 +401,11 @@ class TensorView { * \param device Device ordinal */ template - LINALG_HD TensorView(common::Span data, I const (&shape)[D], int32_t device) + LINALG_HD TensorView(common::Span data, I const (&shape)[D], std::int32_t device) + : TensorView{data, shape, device, Order::kC} {} + + template + LINALG_HD TensorView(common::Span data, I const (&shape)[D], std::int32_t device, Order order) : data_{data}, ptr_{data_.data()}, device_{device} { static_assert(D > 0 && D <= kDim, "Invalid shape."); // shape @@ -386,7 +414,19 @@ class TensorView { shape_[i] = 1; } // stride - detail::CalcStride(shape_, stride_); + switch (order) { + case Order::kC: { + detail::CalcStride(shape_, stride_); + break; + } + case Order::kF: { + detail::CalcStride(shape_, stride_); + break; + } + default: { + SPAN_CHECK(false); + } + } // size this->CalcSize(); } @@ -490,17 +530,17 @@ class TensorView { /** * \brief Number of items in the tensor. */ - LINALG_HD size_t Size() const { return size_; } + LINALG_HD [[nodiscard]] std::size_t Size() const { return size_; } /** * \brief Whether this is a contiguous array, both C and F contiguous returns true. */ - LINALG_HD bool Contiguous() const { + LINALG_HD [[nodiscard]] bool Contiguous() const { return data_.size() == this->Size() || this->CContiguous() || this->FContiguous(); } /** * \brief Whether it's a c-contiguous array. */ - LINALG_HD bool CContiguous() const { + LINALG_HD [[nodiscard]] bool CContiguous() const { StrideT stride; static_assert(std::is_same::value); // It's contiguous if the stride can be calculated from shape. @@ -510,7 +550,7 @@ class TensorView { /** * \brief Whether it's a f-contiguous array. */ - LINALG_HD bool FContiguous() const { + LINALG_HD [[nodiscard]] bool FContiguous() const { StrideT stride; static_assert(std::is_same::value); // It's contiguous if the stride can be calculated from shape. @@ -530,16 +570,37 @@ class TensorView { /** * \brief Constructor for automatic type deduction. */ -template ::value> * = nullptr> -auto MakeTensorView(Container &data, I const (&shape)[D], int32_t device) { // NOLINT +auto MakeTensorView(Context const *ctx, Container &data, S &&...shape) { // NOLINT using T = typename Container::value_type; - return TensorView{data, shape, device}; + std::size_t in_shape[sizeof...(S)]; + detail::IndexToArr(in_shape, std::forward(shape)...); + return TensorView{data, in_shape, ctx->gpu_id}; +} + +template +LINALG_HD auto MakeTensorView(std::int32_t device, common::Span data, S &&...shape) { + std::size_t in_shape[sizeof...(S)]; + detail::IndexToArr(in_shape, std::forward(shape)...); + return TensorView{data, in_shape, device}; } -template -LINALG_HD auto MakeTensorView(common::Span data, I const (&shape)[D], int32_t device) { - return TensorView{data, shape, device}; +template +auto MakeTensorView(Context const *ctx, common::Span data, S &&...shape) { + return MakeTensorView(ctx->gpu_id, data, std::forward(shape)...); +} + +template +auto MakeTensorView(Context const *ctx, HostDeviceVector *data, S &&...shape) { + auto span = ctx->IsCPU() ? data->HostSpan() : data->DeviceSpan(); + return MakeTensorView(ctx->gpu_id, span, std::forward(shape)...); +} + +template +auto MakeTensorView(Context const *ctx, HostDeviceVector const *data, S &&...shape) { + auto span = ctx->IsCPU() ? data->ConstHostSpan() : data->ConstDeviceSpan(); + return MakeTensorView(ctx->gpu_id, span, std::forward(shape)...); } /** @@ -559,6 +620,13 @@ LINALG_HD auto UnravelIndex(size_t idx, std::size_t const (&shape)[D]) { return UnravelIndex(idx, common::Span(shape)); } +template +LINALG_HD auto UnravelIndex(std::size_t idx, S... shape) { + std::size_t s[sizeof...(S)]; + detail::IndexToArr(s, shape...); + return UnravelIndex(idx, common::Span(s)); +} + /** * \brief A view over a vector, specialization of Tensor * @@ -676,6 +744,7 @@ class Tensor { private: HostDeviceVector data_; ShapeT shape_{0}; + Order order_{Order::kC}; template void Initialize(I const (&shape)[D], std::int32_t device) { @@ -701,11 +770,12 @@ class Tensor { * See \ref TensorView for parameters of this constructor. */ template - explicit Tensor(I const (&shape)[D], int32_t device) - : Tensor{common::Span{shape}, device} {} + explicit Tensor(I const (&shape)[D], std::int32_t device, Order order = kC) + : Tensor{common::Span{shape}, device, order} {} template - explicit Tensor(common::Span shape, int32_t device) { + explicit Tensor(common::Span shape, std::int32_t device, Order order = kC) + : order_{order} { // No device unroll as this is a host only function. std::copy(shape.data(), shape.data() + D, shape_); for (auto i = D; i < kDim; ++i) { @@ -724,7 +794,8 @@ class Tensor { * Initialize from 2 host iterators. */ template - explicit Tensor(It begin, It end, I const (&shape)[D], int32_t device) { + explicit Tensor(It begin, It end, I const (&shape)[D], std::int32_t device, Order order = kC) + : order_{order} { auto &h_vec = data_.HostVector(); h_vec.insert(h_vec.begin(), begin, end); // shape @@ -732,8 +803,9 @@ class Tensor { } template - explicit Tensor(std::initializer_list data, I const (&shape)[D], - int32_t device = Context::kCpuId) { + explicit Tensor(std::initializer_list data, I const (&shape)[D], std::int32_t device, + Order order = kC) + : order_{order} { auto &h_vec = data_.HostVector(); h_vec = data; // shape @@ -763,27 +835,27 @@ class Tensor { if (device >= 0) { data_.SetDevice(device); auto span = data_.DeviceSpan(); - return {span, shape_, device}; + return {span, shape_, device, order_}; } else { auto span = data_.HostSpan(); - return {span, shape_, device}; + return {span, shape_, device, order_}; } } TensorView View(int32_t device) const { if (device >= 0) { data_.SetDevice(device); auto span = data_.ConstDeviceSpan(); - return {span, shape_, device}; + return {span, shape_, device, order_}; } else { auto span = data_.ConstHostSpan(); - return {span, shape_, device}; + return {span, shape_, device, order_}; } } auto HostView() const { return this->View(-1); } auto HostView() { return this->View(-1); } - size_t Size() const { return data_.Size(); } + [[nodiscard]] size_t Size() const { return data_.Size(); } auto Shape() const { return common::Span{shape_}; } auto Shape(size_t i) const { return shape_[i]; } @@ -837,12 +909,26 @@ class Tensor { void Reshape(size_t (&shape)[D]) { this->Reshape(common::Span{shape}); } + /** + * \brief Get a host view on the slice. + */ + template + auto Slice(S &&...slices) const { + return this->HostView().Slice(std::forward(slices)...); + } + /** + * \brief Get a host view on the slice. + */ + template + auto Slice(S &&...slices) { + return this->HostView().Slice(std::forward(slices)...); + } /** * \brief Set device ordinal for this tensor. */ void SetDevice(int32_t device) const { data_.SetDevice(device); } - int32_t DeviceIdx() const { return data_.DeviceIdx(); } + [[nodiscard]] int32_t DeviceIdx() const { return data_.DeviceIdx(); } }; template @@ -900,8 +986,7 @@ void Stack(Tensor *l, Tensor const &r) { shape[0] = l->Shape(0) + r.Shape(0); }); } -} // namespace linalg -} // namespace xgboost +} // namespace xgboost::linalg #if defined(LINALG_HD) #undef LINALG_HD diff --git a/src/metric/elementwise_metric.cu b/src/metric/elementwise_metric.cu index e06be9730e8f..9006bdfca5eb 100644 --- a/src/metric/elementwise_metric.cu +++ b/src/metric/elementwise_metric.cu @@ -451,9 +451,8 @@ class QuantileError : public MetricNoCache { auto alpha = ctx->IsCPU() ? alpha_.ConstHostSpan() : alpha_.ConstDeviceSpan(); std::size_t n_targets = preds.Size() / info.num_row_ / alpha_.Size(); CHECK_NE(n_targets, 0); - auto y_predt = linalg::MakeTensorView( - ctx->IsCPU() ? preds.ConstHostSpan() : preds.ConstDeviceSpan(), - {static_cast(info.num_row_), alpha_.Size(), n_targets}, ctx->gpu_id); + auto y_predt = linalg::MakeTensorView(ctx, &preds, static_cast(info.num_row_), + alpha_.Size(), n_targets); info.weights_.SetDevice(ctx->gpu_id); common::OptionalWeights weight{ctx->IsCPU() ? info.weights_.ConstHostSpan() diff --git a/src/objective/adaptive.cc b/src/objective/adaptive.cc index f5f35c8461d4..2d89c94593c4 100644 --- a/src/objective/adaptive.cc +++ b/src/objective/adaptive.cc @@ -23,9 +23,7 @@ #include "xgboost/span.h" // Span #include "xgboost/tree_model.h" // RegTree -namespace xgboost { -namespace obj { -namespace detail { +namespace xgboost::obj::detail { void EncodeTreeLeafHost(Context const* ctx, RegTree const& tree, std::vector const& position, std::vector* p_nptr, std::vector* p_nidx, std::vector* p_ridx) { @@ -98,8 +96,8 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector const& posit auto const& h_node_idx = nidx; auto const& h_node_ptr = nptr; CHECK_LE(h_node_ptr.back(), info.num_row_); - auto h_predt = linalg::MakeTensorView(predt.ConstHostSpan(), - {info.num_row_, predt.Size() / info.num_row_}, ctx->gpu_id); + auto h_predt = linalg::MakeTensorView(ctx, predt.ConstHostSpan(), info.num_row_, + predt.Size() / info.num_row_); // loop over each leaf common::ParallelFor(quantiles.size(), ctx->Threads(), [&](size_t k) { @@ -143,6 +141,4 @@ void UpdateTreeLeafDevice(Context const*, common::Span, std::i common::AssertGPUSupport(); } #endif // !defined(XGBOOST_USE_CUDA) -} // namespace detail -} // namespace obj -} // namespace xgboost +} // namespace xgboost::obj::detail diff --git a/src/objective/adaptive.cu b/src/objective/adaptive.cu index 76627cf6d319..662b0330beb7 100644 --- a/src/objective/adaptive.cu +++ b/src/objective/adaptive.cu @@ -157,8 +157,8 @@ void UpdateTreeLeafDevice(Context const* ctx, common::Span pos HostDeviceVector quantiles; predt.SetDevice(ctx->gpu_id); - auto d_predt = linalg::MakeTensorView(predt.ConstDeviceSpan(), - {info.num_row_, predt.Size() / info.num_row_}, ctx->gpu_id); + auto d_predt = linalg::MakeTensorView(ctx, predt.ConstDeviceSpan(), info.num_row_, + predt.Size() / info.num_row_); CHECK_LT(group_idx, d_predt.Shape(1)); auto t_predt = d_predt.Slice(linalg::All(), group_idx); auto d_labels = info.labels.View(ctx->gpu_id).Slice(linalg::All(), IdxY(info, group_idx)); diff --git a/src/objective/quantile_obj.cu b/src/objective/quantile_obj.cu index 3b9204251cf4..0a40758bc86d 100644 --- a/src/objective/quantile_obj.cu +++ b/src/objective/quantile_obj.cu @@ -64,8 +64,7 @@ class QuantileRegression : public ObjFunction { out_gpair->SetDevice(ctx_->gpu_id); out_gpair->Resize(n_targets * info.num_row_); auto gpair = - linalg::MakeTensorView(ctx_->IsCPU() ? out_gpair->HostSpan() : out_gpair->DeviceSpan(), - {info.num_row_, n_alphas, n_targets / n_alphas}, ctx_->gpu_id); + linalg::MakeTensorView(ctx_, out_gpair, info.num_row_, n_alphas, n_targets / n_alphas); info.weights_.SetDevice(ctx_->gpu_id); common::OptionalWeights weight{ctx_->IsCPU() ? info.weights_.ConstHostSpan() @@ -80,15 +79,8 @@ class QuantileRegression : public ObjFunction { linalg::ElementWiseKernel( ctx_, gpair, [=] XGBOOST_DEVICE(std::size_t i, GradientPair const&) mutable { - auto idx = linalg::UnravelIndex(static_cast(i), - {static_cast(n_samples), - static_cast(alpha.size()), - static_cast(n_targets / alpha.size())}); - - // std::tie is not available for cuda kernel. - std::size_t sample_id = std::get<0>(idx); - std::size_t quantile_id = std::get<1>(idx); - std::size_t target_id = std::get<2>(idx); + auto [sample_id, quantile_id, target_id] = + linalg::UnravelIndex(i, n_samples, alpha.size(), n_targets / alpha.size()); auto d = predt(i) - labels(sample_id, target_id); auto h = weight[sample_id]; diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index ad30442f66a3..1929efb28837 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -274,8 +274,7 @@ void QuantileHistMaker::Builder::InitData(DMatrix *fmat, const RegTree &tree, histogram_builder_->Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id, collective::IsDistributed(), fmat->IsColumnSplit()); - auto m_gpair = - linalg::MakeTensorView(*gpair, {gpair->size(), static_cast(1)}, ctx_->gpu_id); + auto m_gpair = linalg::MakeTensorView(ctx_, *gpair, gpair->size(), static_cast(1)); SampleGradient(ctx_, *param_, m_gpair); } diff --git a/tests/cpp/common/test_linalg.cc b/tests/cpp/common/test_linalg.cc index bfba591fab6f..b1a90d773e28 100644 --- a/tests/cpp/common/test_linalg.cc +++ b/tests/cpp/common/test_linalg.cc @@ -6,17 +6,18 @@ #include #include -#include +#include // size_t +#include // iota +#include #include "../../../src/common/linalg_op.h" -namespace xgboost { -namespace linalg { +namespace xgboost::linalg { namespace { auto kCpuId = Context::kCpuId; } -auto MakeMatrixFromTest(HostDeviceVector *storage, size_t n_rows, size_t n_cols) { +auto MakeMatrixFromTest(HostDeviceVector *storage, std::size_t n_rows, std::size_t n_cols) { storage->Resize(n_rows * n_cols); auto &h_storage = storage->HostVector(); @@ -48,10 +49,11 @@ TEST(Linalg, VectorView) { } TEST(Linalg, TensorView) { + Context ctx; std::vector data(2 * 3 * 4, 0); std::iota(data.begin(), data.end(), 0); - auto t = MakeTensorView(data, {2, 3, 4}, -1); + auto t = MakeTensorView(&ctx, data, 2, 3, 4); ASSERT_EQ(t.Shape()[0], 2); ASSERT_EQ(t.Shape()[1], 3); ASSERT_EQ(t.Shape()[2], 4); @@ -106,12 +108,12 @@ TEST(Linalg, TensorView) { { // Don't assign the initial dimension, tensor should be able to deduce the correct dim // for Slice. - auto t = MakeTensorView(data, {2, 3, 4}, 0); + auto t = MakeTensorView(&ctx, data, 2, 3, 4); auto s = t.Slice(1, 2, All()); static_assert(decltype(s)::kDimension == 1); } { - auto t = MakeTensorView(data, {2, 3, 4}, 0); + auto t = MakeTensorView(&ctx, data, 2, 3, 4); auto s = t.Slice(1, linalg::All(), 1); ASSERT_EQ(s(0), 13); ASSERT_EQ(s(1), 17); @@ -119,7 +121,7 @@ TEST(Linalg, TensorView) { } { // range slice - auto t = MakeTensorView(data, {2, 3, 4}, 0); + auto t = MakeTensorView(&ctx, data, 2, 3, 4); auto s = t.Slice(linalg::All(), linalg::Range(1, 3), 2); static_assert(decltype(s)::kDimension == 2); std::vector sol{6, 10, 18, 22}; @@ -134,7 +136,7 @@ TEST(Linalg, TensorView) { } { // range slice - auto t = MakeTensorView(data, {2, 3, 4}, 0); + auto t = MakeTensorView(&ctx, data, 2, 3, 4); auto s = t.Slice(1, linalg::Range(1, 3), linalg::Range(1, 3)); static_assert(decltype(s)::kDimension == 2); std::vector sol{17, 18, 21, 22}; @@ -149,7 +151,7 @@ TEST(Linalg, TensorView) { } { // same as no slice. - auto t = MakeTensorView(data, {2, 3, 4}, 0); + auto t = MakeTensorView(&ctx, data, 2, 3, 4); auto s = t.Slice(linalg::All(), linalg::Range(0, 3), linalg::Range(0, 4)); static_assert(decltype(s)::kDimension == 3); auto all = t.Slice(linalg::All(), linalg::All(), linalg::All()); @@ -166,7 +168,7 @@ TEST(Linalg, TensorView) { { // copy and move constructor. - auto t = MakeTensorView(data, {2, 3, 4}, kCpuId); + auto t = MakeTensorView(&ctx, data, 2, 3, 4); auto from_copy = t; auto from_move = std::move(t); for (size_t i = 0; i < t.Shape().size(); ++i) { @@ -177,7 +179,7 @@ TEST(Linalg, TensorView) { { // multiple slices - auto t = MakeTensorView(data, {2, 3, 4}, kCpuId); + auto t = MakeTensorView(&ctx, data, 2, 3, 4); auto s_0 = t.Slice(linalg::All(), linalg::Range(0, 2), linalg::Range(1, 4)); ASSERT_FALSE(s_0.CContiguous()); auto s_1 = s_0.Slice(1, 1, linalg::Range(0, 2)); @@ -208,7 +210,7 @@ TEST(Linalg, TensorView) { TEST(Linalg, Tensor) { { - Tensor t{{2, 3, 4}, kCpuId}; + Tensor t{{2, 3, 4}, kCpuId, Order::kC}; auto view = t.View(kCpuId); auto const &as_const = t; @@ -227,7 +229,7 @@ TEST(Linalg, Tensor) { } { // Reshape - Tensor t{{2, 3, 4}, kCpuId}; + Tensor t{{2, 3, 4}, kCpuId, Order::kC}; t.Reshape(4, 3, 2); ASSERT_EQ(t.Size(), 24); ASSERT_EQ(t.Shape(2), 2); @@ -245,7 +247,7 @@ TEST(Linalg, Tensor) { TEST(Linalg, Empty) { { - auto t = TensorView{{}, {0, 3}, kCpuId}; + auto t = TensorView{{}, {0, 3}, kCpuId, Order::kC}; for (int32_t i : {0, 1, 2}) { auto s = t.Slice(All(), i); ASSERT_EQ(s.Size(), 0); @@ -254,7 +256,7 @@ TEST(Linalg, Empty) { } } { - auto t = Tensor{{0, 3}, kCpuId}; + auto t = Tensor{{0, 3}, kCpuId, Order::kC}; ASSERT_EQ(t.Size(), 0); auto view = t.View(kCpuId); @@ -269,7 +271,7 @@ TEST(Linalg, Empty) { TEST(Linalg, ArrayInterface) { auto cpu = kCpuId; - auto t = Tensor{{3, 3}, cpu}; + auto t = Tensor{{3, 3}, cpu, Order::kC}; auto v = t.View(cpu); std::iota(v.Values().begin(), v.Values().end(), 0); auto arr = Json::Load(StringView{ArrayInterfaceStr(v)}); @@ -313,21 +315,48 @@ TEST(Linalg, Popc) { } TEST(Linalg, Stack) { - Tensor l{{2, 3, 4}, kCpuId}; + Tensor l{{2, 3, 4}, kCpuId, Order::kC}; ElementWiseTransformHost(l.View(kCpuId), omp_get_max_threads(), [=](size_t i, float) { return i; }); - Tensor r_0{{2, 3, 4}, kCpuId}; + Tensor r_0{{2, 3, 4}, kCpuId, Order::kC}; ElementWiseTransformHost(r_0.View(kCpuId), omp_get_max_threads(), [=](size_t i, float) { return i; }); Stack(&l, r_0); - Tensor r_1{{0, 3, 4}, kCpuId}; + Tensor r_1{{0, 3, 4}, kCpuId, Order::kC}; Stack(&l, r_1); ASSERT_EQ(l.Shape(0), 4); Stack(&r_1, l); ASSERT_EQ(r_1.Shape(0), l.Shape(0)); } -} // namespace linalg -} // namespace xgboost + +TEST(Linalg, FOrder) { + std::size_t constexpr kRows = 16, kCols = 3; + std::vector data(kRows * kCols); + MatrixView mat{data, {kRows, kCols}, Context::kCpuId, Order::kF}; + float k{0}; + for (std::size_t i = 0; i < kRows; ++i) { + for (std::size_t j = 0; j < kCols; ++j) { + mat(i, j) = k; + k++; + } + } + auto column = mat.Slice(linalg::All(), 1); + ASSERT_TRUE(column.FContiguous()); + ASSERT_EQ(column.Stride(0), 1); + ASSERT_TRUE(column.CContiguous()); + k = 1; + for (auto it = linalg::cbegin(column); it != linalg::cend(column); ++it) { + ASSERT_EQ(*it, k); + k += kCols; + } + k = 1; + auto ptr = column.Values().data(); + for (auto it = ptr; it != ptr + kRows; ++it) { + ASSERT_EQ(*it, k); + k += kCols; + } +} +} // namespace xgboost::linalg diff --git a/tests/cpp/common/test_linalg.cu b/tests/cpp/common/test_linalg.cu index ac2b9a5816d0..fe38f0f9b813 100644 --- a/tests/cpp/common/test_linalg.cu +++ b/tests/cpp/common/test_linalg.cu @@ -7,8 +7,7 @@ #include "xgboost/context.h" #include "xgboost/linalg.h" -namespace xgboost { -namespace linalg { +namespace xgboost::linalg { namespace { void TestElementWiseKernel() { Tensor l{{2, 3, 4}, 0}; @@ -55,8 +54,10 @@ void TestElementWiseKernel() { } void TestSlice() { + Context ctx; + ctx.gpu_id = 1; thrust::device_vector data(2 * 3 * 4); - auto t = MakeTensorView(dh::ToSpan(data), {2, 3, 4}, 0); + auto t = MakeTensorView(&ctx, dh::ToSpan(data), 2, 3, 4); dh::LaunchN(1, [=] __device__(size_t) { auto s = t.Slice(linalg::All(), linalg::Range(0, 3), linalg::Range(0, 4)); auto all = t.Slice(linalg::All(), linalg::All(), linalg::All()); @@ -75,5 +76,4 @@ void TestSlice() { TEST(Linalg, GPUElementWise) { TestElementWiseKernel(); } TEST(Linalg, GPUTensorView) { TestSlice(); } -} // namespace linalg -} // namespace xgboost +} // namespace xgboost::linalg diff --git a/tests/cpp/objective/test_regression_obj.cc b/tests/cpp/objective/test_regression_obj.cc index 424f66aaf3dc..4e37eef18e5f 100644 --- a/tests/cpp/objective/test_regression_obj.cc +++ b/tests/cpp/objective/test_regression_obj.cc @@ -433,8 +433,8 @@ TEST(Objective, DeclareUnifiedTest(AbsoluteErrorLeaf)) { auto h_labels = info.labels.HostView().Slice(linalg::All(), t); std::iota(linalg::begin(h_labels), linalg::end(h_labels), 0); - auto h_predt = linalg::MakeTensorView(predt.HostSpan(), {kRows, kTargets}, Context::kCpuId) - .Slice(linalg::All(), t); + auto h_predt = + linalg::MakeTensorView(&ctx, predt.HostSpan(), kRows, kTargets).Slice(linalg::All(), t); for (size_t i = 0; i < h_predt.Size(); ++i) { h_predt(i) = h_labels(i) + i; } From f0a612bba6f5081ec92e0dfd809c5a5f020b58b0 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 7 Mar 2023 06:09:24 +0800 Subject: [PATCH 2/2] workaround gcc-7. --- include/xgboost/linalg.h | 3 ++- src/objective/adaptive.cc | 3 +-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/include/xgboost/linalg.h b/include/xgboost/linalg.h index a7a24a2e42b4..3d6bcc962017 100644 --- a/include/xgboost/linalg.h +++ b/include/xgboost/linalg.h @@ -571,7 +571,8 @@ class TensorView { * \brief Constructor for automatic type deduction. */ template ::value> * = nullptr> + std::enable_if_t::value && + !std::is_pointer_v> * = nullptr> auto MakeTensorView(Context const *ctx, Container &data, S &&...shape) { // NOLINT using T = typename Container::value_type; std::size_t in_shape[sizeof...(S)]; diff --git a/src/objective/adaptive.cc b/src/objective/adaptive.cc index 2d89c94593c4..4a67e848bb63 100644 --- a/src/objective/adaptive.cc +++ b/src/objective/adaptive.cc @@ -136,8 +136,7 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector const& posit #if !defined(XGBOOST_USE_CUDA) void UpdateTreeLeafDevice(Context const*, common::Span, std::int32_t, - MetaInfo const&, float learning_rate, HostDeviceVector const&, - float, RegTree*) { + MetaInfo const&, float, HostDeviceVector const&, float, RegTree*) { common::AssertGPUSupport(); } #endif // !defined(XGBOOST_USE_CUDA)