Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract transform iterator. #8498

Merged
merged 1 commit into from
Dec 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 0 additions & 68 deletions src/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,74 +164,6 @@ class Range {
Iterator end_;
};

/**
* \brief Transform iterator that takes an index and calls transform operator.
*
* This is CPU-only right now as taking host device function as operator complicates the
* code. For device side one can use `thrust::transform_iterator` instead.
*/
template <typename Fn>
class IndexTransformIter {
size_t iter_{0};
Fn fn_;

public:
using iterator_category = std::random_access_iterator_tag; // NOLINT
using value_type = std::result_of_t<Fn(size_t)>; // NOLINT
using difference_type = detail::ptrdiff_t; // NOLINT
using reference = std::add_lvalue_reference_t<value_type>; // NOLINT
using pointer = std::add_pointer_t<value_type>; // NOLINT

public:
/**
* \param op Transform operator, takes a size_t index as input.
*/
explicit IndexTransformIter(Fn &&op) : fn_{op} {}
IndexTransformIter(IndexTransformIter const &) = default;
IndexTransformIter& operator=(IndexTransformIter&&) = default;
IndexTransformIter& operator=(IndexTransformIter const& that) {
iter_ = that.iter_;
return *this;
}

value_type operator*() const { return fn_(iter_); }

auto operator-(IndexTransformIter const &that) const { return iter_ - that.iter_; }
bool operator==(IndexTransformIter const &that) const { return iter_ == that.iter_; }
bool operator!=(IndexTransformIter const &that) const { return !(*this == that); }

IndexTransformIter &operator++() {
iter_++;
return *this;
}
IndexTransformIter operator++(int) {
auto ret = *this;
++(*this);
return ret;
}
IndexTransformIter &operator+=(difference_type n) {
iter_ += n;
return *this;
}
IndexTransformIter &operator-=(difference_type n) {
(*this) += -n;
return *this;
}
IndexTransformIter operator+(difference_type n) const {
auto ret = *this;
return ret += n;
}
IndexTransformIter operator-(difference_type n) const {
auto ret = *this;
return ret -= n;
}
};

template <typename Fn>
auto MakeIndexTransformIter(Fn&& fn) {
return IndexTransformIter<Fn>(std::forward<Fn>(fn));
}

int AllVisibleGPUs();

