Extract transform iterator. (#8498)

This commit is contained in:
Jiaming Yuan
2022-12-05 21:37:07 +08:00
committed by GitHub
parent d8544e4d9e
commit e3bf5565ab
9 changed files with 118 additions and 71 deletions

View File

@@ -7,6 +7,7 @@
#include "../common/categorical.h"
#include "../common/hist_util.cuh"
#include "../common/random.h"
#include "../common/transform_iterator.h" // MakeIndexTransformIter
#include "./ellpack_page.cuh"
#include "device_adapter.cuh"
#include "gradient_index.h"

View File

@@ -13,6 +13,7 @@
#include "../common/hist_util.h"
#include "../common/numeric.h"
#include "../common/threading_utils.h"
#include "../common/transform_iterator.h" // MakeIndexTransformIter
namespace xgboost {
@@ -78,7 +79,7 @@ GHistIndexMatrix::~GHistIndexMatrix() = default;
void GHistIndexMatrix::PushBatch(SparsePage const &batch, common::Span<FeatureType const> ft,
int32_t n_threads) {
auto page = batch.GetView();
auto it = common::MakeIndexTransformIter([&](size_t ridx) { return page[ridx].size(); });
auto it = common::MakeIndexTransformIter([&](std::size_t ridx) { return page[ridx].size(); });
common::PartialSum(n_threads, it, it + page.Size(), static_cast<size_t>(0), row_ptr.begin());
data::SparsePageAdapterBatch adapter_batch{page};
auto is_valid = [](auto) { return true; }; // SparsePage always contains valid entries

View File

@@ -15,6 +15,7 @@
#include "../common/hist_util.h"
#include "../common/numeric.h"
#include "../common/threading_utils.h"
#include "../common/transform_iterator.h" // MakeIndexTransformIter
#include "adapter.h"
#include "proxy_dmatrix.h"
#include "xgboost/base.h"