Extract transform iterator. (#8498)
This commit is contained in:
parent
d8544e4d9e
commit
e3bf5565ab
@ -164,74 +164,6 @@ class Range {
|
||||
Iterator end_;
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief Transform iterator that takes an index and calls transform operator.
|
||||
*
|
||||
* This is CPU-only right now as taking host device function as operator complicates the
|
||||
* code. For device side one can use `thrust::transform_iterator` instead.
|
||||
*/
|
||||
template <typename Fn>
|
||||
class IndexTransformIter {
|
||||
size_t iter_{0};
|
||||
Fn fn_;
|
||||
|
||||
public:
|
||||
using iterator_category = std::random_access_iterator_tag; // NOLINT
|
||||
using value_type = std::result_of_t<Fn(size_t)>; // NOLINT
|
||||
using difference_type = detail::ptrdiff_t; // NOLINT
|
||||
using reference = std::add_lvalue_reference_t<value_type>; // NOLINT
|
||||
using pointer = std::add_pointer_t<value_type>; // NOLINT
|
||||
|
||||
public:
|
||||
/**
|
||||
* \param op Transform operator, takes a size_t index as input.
|
||||
*/
|
||||
explicit IndexTransformIter(Fn &&op) : fn_{op} {}
|
||||
IndexTransformIter(IndexTransformIter const &) = default;
|
||||
IndexTransformIter& operator=(IndexTransformIter&&) = default;
|
||||
IndexTransformIter& operator=(IndexTransformIter const& that) {
|
||||
iter_ = that.iter_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
value_type operator*() const { return fn_(iter_); }
|
||||
|
||||
auto operator-(IndexTransformIter const &that) const { return iter_ - that.iter_; }
|
||||
bool operator==(IndexTransformIter const &that) const { return iter_ == that.iter_; }
|
||||
bool operator!=(IndexTransformIter const &that) const { return !(*this == that); }
|
||||
|
||||
IndexTransformIter &operator++() {
|
||||
iter_++;
|
||||
return *this;
|
||||
}
|
||||
IndexTransformIter operator++(int) {
|
||||
auto ret = *this;
|
||||
++(*this);
|
||||
return ret;
|
||||
}
|
||||
IndexTransformIter &operator+=(difference_type n) {
|
||||
iter_ += n;
|
||||
return *this;
|
||||
}
|
||||
IndexTransformIter &operator-=(difference_type n) {
|
||||
(*this) += -n;
|
||||
return *this;
|
||||
}
|
||||
IndexTransformIter operator+(difference_type n) const {
|
||||
auto ret = *this;
|
||||
return ret += n;
|
||||
}
|
||||
IndexTransformIter operator-(difference_type n) const {
|
||||
auto ret = *this;
|
||||
return ret -= n;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Fn>
|
||||
auto MakeIndexTransformIter(Fn&& fn) {
|
||||
return IndexTransformIter<Fn>(std::forward<Fn>(fn));
|
||||
}
|
||||
|
||||
int AllVisibleGPUs();
|
||||
|
||||
inline void AssertGPUSupport() {
|
||||
|
||||
@ -8,6 +8,7 @@
|
||||
|
||||
#include "common.h"
|
||||
#include "threading_utils.h"
|
||||
#include "transform_iterator.h" // MakeIndexTransformIter
|
||||
#include "xgboost/generic_parameters.h"
|
||||
#include "xgboost/linalg.h"
|
||||
|
||||
|
||||
@ -3,8 +3,8 @@
|
||||
*/
|
||||
#include <thrust/binary_search.h>
|
||||
#include <thrust/execution_policy.h>
|
||||
#include <thrust/iterator/discard_iterator.h>
|
||||
#include <thrust/iterator/constant_iterator.h>
|
||||
#include <thrust/iterator/discard_iterator.h>
|
||||
#include <thrust/transform_scan.h>
|
||||
#include <thrust/unique.h>
|
||||
|
||||
@ -20,6 +20,7 @@
|
||||
#include "hist_util.h"
|
||||
#include "quantile.cuh"
|
||||
#include "quantile.h"
|
||||
#include "transform_iterator.h" // MakeIndexTransformIter
|
||||
#include "xgboost/span.h"
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
@ -8,7 +8,8 @@
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include "common.h" // AssertGPUSupport
|
||||
#include "common.h" // AssertGPUSupport
|
||||
#include "transform_iterator.h" // MakeIndexTransformIter
|
||||
#include "xgboost/generic_parameters.h"
|
||||
#include "xgboost/linalg.h"
|
||||
|
||||
|
||||
89
src/common/transform_iterator.h
Normal file
89
src/common/transform_iterator.h
Normal file
@ -0,0 +1,89 @@
|
||||
/**
|
||||
* Copyright 2022 by XGBoost Contributors
|
||||
*/
|
||||
#ifndef XGBOOST_COMMON_TRANSFORM_ITERATOR_H_
|
||||
#define XGBOOST_COMMON_TRANSFORM_ITERATOR_H_
|
||||
|
||||
#include <cstddef> // std::size_t
|
||||
#include <iterator> // std::random_access_iterator_tag
|
||||
#include <type_traits> // std::result_of_t, std::add_pointer_t, std::add_lvalue_reference_t
|
||||
#include <utility> // std::forward
|
||||
|
||||
#include "xgboost/span.h" // ptrdiff_t
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
/**
|
||||
* \brief Transform iterator that takes an index and calls transform operator.
|
||||
*
|
||||
* This is CPU-only right now as taking host device function as operator complicates the
|
||||
* code. For device side one can use `thrust::transform_iterator` instead.
|
||||
*/
|
||||
template <typename Fn>
|
||||
class IndexTransformIter {
|
||||
std::size_t iter_{0};
|
||||
Fn fn_;
|
||||
|
||||
public:
|
||||
using iterator_category = std::random_access_iterator_tag; // NOLINT
|
||||
using value_type = std::result_of_t<Fn(std::size_t)>; // NOLINT
|
||||
using difference_type = detail::ptrdiff_t; // NOLINT
|
||||
using reference = std::add_lvalue_reference_t<value_type>; // NOLINT
|
||||
using pointer = std::add_pointer_t<value_type>; // NOLINT
|
||||
|
||||
public:
|
||||
/**
|
||||
* \param op Transform operator, takes a size_t index as input.
|
||||
*/
|
||||
explicit IndexTransformIter(Fn &&op) : fn_{op} {}
|
||||
IndexTransformIter(IndexTransformIter const &) = default;
|
||||
IndexTransformIter &operator=(IndexTransformIter &&) = default;
|
||||
IndexTransformIter &operator=(IndexTransformIter const &that) {
|
||||
iter_ = that.iter_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
value_type operator*() const { return fn_(iter_); }
|
||||
value_type operator[](std::size_t i) const {
|
||||
auto iter = *this + i;
|
||||
return *iter;
|
||||
}
|
||||
|
||||
auto operator-(IndexTransformIter const &that) const { return iter_ - that.iter_; }
|
||||
bool operator==(IndexTransformIter const &that) const { return iter_ == that.iter_; }
|
||||
bool operator!=(IndexTransformIter const &that) const { return !(*this == that); }
|
||||
|
||||
IndexTransformIter &operator++() {
|
||||
iter_++;
|
||||
return *this;
|
||||
}
|
||||
IndexTransformIter operator++(int) {
|
||||
auto ret = *this;
|
||||
++(*this);
|
||||
return ret;
|
||||
}
|
||||
IndexTransformIter &operator+=(difference_type n) {
|
||||
iter_ += n;
|
||||
return *this;
|
||||
}
|
||||
IndexTransformIter &operator-=(difference_type n) {
|
||||
(*this) += -n;
|
||||
return *this;
|
||||
}
|
||||
IndexTransformIter operator+(difference_type n) const {
|
||||
auto ret = *this;
|
||||
return ret += n;
|
||||
}
|
||||
IndexTransformIter operator-(difference_type n) const {
|
||||
auto ret = *this;
|
||||
return ret -= n;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Fn>
|
||||
auto MakeIndexTransformIter(Fn &&fn) {
|
||||
return IndexTransformIter<Fn>(std::forward<Fn>(fn));
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_COMMON_TRANSFORM_ITERATOR_H_
|
||||
@ -7,6 +7,7 @@
|
||||
#include "../common/categorical.h"
|
||||
#include "../common/hist_util.cuh"
|
||||
#include "../common/random.h"
|
||||
#include "../common/transform_iterator.h" // MakeIndexTransformIter
|
||||
#include "./ellpack_page.cuh"
|
||||
#include "device_adapter.cuh"
|
||||
#include "gradient_index.h"
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
#include "../common/hist_util.h"
|
||||
#include "../common/numeric.h"
|
||||
#include "../common/threading_utils.h"
|
||||
#include "../common/transform_iterator.h" // MakeIndexTransformIter
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
@ -78,7 +79,7 @@ GHistIndexMatrix::~GHistIndexMatrix() = default;
|
||||
void GHistIndexMatrix::PushBatch(SparsePage const &batch, common::Span<FeatureType const> ft,
|
||||
int32_t n_threads) {
|
||||
auto page = batch.GetView();
|
||||
auto it = common::MakeIndexTransformIter([&](size_t ridx) { return page[ridx].size(); });
|
||||
auto it = common::MakeIndexTransformIter([&](std::size_t ridx) { return page[ridx].size(); });
|
||||
common::PartialSum(n_threads, it, it + page.Size(), static_cast<size_t>(0), row_ptr.begin());
|
||||
data::SparsePageAdapterBatch adapter_batch{page};
|
||||
auto is_valid = [](auto) { return true; }; // SparsePage always contains valid entries
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
#include "../common/hist_util.h"
|
||||
#include "../common/numeric.h"
|
||||
#include "../common/threading_utils.h"
|
||||
#include "../common/transform_iterator.h" // MakeIndexTransformIter
|
||||
#include "adapter.h"
|
||||
#include "proxy_dmatrix.h"
|
||||
#include "xgboost/base.h"
|
||||
|
||||
20
tests/cpp/common/test_transform_iterator.cc
Normal file
20
tests/cpp/common/test_transform_iterator.cc
Normal file
@ -0,0 +1,20 @@
|
||||
/**
|
||||
* Copyright 2022 by XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <cstddef> // std::size_t
|
||||
|
||||
#include "../../../src/common/transform_iterator.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
TEST(IndexTransformIter, Basic) {
|
||||
auto sqr = [](std::size_t i) { return i * i; };
|
||||
auto iter = MakeIndexTransformIter(sqr);
|
||||
for (std::size_t i = 0; i < 4; ++i) {
|
||||
ASSERT_EQ(iter[i], sqr(i));
|
||||
}
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
Loading…
x
Reference in New Issue
Block a user