From 88711695038d6a388a36ca37483e2ce24e5f570d Mon Sep 17 00:00:00 2001 From: jiamingy Date: Wed, 30 Nov 2022 15:54:34 +0800 Subject: [PATCH] Extract transform iterator. --- src/common/common.h | 68 ---------------- src/common/linalg_op.h | 1 + src/common/quantile.cu | 3 +- src/common/stats.h | 3 +- src/common/transform_iterator.h | 89 +++++++++++++++++++++ src/data/ellpack_page.cu | 1 + src/data/gradient_index.cc | 3 +- src/data/gradient_index.h | 1 + tests/cpp/common/test_transform_iterator.cc | 20 +++++ 9 files changed, 118 insertions(+), 71 deletions(-) create mode 100644 src/common/transform_iterator.h create mode 100644 tests/cpp/common/test_transform_iterator.cc diff --git a/src/common/common.h b/src/common/common.h index b2d7211c6932..438669e5fa3c 100644 --- a/src/common/common.h +++ b/src/common/common.h @@ -164,74 +164,6 @@ class Range { Iterator end_; }; -/** - * \brief Transform iterator that takes an index and calls transform operator. - * - * This is CPU-only right now as taking host device function as operator complicates the - * code. For device side one can use `thrust::transform_iterator` instead. - */ -template -class IndexTransformIter { - size_t iter_{0}; - Fn fn_; - - public: - using iterator_category = std::random_access_iterator_tag; // NOLINT - using value_type = std::result_of_t; // NOLINT - using difference_type = detail::ptrdiff_t; // NOLINT - using reference = std::add_lvalue_reference_t; // NOLINT - using pointer = std::add_pointer_t; // NOLINT - - public: - /** - * \param op Transform operator, takes a size_t index as input. - */ - explicit IndexTransformIter(Fn &&op) : fn_{op} {} - IndexTransformIter(IndexTransformIter const &) = default; - IndexTransformIter& operator=(IndexTransformIter&&) = default; - IndexTransformIter& operator=(IndexTransformIter const& that) { - iter_ = that.iter_; - return *this; - } - - value_type operator*() const { return fn_(iter_); } - - auto operator-(IndexTransformIter const &that) const { return iter_ - that.iter_; } - bool operator==(IndexTransformIter const &that) const { return iter_ == that.iter_; } - bool operator!=(IndexTransformIter const &that) const { return !(*this == that); } - - IndexTransformIter &operator++() { - iter_++; - return *this; - } - IndexTransformIter operator++(int) { - auto ret = *this; - ++(*this); - return ret; - } - IndexTransformIter &operator+=(difference_type n) { - iter_ += n; - return *this; - } - IndexTransformIter &operator-=(difference_type n) { - (*this) += -n; - return *this; - } - IndexTransformIter operator+(difference_type n) const { - auto ret = *this; - return ret += n; - } - IndexTransformIter operator-(difference_type n) const { - auto ret = *this; - return ret -= n; - } -}; - -template -auto MakeIndexTransformIter(Fn&& fn) { - return IndexTransformIter(std::forward(fn)); -} - int AllVisibleGPUs(); inline void AssertGPUSupport() { diff --git a/src/common/linalg_op.h b/src/common/linalg_op.h index efb9cf300238..0df7804757d2 100644 --- a/src/common/linalg_op.h +++ b/src/common/linalg_op.h @@ -8,6 +8,7 @@ #include "common.h" #include "threading_utils.h" +#include "transform_iterator.h" // MakeIndexTransformIter #include "xgboost/generic_parameters.h" #include "xgboost/linalg.h" diff --git a/src/common/quantile.cu b/src/common/quantile.cu index 36c980a8ebbc..06939e846891 100644 --- a/src/common/quantile.cu +++ b/src/common/quantile.cu @@ -3,8 +3,8 @@ */ #include #include -#include #include +#include #include #include @@ -20,6 +20,7 @@ #include "hist_util.h" #include "quantile.cuh" #include "quantile.h" +#include "transform_iterator.h" // MakeIndexTransformIter #include "xgboost/span.h" namespace xgboost { diff --git a/src/common/stats.h b/src/common/stats.h index c6347c421a9f..566b4be93273 100644 --- a/src/common/stats.h +++ b/src/common/stats.h @@ -8,7 +8,8 @@ #include #include -#include "common.h" // AssertGPUSupport +#include "common.h" // AssertGPUSupport +#include "transform_iterator.h" // MakeIndexTransformIter #include "xgboost/generic_parameters.h" #include "xgboost/linalg.h" diff --git a/src/common/transform_iterator.h b/src/common/transform_iterator.h new file mode 100644 index 000000000000..83fffe05af79 --- /dev/null +++ b/src/common/transform_iterator.h @@ -0,0 +1,89 @@ +/** + * Copyright 2022 by XGBoost Contributors + */ +#ifndef XGBOOST_COMMON_TRANSFORM_ITERATOR_H_ +#define XGBOOST_COMMON_TRANSFORM_ITERATOR_H_ + +#include // std::size_t +#include // std::random_access_iterator_tag +#include // std::result_of_t, std::add_pointer_t, std::add_lvalue_reference_t +#include // std::forward + +#include "xgboost/span.h" // ptrdiff_t + +namespace xgboost { +namespace common { +/** + * \brief Transform iterator that takes an index and calls transform operator. + * + * This is CPU-only right now as taking host device function as operator complicates the + * code. For device side one can use `thrust::transform_iterator` instead. + */ +template +class IndexTransformIter { + std::size_t iter_{0}; + Fn fn_; + + public: + using iterator_category = std::random_access_iterator_tag; // NOLINT + using value_type = std::result_of_t; // NOLINT + using difference_type = detail::ptrdiff_t; // NOLINT + using reference = std::add_lvalue_reference_t; // NOLINT + using pointer = std::add_pointer_t; // NOLINT + + public: + /** + * \param op Transform operator, takes a size_t index as input. + */ + explicit IndexTransformIter(Fn &&op) : fn_{op} {} + IndexTransformIter(IndexTransformIter const &) = default; + IndexTransformIter &operator=(IndexTransformIter &&) = default; + IndexTransformIter &operator=(IndexTransformIter const &that) { + iter_ = that.iter_; + return *this; + } + + value_type operator*() const { return fn_(iter_); } + value_type operator[](std::size_t i) const { + auto iter = *this + i; + return *iter; + } + + auto operator-(IndexTransformIter const &that) const { return iter_ - that.iter_; } + bool operator==(IndexTransformIter const &that) const { return iter_ == that.iter_; } + bool operator!=(IndexTransformIter const &that) const { return !(*this == that); } + + IndexTransformIter &operator++() { + iter_++; + return *this; + } + IndexTransformIter operator++(int) { + auto ret = *this; + ++(*this); + return ret; + } + IndexTransformIter &operator+=(difference_type n) { + iter_ += n; + return *this; + } + IndexTransformIter &operator-=(difference_type n) { + (*this) += -n; + return *this; + } + IndexTransformIter operator+(difference_type n) const { + auto ret = *this; + return ret += n; + } + IndexTransformIter operator-(difference_type n) const { + auto ret = *this; + return ret -= n; + } +}; + +template +auto MakeIndexTransformIter(Fn &&fn) { + return IndexTransformIter(std::forward(fn)); +} +} // namespace common +} // namespace xgboost +#endif // XGBOOST_COMMON_TRANSFORM_ITERATOR_H_ diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index 076b8ed4b1be..4176e9ae4f77 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -7,6 +7,7 @@ #include "../common/categorical.h" #include "../common/hist_util.cuh" #include "../common/random.h" +#include "../common/transform_iterator.h" // MakeIndexTransformIter #include "./ellpack_page.cuh" #include "device_adapter.cuh" #include "gradient_index.h" diff --git a/src/data/gradient_index.cc b/src/data/gradient_index.cc index 2e9d38a1918c..140bcbff9878 100644 --- a/src/data/gradient_index.cc +++ b/src/data/gradient_index.cc @@ -13,6 +13,7 @@ #include "../common/hist_util.h" #include "../common/numeric.h" #include "../common/threading_utils.h" +#include "../common/transform_iterator.h" // MakeIndexTransformIter namespace xgboost { @@ -78,7 +79,7 @@ GHistIndexMatrix::~GHistIndexMatrix() = default; void GHistIndexMatrix::PushBatch(SparsePage const &batch, common::Span ft, int32_t n_threads) { auto page = batch.GetView(); - auto it = common::MakeIndexTransformIter([&](size_t ridx) { return page[ridx].size(); }); + auto it = common::MakeIndexTransformIter([&](std::size_t ridx) { return page[ridx].size(); }); common::PartialSum(n_threads, it, it + page.Size(), static_cast(0), row_ptr.begin()); data::SparsePageAdapterBatch adapter_batch{page}; auto is_valid = [](auto) { return true; }; // SparsePage always contains valid entries diff --git a/src/data/gradient_index.h b/src/data/gradient_index.h index 10d9e770dc13..73a17e359ab9 100644 --- a/src/data/gradient_index.h +++ b/src/data/gradient_index.h @@ -15,6 +15,7 @@ #include "../common/hist_util.h" #include "../common/numeric.h" #include "../common/threading_utils.h" +#include "../common/transform_iterator.h" // MakeIndexTransformIter #include "adapter.h" #include "proxy_dmatrix.h" #include "xgboost/base.h" diff --git a/tests/cpp/common/test_transform_iterator.cc b/tests/cpp/common/test_transform_iterator.cc new file mode 100644 index 000000000000..3ec115f8f3d6 --- /dev/null +++ b/tests/cpp/common/test_transform_iterator.cc @@ -0,0 +1,20 @@ +/** + * Copyright 2022 by XGBoost Contributors + */ +#include + +#include // std::size_t + +#include "../../../src/common/transform_iterator.h" + +namespace xgboost { +namespace common { +TEST(IndexTransformIter, Basic) { + auto sqr = [](std::size_t i) { return i * i; }; + auto iter = MakeIndexTransformIter(sqr); + for (std::size_t i = 0; i < 4; ++i) { + ASSERT_EQ(iter[i], sqr(i)); + } +} +} // namespace common +} // namespace xgboost