Extract transform iterator. (#8498)

This commit is contained in:
Jiaming Yuan 2022-12-05 21:37:07 +08:00 committed by GitHub
parent d8544e4d9e
commit e3bf5565ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 118 additions and 71 deletions

View File

@ -164,74 +164,6 @@ class Range {
Iterator end_;
};
/**
* \brief Transform iterator that takes an index and calls transform operator.
*
* This is CPU-only right now as taking host device function as operator complicates the
* code. For device side one can use `thrust::transform_iterator` instead.
*/
template <typename Fn>
class IndexTransformIter {
size_t iter_{0};
Fn fn_;
public:
using iterator_category = std::random_access_iterator_tag; // NOLINT
using value_type = std::result_of_t<Fn(size_t)>; // NOLINT
using difference_type = detail::ptrdiff_t; // NOLINT
using reference = std::add_lvalue_reference_t<value_type>; // NOLINT
using pointer = std::add_pointer_t<value_type>; // NOLINT
public:
/**
* \param op Transform operator, takes a size_t index as input.
*/
explicit IndexTransformIter(Fn &&op) : fn_{op} {}
IndexTransformIter(IndexTransformIter const &) = default;
IndexTransformIter& operator=(IndexTransformIter&&) = default;
IndexTransformIter& operator=(IndexTransformIter const& that) {
iter_ = that.iter_;
return *this;
}
value_type operator*() const { return fn_(iter_); }
auto operator-(IndexTransformIter const &that) const { return iter_ - that.iter_; }
bool operator==(IndexTransformIter const &that) const { return iter_ == that.iter_; }
bool operator!=(IndexTransformIter const &that) const { return !(*this == that); }
IndexTransformIter &operator++() {
iter_++;
return *this;
}
IndexTransformIter operator++(int) {
auto ret = *this;
++(*this);
return ret;
}
IndexTransformIter &operator+=(difference_type n) {
iter_ += n;
return *this;
}
IndexTransformIter &operator-=(difference_type n) {
(*this) += -n;
return *this;
}
IndexTransformIter operator+(difference_type n) const {
auto ret = *this;
return ret += n;
}
IndexTransformIter operator-(difference_type n) const {
auto ret = *this;
return ret -= n;
}
};
template <typename Fn>
auto MakeIndexTransformIter(Fn&& fn) {
return IndexTransformIter<Fn>(std::forward<Fn>(fn));
}
int AllVisibleGPUs();
inline void AssertGPUSupport() {

View File

@ -8,6 +8,7 @@
#include "common.h"
#include "threading_utils.h"
#include "transform_iterator.h" // MakeIndexTransformIter
#include "xgboost/generic_parameters.h"
#include "xgboost/linalg.h"

View File

@ -3,8 +3,8 @@
*/
#include <thrust/binary_search.h>
#include <thrust/execution_policy.h>
#include <thrust/iterator/discard_iterator.h>
#include <thrust/iterator/constant_iterator.h>
#include <thrust/iterator/discard_iterator.h>
#include <thrust/transform_scan.h>
#include <thrust/unique.h>
@ -20,6 +20,7 @@
#include "hist_util.h"
#include "quantile.cuh"
#include "quantile.h"
#include "transform_iterator.h" // MakeIndexTransformIter
#include "xgboost/span.h"
namespace xgboost {

View File

@ -8,7 +8,8 @@
#include <limits>
#include <vector>
#include "common.h" // AssertGPUSupport
#include "common.h" // AssertGPUSupport
#include "transform_iterator.h" // MakeIndexTransformIter
#include "xgboost/generic_parameters.h"
#include "xgboost/linalg.h"

View File

@ -0,0 +1,89 @@
/**
* Copyright 2022 by XGBoost Contributors
*/
#ifndef XGBOOST_COMMON_TRANSFORM_ITERATOR_H_
#define XGBOOST_COMMON_TRANSFORM_ITERATOR_H_
#include <cstddef> // std::size_t
#include <iterator> // std::random_access_iterator_tag
#include <type_traits> // std::result_of_t, std::add_pointer_t, std::add_lvalue_reference_t
#include <utility> // std::forward
#include "xgboost/span.h" // ptrdiff_t
namespace xgboost {
namespace common {
/**
* \brief Transform iterator that takes an index and calls transform operator.
*
* This is CPU-only right now as taking host device function as operator complicates the
* code. For device side one can use `thrust::transform_iterator` instead.
*/
template <typename Fn>
class IndexTransformIter {
std::size_t iter_{0};
Fn fn_;
public:
using iterator_category = std::random_access_iterator_tag; // NOLINT
using value_type = std::result_of_t<Fn(std::size_t)>; // NOLINT
using difference_type = detail::ptrdiff_t; // NOLINT
using reference = std::add_lvalue_reference_t<value_type>; // NOLINT
using pointer = std::add_pointer_t<value_type>; // NOLINT
public:
/**
* \param op Transform operator, takes a size_t index as input.
*/
explicit IndexTransformIter(Fn &&op) : fn_{op} {}
IndexTransformIter(IndexTransformIter const &) = default;
IndexTransformIter &operator=(IndexTransformIter &&) = default;
IndexTransformIter &operator=(IndexTransformIter const &that) {
iter_ = that.iter_;
return *this;
}
value_type operator*() const { return fn_(iter_); }
value_type operator[](std::size_t i) const {
auto iter = *this + i;
return *iter;
}
auto operator-(IndexTransformIter const &that) const { return iter_ - that.iter_; }
bool operator==(IndexTransformIter const &that) const { return iter_ == that.iter_; }
bool operator!=(IndexTransformIter const &that) const { return !(*this == that); }
IndexTransformIter &operator++() {
iter_++;
return *this;
}
IndexTransformIter operator++(int) {
auto ret = *this;
++(*this);
return ret;
}
IndexTransformIter &operator+=(difference_type n) {
iter_ += n;
return *this;
}
IndexTransformIter &operator-=(difference_type n) {
(*this) += -n;
return *this;
}
IndexTransformIter operator+(difference_type n) const {
auto ret = *this;
return ret += n;
}
IndexTransformIter operator-(difference_type n) const {
auto ret = *this;
return ret -= n;
}
};
template <typename Fn>
auto MakeIndexTransformIter(Fn &&fn) {
return IndexTransformIter<Fn>(std::forward<Fn>(fn));
}
} // namespace common
} // namespace xgboost
#endif // XGBOOST_COMMON_TRANSFORM_ITERATOR_H_

View File

@ -7,6 +7,7 @@
#include "../common/categorical.h"
#include "../common/hist_util.cuh"
#include "../common/random.h"
#include "../common/transform_iterator.h" // MakeIndexTransformIter
#include "./ellpack_page.cuh"
#include "device_adapter.cuh"
#include "gradient_index.h"

View File

@ -13,6 +13,7 @@
#include "../common/hist_util.h"
#include "../common/numeric.h"
#include "../common/threading_utils.h"
#include "../common/transform_iterator.h" // MakeIndexTransformIter
namespace xgboost {
@ -78,7 +79,7 @@ GHistIndexMatrix::~GHistIndexMatrix() = default;
void GHistIndexMatrix::PushBatch(SparsePage const &batch, common::Span<FeatureType const> ft,
int32_t n_threads) {
auto page = batch.GetView();
auto it = common::MakeIndexTransformIter([&](size_t ridx) { return page[ridx].size(); });
auto it = common::MakeIndexTransformIter([&](std::size_t ridx) { return page[ridx].size(); });
common::PartialSum(n_threads, it, it + page.Size(), static_cast<size_t>(0), row_ptr.begin());
data::SparsePageAdapterBatch adapter_batch{page};
auto is_valid = [](auto) { return true; }; // SparsePage always contains valid entries

View File

@ -15,6 +15,7 @@
#include "../common/hist_util.h"
#include "../common/numeric.h"
#include "../common/threading_utils.h"
#include "../common/transform_iterator.h" // MakeIndexTransformIter
#include "adapter.h"
#include "proxy_dmatrix.h"
#include "xgboost/base.h"

View File

@ -0,0 +1,20 @@
/**
* Copyright 2022 by XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <cstddef> // std::size_t
#include "../../../src/common/transform_iterator.h"
namespace xgboost {
namespace common {
TEST(IndexTransformIter, Basic) {
auto sqr = [](std::size_t i) { return i * i; };
auto iter = MakeIndexTransformIter(sqr);
for (std::size_t i = 0; i < 4; ++i) {
ASSERT_EQ(iter[i], sqr(i));
}
}
} // namespace common
} // namespace xgboost