Small cleanup to gradient index and hist. (#7668)

* Code comments.
* Const accessor to index.
* Remove some weird variables in the `Index` class.
* Simplify the `MemStackAllocator`.
This commit is contained in:
Jiaming Yuan 2022-02-23 11:37:21 +08:00 committed by GitHub
parent 49c74a5369
commit 6762c45494
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 149 additions and 148 deletions

View File

@ -266,9 +266,9 @@ class ColumnMatrix {
}
template <typename T>
inline void SetIndexAllDense(T *index, const GHistIndexMatrix &gmat,
const size_t nrow, const size_t nfeature,
const bool noMissingValues, int32_t n_threads) {
inline void SetIndexAllDense(T const* index, const GHistIndexMatrix& gmat, const size_t nrow,
const size_t nfeature, const bool noMissingValues,
int32_t n_threads) {
T* local_index = reinterpret_cast<T*>(&index_[0]);
/* missing values make sense only for column with type kDenseColumn,
@ -313,7 +313,7 @@ class ColumnMatrix {
}
template<typename T>
inline void SetIndex(uint32_t* index, const GHistIndexMatrix& gmat,
inline void SetIndex(uint32_t const* index, const GHistIndexMatrix& gmat,
const size_t nfeature) {
std::vector<size_t> num_nonzeros;
num_nonzeros.resize(nfeature);

View File

@ -197,19 +197,27 @@ enum BinTypeSize : uint32_t {
kUint32BinsTypeSize = 4
};
/**
* \brief Optionally compressed gradient index. The compression works only with dense
* data.
*
* The main body of construction code is in gradient_index.cc, this struct is only a
* storage class.
*/
struct Index {
Index() {
SetBinTypeSize(binTypeSize_);
}
Index() { SetBinTypeSize(binTypeSize_); }
Index(const Index& i) = delete;
Index& operator=(Index i) = delete;
Index(Index&& i) = delete;
Index& operator=(Index&& i) = delete;
uint32_t operator[](size_t i) const {
if (offset_ptr_ != nullptr) {
return func_(data_ptr_, i) + offset_ptr_[i%p_];
if (!bin_offset_.empty()) {
// dense, compressed
auto fidx = i % bin_offset_.size();
// restore the index by adding back its feature offset.
return func_(data_.data(), i) + bin_offset_[fidx];
} else {
return func_(data_ptr_, i);
return func_(data_.data(), i);
}
}
void SetBinTypeSize(BinTypeSize binTypeSize) {
@ -225,8 +233,7 @@ struct Index {
func_ = &GetValueFromUint32;
break;
default:
CHECK(binTypeSize == kUint8BinsTypeSize ||
binTypeSize == kUint16BinsTypeSize ||
CHECK(binTypeSize == kUint8BinsTypeSize || binTypeSize == kUint16BinsTypeSize ||
binTypeSize == kUint32BinsTypeSize);
}
}
@ -234,26 +241,24 @@ struct Index {
return binTypeSize_;
}
template <typename T>
T* data() const { // NOLINT
return static_cast<T*>(data_ptr_);
T const* data() const { // NOLINT
return reinterpret_cast<T const*>(data_.data());
}
uint32_t* Offset() const {
return offset_ptr_;
template <typename T>
T* data() { // NOLINT
return reinterpret_cast<T*>(data_.data());
}
size_t OffsetSize() const {
return offset_.size();
uint32_t const* Offset() const { return bin_offset_.data(); }
size_t OffsetSize() const { return bin_offset_.size(); }
size_t Size() const { return data_.size() / (binTypeSize_); }
void Resize(const size_t n_bytes) {
data_.resize(n_bytes);
}
size_t Size() const {
return data_.size() / (binTypeSize_);
}
void Resize(const size_t nBytesData) {
data_.resize(nBytesData);
data_ptr_ = reinterpret_cast<void*>(data_.data());
}
void ResizeOffset(const size_t nDisps) {
offset_.resize(nDisps);
offset_ptr_ = offset_.data();
p_ = nDisps;
// set the offset used in compression, cut_ptrs is the CSC indptr in HistogramCuts
void SetBinOffset(std::vector<uint32_t> const& cut_ptrs) {
bin_offset_.resize(cut_ptrs.size() - 1); // resize to number of features.
std::copy_n(cut_ptrs.begin(), bin_offset_.size(), bin_offset_.begin());
}
std::vector<uint8_t>::const_iterator begin() const { // NOLINT
return data_.begin();
@ -270,24 +275,23 @@ struct Index {
}
private:
static uint32_t GetValueFromUint8(void *t, size_t i) {
return reinterpret_cast<uint8_t*>(t)[i];
// Functions to decompress the index.
static uint32_t GetValueFromUint8(uint8_t const* t, size_t i) { return t[i]; }
static uint32_t GetValueFromUint16(uint8_t const* t, size_t i) {
return reinterpret_cast<uint16_t const*>(t)[i];
}
static uint32_t GetValueFromUint16(void* t, size_t i) {
return reinterpret_cast<uint16_t*>(t)[i];
}
static uint32_t GetValueFromUint32(void* t, size_t i) {
return reinterpret_cast<uint32_t*>(t)[i];
static uint32_t GetValueFromUint32(uint8_t const* t, size_t i) {
return reinterpret_cast<uint32_t const*>(t)[i];
}
using Func = uint32_t (*)(void*, size_t);
using Func = uint32_t (*)(uint8_t const*, size_t);
std::vector<uint8_t> data_;
std::vector<uint32_t> offset_; // size of this field is equal to number of features
void* data_ptr_;
// starting position of each feature inside the cut values (the indptr of the CSC cut matrix
// HistogramCuts without the last entry.) Used for bin compression.
std::vector<uint32_t> bin_offset_;
BinTypeSize binTypeSize_ {kUint8BinsTypeSize};
size_t p_ {1};
uint32_t* offset_ptr_ {nullptr};
Func func_;
};
@ -304,9 +308,11 @@ int32_t XGBOOST_HOST_DEV_INLINE BinarySearchBin(size_t begin, size_t end,
}
previous_middle = middle;
// index into all the bins
auto gidx = data[middle];
if (gidx >= fidx_begin && gidx < fidx_end) {
// Found the intersection.
return static_cast<int32_t>(gidx);
} else if (gidx < fidx_begin) {
begin = middle;
@ -636,42 +642,6 @@ class GHistBuilder {
/*! \brief number of all bins over all features */
uint32_t nbins_ { 0 };
};
/*!
* \brief A C-style array with in-stack allocation. As long as the array is smaller than
* MaxStackSize, it will be allocated inside the stack. Otherwise, it will be
* heap-allocated.
*/
template<typename T, size_t MaxStackSize>
class MemStackAllocator {
public:
explicit MemStackAllocator(size_t required_size): required_size_(required_size) {
}
T* Get() {
if (!ptr_) {
if (MaxStackSize >= required_size_) {
ptr_ = stack_mem_;
} else {
ptr_ = reinterpret_cast<T*>(malloc(required_size_ * sizeof(T)));
do_free_ = true;
}
}
return ptr_;
}
~MemStackAllocator() {
if (do_free_) free(ptr_);
}
private:
T* ptr_ = nullptr;
bool do_free_ = false;
size_t required_size_;
T stack_mem_[MaxStackSize];
};
} // namespace common
} // namespace xgboost
#endif // XGBOOST_COMMON_HIST_UTIL_H_

View File

@ -246,6 +246,43 @@ inline int32_t OmpGetNumThreads(int32_t n_threads) {
n_threads = std::max(n_threads, 1);
return n_threads;
}
/*!
* \brief A C-style array with in-stack allocation. As long as the array is smaller than
* MaxStackSize, it will be allocated inside the stack. Otherwise, it will be
* heap-allocated.
*/
template <typename T, size_t MaxStackSize>
class MemStackAllocator {
public:
explicit MemStackAllocator(size_t required_size) : required_size_(required_size) {
if (MaxStackSize >= required_size_) {
ptr_ = stack_mem_;
} else {
ptr_ = reinterpret_cast<T*>(malloc(required_size_ * sizeof(T)));
}
if (!ptr_) {
throw std::bad_alloc{};
}
}
~MemStackAllocator() {
if (required_size_ > MaxStackSize) {
free(ptr_);
}
}
T& operator[](size_t i) { return ptr_[i]; }
T const& operator[](size_t i) const { return ptr_[i]; }
// FIXME(jiamingy): Remove this once we merge partitioner cleanup for hist.
auto Get() { return ptr_; }
private:
T* ptr_ = nullptr;
size_t required_size_;
T stack_mem_[MaxStackSize];
};
} // namespace common
} // namespace xgboost

View File

@ -10,6 +10,7 @@
#include "../common/column_matrix.h"
#include "../common/hist_util.h"
#include "../common/threading_utils.h"
namespace xgboost {
@ -34,7 +35,6 @@ void GHistIndexMatrix::PushBatch(SparsePage const &batch,
std::max(static_cast<size_t>(1), std::min(batch.Size(), static_cast<size_t>(n_threads)));
auto page = batch.GetView();
common::MemStackAllocator<size_t, 128> partial_sums(batch_threads);
size_t *p_part = partial_sums.Get();
size_t block_size = batch.Size() / batch_threads;
@ -48,10 +48,10 @@ void GHistIndexMatrix::PushBatch(SparsePage const &batch,
size_t iend = (tid == (batch_threads - 1) ? batch.Size()
: (block_size * (tid + 1)));
size_t sum = 0;
for (size_t i = ibegin; i < iend; ++i) {
sum += page[i].size();
row_ptr[rbegin + 1 + i] = sum;
size_t running_sum = 0;
for (size_t ridx = ibegin; ridx < iend; ++ridx) {
running_sum += page[ridx].size();
row_ptr[rbegin + 1 + ridx] = running_sum;
}
});
}
@ -59,9 +59,9 @@ void GHistIndexMatrix::PushBatch(SparsePage const &batch,
#pragma omp single
{
exc.Run([&]() {
p_part[0] = prev_sum;
partial_sums[0] = prev_sum;
for (size_t i = 1; i < batch_threads; ++i) {
p_part[i] = p_part[i - 1] + row_ptr[rbegin + i * block_size];
partial_sums[i] = partial_sums[i - 1] + row_ptr[rbegin + i * block_size];
}
});
}
@ -74,55 +74,52 @@ void GHistIndexMatrix::PushBatch(SparsePage const &batch,
: (block_size * (tid + 1)));
for (size_t i = ibegin; i < iend; ++i) {
row_ptr[rbegin + 1 + i] += p_part[tid];
row_ptr[rbegin + 1 + i] += partial_sums[tid];
}
});
}
}
exc.Rethrow();
const size_t n_offsets = cut.Ptrs().size() - 1;
const size_t n_index = row_ptr[rbegin + batch.Size()];
const size_t n_index = row_ptr[rbegin + batch.Size()]; // number of entries in this page
ResizeIndex(n_index, isDense_);
CHECK_GT(cut.Values().size(), 0U);
uint32_t *offsets = nullptr;
if (isDense_) {
index.ResizeOffset(n_offsets);
offsets = index.Offset();
for (size_t i = 0; i < n_offsets; ++i) {
offsets[i] = cut.Ptrs()[i];
}
index.SetBinOffset(cut.Ptrs());
}
uint32_t const *offsets = index.Offset();
if (isDense_) {
// Inside the lambda functions, bin_idx is the index for cut value across all
// features. By subtracting it with starting pointer of each feature, we can reduce
// it to smaller value and compress it to smaller types.
common::BinTypeSize curent_bin_size = index.GetBinTypeSize();
if (curent_bin_size == common::kUint8BinsTypeSize) {
common::Span<uint8_t> index_data_span = {index.data<uint8_t>(), n_index};
SetIndexData(index_data_span, ft, batch_threads, batch, rbegin, nbins,
[offsets](auto idx, auto j) {
return static_cast<uint8_t>(idx - offsets[j]);
[offsets](auto bin_idx, auto fidx) {
return static_cast<uint8_t>(bin_idx - offsets[fidx]);
});
} else if (curent_bin_size == common::kUint16BinsTypeSize) {
common::Span<uint16_t> index_data_span = {index.data<uint16_t>(), n_index};
SetIndexData(index_data_span, ft, batch_threads, batch, rbegin, nbins,
[offsets](auto idx, auto j) {
return static_cast<uint16_t>(idx - offsets[j]);
[offsets](auto bin_idx, auto fidx) {
return static_cast<uint16_t>(bin_idx - offsets[fidx]);
});
} else {
CHECK_EQ(curent_bin_size, common::kUint32BinsTypeSize);
common::Span<uint32_t> index_data_span = {index.data<uint32_t>(), n_index};
SetIndexData(index_data_span, ft, batch_threads, batch, rbegin, nbins,
[offsets](auto idx, auto j) {
return static_cast<uint32_t>(idx - offsets[j]);
[offsets](auto bin_idx, auto fidx) {
return static_cast<uint32_t>(bin_idx - offsets[fidx]);
});
}
} else {
/* For sparse DMatrix we have to store index of feature for each bin
in index field to chose right offset. So offset is nullptr and index is
not reduced */
} else {
common::Span<uint32_t> index_data_span = {index.data<uint32_t>(), n_index};
SetIndexData(index_data_span, ft, batch_threads, batch, rbegin, nbins,
[](auto idx, auto) { return idx; });
@ -194,11 +191,13 @@ void GHistIndexMatrix::Init(SparsePage const &batch, common::Span<FeatureType co
void GHistIndexMatrix::ResizeIndex(const size_t n_index, const bool isDense) {
if ((max_num_bins - 1 <= static_cast<int>(std::numeric_limits<uint8_t>::max())) && isDense) {
// compress dense index to uint8
index.SetBinTypeSize(common::kUint8BinsTypeSize);
index.Resize((sizeof(uint8_t)) * n_index);
} else if ((max_num_bins - 1 > static_cast<int>(std::numeric_limits<uint8_t>::max()) &&
max_num_bins - 1 <= static_cast<int>(std::numeric_limits<uint16_t>::max())) &&
isDense) {
// compress dense index to uint16
index.SetBinTypeSize(common::kUint16BinsTypeSize);
index.Resize((sizeof(uint16_t)) * n_index);
} else {

View File

@ -21,6 +21,13 @@ namespace xgboost {
* index for CPU histogram. On GPU ellpack page is used.
*/
class GHistIndexMatrix {
/**
* \brief Push a page into index matrix, the function is only necessary because hist has
* partial support for external memory.
*
* \param rbegin The beginning row index of current page. (total rows in previous pages)
* \param prev_sum Total number of entries in previous pages.
*/
void PushBatch(SparsePage const& batch, common::Span<FeatureType const> ft, size_t rbegin,
size_t prev_sum, uint32_t nbins, int32_t n_threads);
@ -64,12 +71,12 @@ class GHistIndexMatrix {
BinIdxType* index_data = index_data_span.data();
auto const& ptrs = cut.Ptrs();
auto const& values = cut.Values();
common::ParallelFor(batch_size, batch_threads, [&](omp_ulong i) {
common::ParallelFor(batch_size, batch_threads, [&](omp_ulong ridx) {
const int tid = omp_get_thread_num();
size_t ibegin = row_ptr[rbegin + i];
size_t iend = row_ptr[rbegin + i + 1];
const size_t size = offset_vec[i + 1] - offset_vec[i];
SparsePage::Inst inst = {data_ptr + offset_vec[i], size};
size_t ibegin = row_ptr[rbegin + ridx]; // index of first entry for current block
size_t iend = row_ptr[rbegin + ridx + 1]; // first entry for next block
const size_t size = offset_vec[ridx + 1] - offset_vec[ridx];
SparsePage::Inst inst = {data_ptr + offset_vec[ridx], size};
CHECK_EQ(ibegin + inst.size(), iend);
for (bst_uint j = 0; j < inst.size(); ++j) {
auto e = inst[j];
@ -103,6 +110,10 @@ class GHistIndexMatrix {
return isDense_;
}
void SetDense(bool is_dense) { isDense_ = is_dense; }
/**
* \brief Get the local row index.
*/
size_t RowIdx(size_t ridx) const { return row_ptr[ridx - base_rowid]; }
bst_row_t Size() const {
return row_ptr.empty() ? 0 : row_ptr.size() - 1;

View File

@ -16,14 +16,6 @@ class GHistIndexRawFormat : public SparsePageFormat<GHistIndexMatrix> {
}
// indptr
fi->Read(&page->row_ptr);
// offset
using OffsetT = std::iterator_traits<decltype(page->index.Offset())>::value_type;
std::vector<OffsetT> offset;
if (!fi->Read(&offset)) {
return false;
}
page->index.ResizeOffset(offset.size());
std::copy(offset.begin(), offset.end(), page->index.Offset());
// data
std::vector<uint8_t> data;
if (!fi->Read(&data)) {
@ -55,6 +47,9 @@ class GHistIndexRawFormat : public SparsePageFormat<GHistIndexMatrix> {
return false;
}
page->SetDense(is_dense);
if (is_dense) {
page->index.SetBinOffset(page->cut.Ptrs());
}
return true;
}
@ -65,13 +60,6 @@ class GHistIndexRawFormat : public SparsePageFormat<GHistIndexMatrix> {
fo->Write(page.row_ptr);
bytes += page.row_ptr.size() * sizeof(decltype(page.row_ptr)::value_type) +
sizeof(uint64_t);
// offset
using OffsetT = std::iterator_traits<decltype(page.index.Offset())>::value_type;
std::vector<OffsetT> offset(page.index.OffsetSize());
std::copy(page.index.Offset(),
page.index.Offset() + page.index.OffsetSize(), offset.begin());
fo->Write(offset);
bytes += page.index.OffsetSize() * sizeof(OffsetT) + sizeof(uint64_t);
// data
std::vector<uint8_t> data(page.index.begin(), page.index.end());
fo->Write(data);

View File

@ -35,14 +35,12 @@ class ApproxRowPartitioner {
std::vector<uint32_t> const &cut_ptrs,
std::vector<float> const &cut_values) {
int32_t gidx = -1;
auto const &row_ptr = index.row_ptr;
auto get_rid = [&](size_t ridx) { return row_ptr[ridx - index.base_rowid]; };
if (index.IsDense()) {
gidx = index.index[get_rid(ridx) + fidx];
// RowIdx returns the starting pos of this row
gidx = index.index[index.RowIdx(ridx) + fidx];
} else {
auto begin = get_rid(ridx);
auto end = get_rid(ridx + 1);
auto begin = index.RowIdx(ridx);
auto end = index.RowIdx(ridx + 1);
auto f_begin = cut_ptrs[fidx];
auto f_end = cut_ptrs[fidx + 1];
gidx = common::BinarySearchBin(begin, end, index.index, f_begin, f_end);

View File

@ -135,7 +135,7 @@ void QuantileHistMaker::Builder<GradientSumT>::InitRoot(
{
auto nid = RegTree::kRoot;
GHistRowT hist = this->histogram_builder_->Histogram()[nid];
auto hist = this->histogram_builder_->Histogram()[nid];
GradientPairT grad_stat;
if (data_layout_ == DataLayout::kDenseDataZeroBased ||
data_layout_ == DataLayout::kDenseDataOneBased) {
@ -149,7 +149,7 @@ void QuantileHistMaker::Builder<GradientSumT>::InitRoot(
grad_stat.Add(et.GetGrad(), et.GetHess());
}
} else {
const RowSetCollection::Elem e = row_set_collection_[nid];
const common::RowSetCollection::Elem e = row_set_collection_[nid];
for (const size_t *it = e.begin; it < e.end; ++it) {
grad_stat.Add(gpair_h[*it].GetGrad(), gpair_h[*it].GetHess());
}
@ -229,7 +229,7 @@ template<typename GradientSumT>
template <bool any_missing>
void QuantileHistMaker::Builder<GradientSumT>::ExpandTree(
const GHistIndexMatrix& gmat,
const ColumnMatrix& column_matrix,
const common::ColumnMatrix& column_matrix,
DMatrix* p_fmat,
RegTree* p_tree,
const std::vector<GradientPair>& gpair_h) {

View File

@ -147,7 +147,7 @@ class QuantileHistMaker: public TreeUpdater {
// training parameter
TrainParam param_;
// column accessor
ColumnMatrix column_matrix_;
common::ColumnMatrix column_matrix_;
DMatrix const* p_last_dmat_ {nullptr};
bool is_gmat_initialized_ {false};
@ -155,7 +155,6 @@ class QuantileHistMaker: public TreeUpdater {
template<typename GradientSumT>
struct Builder {
public:
using GHistRowT = GHistRow<GradientSumT>;
using GradientPairT = xgboost::detail::GradientPairInternal<GradientSumT>;
// constructor
explicit Builder(const size_t n_trees, const TrainParam& param,
@ -164,7 +163,6 @@ class QuantileHistMaker: public TreeUpdater {
: n_trees_(n_trees),
param_(param),
pruner_(std::move(pruner)),
p_last_tree_(nullptr),
p_last_fmat_(fmat),
histogram_builder_{new HistogramBuilder<GradientSumT, CPUExpandEntry>},
task_{task},
@ -172,7 +170,7 @@ class QuantileHistMaker: public TreeUpdater {
builder_monitor_.Init("Quantile::Builder");
}
// update one tree, growing
void Update(const GHistIndexMatrix& gmat, const ColumnMatrix& column_matrix,
void Update(const GHistIndexMatrix& gmat, const common::ColumnMatrix& column_matrix,
HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat, RegTree* p_tree);
bool UpdatePredictionCache(const DMatrix* data,

View File

@ -306,8 +306,8 @@ TEST(HistUtil, IndexBinBound) {
}
template <typename T>
void CheckIndexData(T* data_ptr, uint32_t* offsets,
const GHistIndexMatrix& hmat, size_t n_cols) {
void CheckIndexData(T const* data_ptr, uint32_t const* offsets, const GHistIndexMatrix& hmat,
size_t n_cols) {
for (size_t i = 0; i < hmat.index.Size(); ++i) {
EXPECT_EQ(data_ptr[i] + offsets[i % n_cols], hmat.index[i]);
}
@ -323,7 +323,7 @@ TEST(HistUtil, IndexBinData) {
for (auto max_bin : kBinSizes) {
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
GHistIndexMatrix hmat(p_fmat.get(), max_bin, 0.5, false, common::OmpGetNumThreads(0));
uint32_t* offsets = hmat.index.Offset();
uint32_t const* offsets = hmat.index.Offset();
EXPECT_EQ(hmat.index.Size(), kRows*kCols);
switch (max_bin) {
case kBinSizes[0]:

View File

@ -6,15 +6,16 @@
#include <limits>
#include "../../../../src/common/categorical.h"
#include "../../../../src/common/row_set.h"
#include "../../../../src/tree/hist/expand_entry.h"
#include "../../../../src/tree/hist/histogram.h"
#include "../../../../src/tree/updater_quantile_hist.h"
#include "../../categorical_helpers.h"
#include "../../helpers.h"
namespace xgboost {
namespace tree {
namespace {
void InitRowPartitionForTest(RowSetCollection *row_set, size_t n_samples, size_t base_rowid = 0) {
void InitRowPartitionForTest(common::RowSetCollection *row_set, size_t n_samples, size_t base_rowid = 0) {
auto &row_indices = *row_set->Data();
row_indices.resize(n_samples);
std::iota(row_indices.begin(), row_indices.end(), base_rowid);
@ -91,7 +92,7 @@ void TestSyncHist(bool is_distributed) {
uint32_t total_bins = gmat.cut.Ptrs().back();
histogram.Reset(total_bins, {kMaxBins, 0.5}, omp_get_max_threads(), 1, is_distributed);
RowSetCollection row_set_collection_;
common::RowSetCollection row_set_collection_;
{
row_set_collection_.Clear();
std::vector<size_t> &row_indices = *row_set_collection_.Data();
@ -256,7 +257,7 @@ void TestBuildHistogram(bool is_distributed) {
RegTree tree;
RowSetCollection row_set_collection;
common::RowSetCollection row_set_collection;
row_set_collection.Clear();
std::vector<size_t> &row_indices = *row_set_collection.Data();
row_indices.resize(kNRows);
@ -318,7 +319,7 @@ void TestHistogramCategorical(size_t n_categories) {
auto gpair = GenerateRandomGradients(kRows, 0, 2);
RowSetCollection row_set_collection;
common::RowSetCollection row_set_collection;
row_set_collection.Clear();
std::vector<size_t> &row_indices = *row_set_collection.Data();
row_indices.resize(kRows);
@ -381,13 +382,13 @@ void TestHistogramExternalMemory(BatchParam batch_param, bool is_approx) {
std::vector<CPUExpandEntry> nodes;
nodes.emplace_back(0, tree.GetDepth(0), 0.0f);
GHistRow<double> multi_page;
common::GHistRow<double> multi_page;
HistogramBuilder<double, CPUExpandEntry> multi_build;
{
/**
* Multi page
*/
std::vector<RowSetCollection> rows_set;
std::vector<common::RowSetCollection> rows_set;
for (auto const &page : m->GetBatches<GHistIndexMatrix>(batch_param)) {
CHECK_LT(page.base_rowid, m->Info().num_row_);
auto n_rows_in_node = page.Size();
@ -417,12 +418,12 @@ void TestHistogramExternalMemory(BatchParam batch_param, bool is_approx) {
}
HistogramBuilder<double, CPUExpandEntry> single_build;
GHistRow<double> single_page;
common::GHistRow<double> single_page;
{
/**
* Single page
*/
RowSetCollection row_set_collection;
common::RowSetCollection row_set_collection;
InitRowPartitionForTest(&row_set_collection, n_samples);
single_build.Reset(total_bins, batch_param, common::OmpGetNumThreads(0), 1, false);

View File

@ -24,7 +24,6 @@ class QuantileHistMock : public QuantileHistMaker {
template <typename GradientSumT>
struct BuilderMock : public QuantileHistMaker::Builder<GradientSumT> {
using RealImpl = QuantileHistMaker::Builder<GradientSumT>;
using GHistRowT = typename RealImpl::GHistRowT;
BuilderMock(const TrainParam &param, std::unique_ptr<TreeUpdater> pruner,
DMatrix const *fmat, GenericParameter const* ctx)