Extract transform iterator. (#8498)
This commit is contained in:
parent
d8544e4d9e
commit
e3bf5565ab
@ -164,74 +164,6 @@ class Range {
|
|||||||
Iterator end_;
|
Iterator end_;
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
|
||||||
* \brief Transform iterator that takes an index and calls transform operator.
|
|
||||||
*
|
|
||||||
* This is CPU-only right now as taking host device function as operator complicates the
|
|
||||||
* code. For device side one can use `thrust::transform_iterator` instead.
|
|
||||||
*/
|
|
||||||
template <typename Fn>
|
|
||||||
class IndexTransformIter {
|
|
||||||
size_t iter_{0};
|
|
||||||
Fn fn_;
|
|
||||||
|
|
||||||
public:
|
|
||||||
using iterator_category = std::random_access_iterator_tag; // NOLINT
|
|
||||||
using value_type = std::result_of_t<Fn(size_t)>; // NOLINT
|
|
||||||
using difference_type = detail::ptrdiff_t; // NOLINT
|
|
||||||
using reference = std::add_lvalue_reference_t<value_type>; // NOLINT
|
|
||||||
using pointer = std::add_pointer_t<value_type>; // NOLINT
|
|
||||||
|
|
||||||
public:
|
|
||||||
/**
|
|
||||||
* \param op Transform operator, takes a size_t index as input.
|
|
||||||
*/
|
|
||||||
explicit IndexTransformIter(Fn &&op) : fn_{op} {}
|
|
||||||
IndexTransformIter(IndexTransformIter const &) = default;
|
|
||||||
IndexTransformIter& operator=(IndexTransformIter&&) = default;
|
|
||||||
IndexTransformIter& operator=(IndexTransformIter const& that) {
|
|
||||||
iter_ = that.iter_;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
value_type operator*() const { return fn_(iter_); }
|
|
||||||
|
|
||||||
auto operator-(IndexTransformIter const &that) const { return iter_ - that.iter_; }
|
|
||||||
bool operator==(IndexTransformIter const &that) const { return iter_ == that.iter_; }
|
|
||||||
bool operator!=(IndexTransformIter const &that) const { return !(*this == that); }
|
|
||||||
|
|
||||||
IndexTransformIter &operator++() {
|
|
||||||
iter_++;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
IndexTransformIter operator++(int) {
|
|
||||||
auto ret = *this;
|
|
||||||
++(*this);
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
IndexTransformIter &operator+=(difference_type n) {
|
|
||||||
iter_ += n;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
IndexTransformIter &operator-=(difference_type n) {
|
|
||||||
(*this) += -n;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
IndexTransformIter operator+(difference_type n) const {
|
|
||||||
auto ret = *this;
|
|
||||||
return ret += n;
|
|
||||||
}
|
|
||||||
IndexTransformIter operator-(difference_type n) const {
|
|
||||||
auto ret = *this;
|
|
||||||
return ret -= n;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename Fn>
|
|
||||||
auto MakeIndexTransformIter(Fn&& fn) {
|
|
||||||
return IndexTransformIter<Fn>(std::forward<Fn>(fn));
|
|
||||||
}
|
|
||||||
|
|
||||||
int AllVisibleGPUs();
|
int AllVisibleGPUs();
|
||||||
|
|
||||||
inline void AssertGPUSupport() {
|
inline void AssertGPUSupport() {
|
||||||
|
|||||||
@ -8,6 +8,7 @@
|
|||||||
|
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "threading_utils.h"
|
#include "threading_utils.h"
|
||||||
|
#include "transform_iterator.h" // MakeIndexTransformIter
|
||||||
#include "xgboost/generic_parameters.h"
|
#include "xgboost/generic_parameters.h"
|
||||||
#include "xgboost/linalg.h"
|
#include "xgboost/linalg.h"
|
||||||
|
|
||||||
|
|||||||
@ -3,8 +3,8 @@
|
|||||||
*/
|
*/
|
||||||
#include <thrust/binary_search.h>
|
#include <thrust/binary_search.h>
|
||||||
#include <thrust/execution_policy.h>
|
#include <thrust/execution_policy.h>
|
||||||
#include <thrust/iterator/discard_iterator.h>
|
|
||||||
#include <thrust/iterator/constant_iterator.h>
|
#include <thrust/iterator/constant_iterator.h>
|
||||||
|
#include <thrust/iterator/discard_iterator.h>
|
||||||
#include <thrust/transform_scan.h>
|
#include <thrust/transform_scan.h>
|
||||||
#include <thrust/unique.h>
|
#include <thrust/unique.h>
|
||||||
|
|
||||||
@ -20,6 +20,7 @@
|
|||||||
#include "hist_util.h"
|
#include "hist_util.h"
|
||||||
#include "quantile.cuh"
|
#include "quantile.cuh"
|
||||||
#include "quantile.h"
|
#include "quantile.h"
|
||||||
|
#include "transform_iterator.h" // MakeIndexTransformIter
|
||||||
#include "xgboost/span.h"
|
#include "xgboost/span.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
|
|||||||
@ -8,7 +8,8 @@
|
|||||||
#include <limits>
|
#include <limits>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "common.h" // AssertGPUSupport
|
#include "common.h" // AssertGPUSupport
|
||||||
|
#include "transform_iterator.h" // MakeIndexTransformIter
|
||||||
#include "xgboost/generic_parameters.h"
|
#include "xgboost/generic_parameters.h"
|
||||||
#include "xgboost/linalg.h"
|
#include "xgboost/linalg.h"
|
||||||
|
|
||||||
|
|||||||
89
src/common/transform_iterator.h
Normal file
89
src/common/transform_iterator.h
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2022 by XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#ifndef XGBOOST_COMMON_TRANSFORM_ITERATOR_H_
|
||||||
|
#define XGBOOST_COMMON_TRANSFORM_ITERATOR_H_
|
||||||
|
|
||||||
|
#include <cstddef> // std::size_t
|
||||||
|
#include <iterator> // std::random_access_iterator_tag
|
||||||
|
#include <type_traits> // std::result_of_t, std::add_pointer_t, std::add_lvalue_reference_t
|
||||||
|
#include <utility> // std::forward
|
||||||
|
|
||||||
|
#include "xgboost/span.h" // ptrdiff_t
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace common {
|
||||||
|
/**
|
||||||
|
* \brief Transform iterator that takes an index and calls transform operator.
|
||||||
|
*
|
||||||
|
* This is CPU-only right now as taking host device function as operator complicates the
|
||||||
|
* code. For device side one can use `thrust::transform_iterator` instead.
|
||||||
|
*/
|
||||||
|
template <typename Fn>
|
||||||
|
class IndexTransformIter {
|
||||||
|
std::size_t iter_{0};
|
||||||
|
Fn fn_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
using iterator_category = std::random_access_iterator_tag; // NOLINT
|
||||||
|
using value_type = std::result_of_t<Fn(std::size_t)>; // NOLINT
|
||||||
|
using difference_type = detail::ptrdiff_t; // NOLINT
|
||||||
|
using reference = std::add_lvalue_reference_t<value_type>; // NOLINT
|
||||||
|
using pointer = std::add_pointer_t<value_type>; // NOLINT
|
||||||
|
|
||||||
|
public:
|
||||||
|
/**
|
||||||
|
* \param op Transform operator, takes a size_t index as input.
|
||||||
|
*/
|
||||||
|
explicit IndexTransformIter(Fn &&op) : fn_{op} {}
|
||||||
|
IndexTransformIter(IndexTransformIter const &) = default;
|
||||||
|
IndexTransformIter &operator=(IndexTransformIter &&) = default;
|
||||||
|
IndexTransformIter &operator=(IndexTransformIter const &that) {
|
||||||
|
iter_ = that.iter_;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
value_type operator*() const { return fn_(iter_); }
|
||||||
|
value_type operator[](std::size_t i) const {
|
||||||
|
auto iter = *this + i;
|
||||||
|
return *iter;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto operator-(IndexTransformIter const &that) const { return iter_ - that.iter_; }
|
||||||
|
bool operator==(IndexTransformIter const &that) const { return iter_ == that.iter_; }
|
||||||
|
bool operator!=(IndexTransformIter const &that) const { return !(*this == that); }
|
||||||
|
|
||||||
|
IndexTransformIter &operator++() {
|
||||||
|
iter_++;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
IndexTransformIter operator++(int) {
|
||||||
|
auto ret = *this;
|
||||||
|
++(*this);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
IndexTransformIter &operator+=(difference_type n) {
|
||||||
|
iter_ += n;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
IndexTransformIter &operator-=(difference_type n) {
|
||||||
|
(*this) += -n;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
IndexTransformIter operator+(difference_type n) const {
|
||||||
|
auto ret = *this;
|
||||||
|
return ret += n;
|
||||||
|
}
|
||||||
|
IndexTransformIter operator-(difference_type n) const {
|
||||||
|
auto ret = *this;
|
||||||
|
return ret -= n;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Fn>
|
||||||
|
auto MakeIndexTransformIter(Fn &&fn) {
|
||||||
|
return IndexTransformIter<Fn>(std::forward<Fn>(fn));
|
||||||
|
}
|
||||||
|
} // namespace common
|
||||||
|
} // namespace xgboost
|
||||||
|
#endif // XGBOOST_COMMON_TRANSFORM_ITERATOR_H_
|
||||||
@ -7,6 +7,7 @@
|
|||||||
#include "../common/categorical.h"
|
#include "../common/categorical.h"
|
||||||
#include "../common/hist_util.cuh"
|
#include "../common/hist_util.cuh"
|
||||||
#include "../common/random.h"
|
#include "../common/random.h"
|
||||||
|
#include "../common/transform_iterator.h" // MakeIndexTransformIter
|
||||||
#include "./ellpack_page.cuh"
|
#include "./ellpack_page.cuh"
|
||||||
#include "device_adapter.cuh"
|
#include "device_adapter.cuh"
|
||||||
#include "gradient_index.h"
|
#include "gradient_index.h"
|
||||||
|
|||||||
@ -13,6 +13,7 @@
|
|||||||
#include "../common/hist_util.h"
|
#include "../common/hist_util.h"
|
||||||
#include "../common/numeric.h"
|
#include "../common/numeric.h"
|
||||||
#include "../common/threading_utils.h"
|
#include "../common/threading_utils.h"
|
||||||
|
#include "../common/transform_iterator.h" // MakeIndexTransformIter
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
|
|
||||||
@ -78,7 +79,7 @@ GHistIndexMatrix::~GHistIndexMatrix() = default;
|
|||||||
void GHistIndexMatrix::PushBatch(SparsePage const &batch, common::Span<FeatureType const> ft,
|
void GHistIndexMatrix::PushBatch(SparsePage const &batch, common::Span<FeatureType const> ft,
|
||||||
int32_t n_threads) {
|
int32_t n_threads) {
|
||||||
auto page = batch.GetView();
|
auto page = batch.GetView();
|
||||||
auto it = common::MakeIndexTransformIter([&](size_t ridx) { return page[ridx].size(); });
|
auto it = common::MakeIndexTransformIter([&](std::size_t ridx) { return page[ridx].size(); });
|
||||||
common::PartialSum(n_threads, it, it + page.Size(), static_cast<size_t>(0), row_ptr.begin());
|
common::PartialSum(n_threads, it, it + page.Size(), static_cast<size_t>(0), row_ptr.begin());
|
||||||
data::SparsePageAdapterBatch adapter_batch{page};
|
data::SparsePageAdapterBatch adapter_batch{page};
|
||||||
auto is_valid = [](auto) { return true; }; // SparsePage always contains valid entries
|
auto is_valid = [](auto) { return true; }; // SparsePage always contains valid entries
|
||||||
|
|||||||
@ -15,6 +15,7 @@
|
|||||||
#include "../common/hist_util.h"
|
#include "../common/hist_util.h"
|
||||||
#include "../common/numeric.h"
|
#include "../common/numeric.h"
|
||||||
#include "../common/threading_utils.h"
|
#include "../common/threading_utils.h"
|
||||||
|
#include "../common/transform_iterator.h" // MakeIndexTransformIter
|
||||||
#include "adapter.h"
|
#include "adapter.h"
|
||||||
#include "proxy_dmatrix.h"
|
#include "proxy_dmatrix.h"
|
||||||
#include "xgboost/base.h"
|
#include "xgboost/base.h"
|
||||||
|
|||||||
20
tests/cpp/common/test_transform_iterator.cc
Normal file
20
tests/cpp/common/test_transform_iterator.cc
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2022 by XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <cstddef> // std::size_t
|
||||||
|
|
||||||
|
#include "../../../src/common/transform_iterator.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace common {
|
||||||
|
TEST(IndexTransformIter, Basic) {
|
||||||
|
auto sqr = [](std::size_t i) { return i * i; };
|
||||||
|
auto iter = MakeIndexTransformIter(sqr);
|
||||||
|
for (std::size_t i = 0; i < 4; ++i) {
|
||||||
|
ASSERT_EQ(iter[i], sqr(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace common
|
||||||
|
} // namespace xgboost
|
||||||
Loading…
x
Reference in New Issue
Block a user