inline void AssertGPUSupport() {
Expand Down
1 change: 1 addition & 0 deletions src/common/linalg_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "common.h"
#include "threading_utils.h"
#include "transform_iterator.h" // MakeIndexTransformIter
#include "xgboost/generic_parameters.h"
#include "xgboost/linalg.h"

Expand Down
3 changes: 2 additions & 1 deletion src/common/quantile.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
*/
#include <thrust/binary_search.h>
#include <thrust/execution_policy.h>
#include <thrust/iterator/discard_iterator.h>
#include <thrust/iterator/constant_iterator.h>
#include <thrust/iterator/discard_iterator.h>
#include <thrust/transform_scan.h>
#include <thrust/unique.h>

Expand All @@ -20,6 +20,7 @@
#include "hist_util.h"
#include "quantile.cuh"
#include "quantile.h"
#include "transform_iterator.h" // MakeIndexTransformIter
#include "xgboost/span.h"

namespace xgboost {
Expand Down
3 changes: 2 additions & 1 deletion src/common/stats.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
#include <limits>
#include <vector>

#include "common.h" // AssertGPUSupport
#include "common.h" // AssertGPUSupport
#include "transform_iterator.h" // MakeIndexTransformIter
#include "xgboost/generic_parameters.h"
#include "xgboost/linalg.h"

Expand Down
89 changes: 89 additions & 0 deletions src/common/transform_iterator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/**
* Copyright 2022 by XGBoost Contributors
*/
#ifndef XGBOOST_COMMON_TRANSFORM_ITERATOR_H_
#define XGBOOST_COMMON_TRANSFORM_ITERATOR_H_

#include <cstddef> // std::size_t
#include <iterator> // std::random_access_iterator_tag
#include <type_traits> // std::result_of_t, std::add_pointer_t, std::add_lvalue_reference_t
#include <utility> // std::forward

#include "xgboost/span.h" // ptrdiff_t

namespace xgboost {
namespace common {
/**
* \brief Transform iterator that takes an index and calls transform operator.
*
* This is CPU-only right now as taking host device function as operator complicates the
* code. For device side one can use `thrust::transform_iterator` instead.
*/
template <typename Fn>
class IndexTransformIter {
std::size_t iter_{0};
Fn fn_;

public:
using iterator_category = std::random_access_iterator_tag; // NOLINT
using value_type = std::result_of_t<Fn(std::size_t)>; // NOLINT
using difference_type = detail::ptrdiff_t; // NOLINT
using reference = std::add_lvalue_reference_t<value_type>; // NOLINT
using pointer = std::add_pointer_t<value_type>; // NOLINT

public:
/**
* \param op Transform operator, takes a size_t index as input.
*/
explicit IndexTransformIter(Fn &&op) : fn_{op} {}
IndexTransformIter(IndexTransformIter const &) = default;
IndexTransformIter &operator=(IndexTransformIter &&) = default;
IndexTransformIter &operator=(IndexTransformIter const &that) {
iter_ = that.iter_;
return *this;
}

value_type operator*() const { return fn_(iter_); }
value_type operator[](std::size_t i) const {
auto iter = *this + i;
return *iter;
}

auto operator-(IndexTransformIter const &that) const { return iter_ - that.iter_; }
bool operator==(IndexTransformIter const &that) const { return iter_ == that.iter_; }
bool operator!=(IndexTransformIter const &that) const { return !(*this == that); }

IndexTransformIter &operator++() {
iter_++;
return *this;
}
IndexTransformIter operator++(int) {
auto ret = *this;
++(*this);
return ret;
}
IndexTransformIter &operator+=(difference_type n) {
iter_ += n;
return *this;
}
IndexTransformIter &operator-=(difference_type n) {
(*this) += -n;
return *this;
}
IndexTransformIter operator+(difference_type n) const {
auto ret = *this;
return ret += n;
}
IndexTransformIter operator-(difference_type n) const {
auto ret = *this;
return ret -= n;
}
};

template <typename Fn>
auto MakeIndexTransformIter(Fn &&fn) {
return IndexTransformIter<Fn>(std::forward<Fn>(fn));
}
} // namespace common
} // namespace xgboost
#endif // XGBOOST_COMMON_TRANSFORM_ITERATOR_H_
1 change: 1 addition & 0 deletions src/data/ellpack_page.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "../common/categorical.h"
#include "../common/hist_util.cuh"
#include "../common/random.h"
#include "../common/transform_iterator.h" // MakeIndexTransformIter
#include "./ellpack_page.cuh"
#include "device_adapter.cuh"
#include "gradient_index.h"
Expand Down
3 changes: 2 additions & 1 deletion src/data/gradient_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "../common/hist_util.h"
#include "../common/numeric.h"
#include "../common/threading_utils.h"
#include "../common/transform_iterator.h" // MakeIndexTransformIter

namespace xgboost {

Expand Down Expand Up @@ -78,7 +79,7 @@ GHistIndexMatrix::~GHistIndexMatrix() = default;
void GHistIndexMatrix::PushBatch(SparsePage const &batch, common::Span<FeatureType const> ft,
int32_t n_threads) {
auto page = batch.GetView();
auto it = common::MakeIndexTransformIter([&](size_t ridx) { return page[ridx].size(); });
auto it = common::MakeIndexTransformIter([&](std::size_t ridx) { return page[ridx].size(); });
common::PartialSum(n_threads, it, it + page.Size(), static_cast<size_t>(0), row_ptr.begin());
data::SparsePageAdapterBatch adapter_batch{page};
auto is_valid = [](auto) { return true; }; // SparsePage always contains valid entries
Expand Down
1 change: 1 addition & 0 deletions src/data/gradient_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "../common/hist_util.h"
#include "../common/numeric.h"
#include "../common/threading_utils.h"
#include "../common/transform_iterator.h" // MakeIndexTransformIter
#include "adapter.h"
#include "proxy_dmatrix.h"
#include "xgboost/base.h"
Expand Down
20 changes: 20 additions & 0 deletions tests/cpp/common/test_transform_iterator.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/**
* Copyright 2022 by XGBoost Contributors
*/
#include <gtest/gtest.h>

#include <cstddef> // std::size_t

#include "../../../src/common/transform_iterator.h"

namespace xgboost {
namespace common {
TEST(IndexTransformIter, Basic) {
auto sqr = [](std::size_t i) { return i * i; };
auto iter = MakeIndexTransformIter(sqr);
for (std::size_t i = 0; i < 4; ++i) {
ASSERT_EQ(iter[i], sqr(i));
}
}
} // namespace common
} // namespace xgboost