Move GHistIndex into DMatrix. (#7064)
This commit is contained in:
parent
1c8fdf2218
commit
1cd20efe68
@ -38,6 +38,7 @@
|
||||
#include "../src/data/sparse_page_raw_format.cc"
|
||||
#include "../src/data/ellpack_page.cc"
|
||||
#include "../src/data/ellpack_page_source.cc"
|
||||
#include "../src/data/gradient_index.cc"
|
||||
|
||||
// prediction
|
||||
#include "../src/predictor/predictor.cc"
|
||||
|
||||
@ -385,6 +385,8 @@ class EllpackPage {
|
||||
std::unique_ptr<EllpackPageImpl> impl_;
|
||||
};
|
||||
|
||||
class GHistIndexMatrix;
|
||||
|
||||
template<typename T>
|
||||
class BatchIteratorImpl {
|
||||
public:
|
||||
@ -553,6 +555,7 @@ class DMatrix {
|
||||
virtual BatchSet<CSCPage> GetColumnBatches() = 0;
|
||||
virtual BatchSet<SortedCSCPage> GetSortedColumnBatches() = 0;
|
||||
virtual BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) = 0;
|
||||
virtual BatchSet<GHistIndexMatrix> GetGradientIndex(const BatchParam& param) = 0;
|
||||
|
||||
virtual bool EllpackExists() const = 0;
|
||||
virtual bool SparsePageExists() const = 0;
|
||||
@ -587,6 +590,11 @@ template<>
|
||||
inline BatchSet<EllpackPage> DMatrix::GetBatches(const BatchParam& param) {
|
||||
return GetEllpackBatches(param);
|
||||
}
|
||||
|
||||
template<>
|
||||
inline BatchSet<GHistIndexMatrix> DMatrix::GetBatches(const BatchParam& param) {
|
||||
return GetGradientIndex(param);
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
namespace dmlc {
|
||||
|
||||
@ -12,6 +12,7 @@
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "hist_util.h"
|
||||
#include "../data/gradient_index.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
@ -262,9 +263,10 @@ class ColumnMatrix {
|
||||
return res;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
inline void SetIndexAllDense(T* index, const GHistIndexMatrix& gmat, const size_t nrow,
|
||||
const size_t nfeature, const bool noMissingValues) {
|
||||
template <typename T>
|
||||
inline void SetIndexAllDense(T *index, const GHistIndexMatrix &gmat,
|
||||
const size_t nrow, const size_t nfeature,
|
||||
const bool noMissingValues) {
|
||||
T* local_index = reinterpret_cast<T*>(&index_[0]);
|
||||
|
||||
/* missing values make sense only for column with type kDenseColumn,
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
#include "column_matrix.h"
|
||||
#include "quantile.h"
|
||||
#include "./../tree/updater_quantile_hist.h"
|
||||
#include "../data/gradient_index.h"
|
||||
|
||||
#if defined(XGBOOST_MM_PREFETCH_PRESENT)
|
||||
#include <xmmintrin.h>
|
||||
@ -29,164 +30,10 @@
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
void GHistIndexMatrix::ResizeIndex(const size_t n_index,
|
||||
const bool isDense) {
|
||||
if ((max_num_bins - 1 <= static_cast<int>(std::numeric_limits<uint8_t>::max())) && isDense) {
|
||||
index.SetBinTypeSize(kUint8BinsTypeSize);
|
||||
index.Resize((sizeof(uint8_t)) * n_index);
|
||||
} else if ((max_num_bins - 1 > static_cast<int>(std::numeric_limits<uint8_t>::max()) &&
|
||||
max_num_bins - 1 <= static_cast<int>(std::numeric_limits<uint16_t>::max())) && isDense) {
|
||||
index.SetBinTypeSize(kUint16BinsTypeSize);
|
||||
index.Resize((sizeof(uint16_t)) * n_index);
|
||||
} else {
|
||||
index.SetBinTypeSize(kUint32BinsTypeSize);
|
||||
index.Resize((sizeof(uint32_t)) * n_index);
|
||||
}
|
||||
}
|
||||
|
||||
HistogramCuts::HistogramCuts() {
|
||||
cut_ptrs_.HostVector().emplace_back(0);
|
||||
}
|
||||
|
||||
void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_bins) {
|
||||
cut = SketchOnDMatrix(p_fmat, max_bins);
|
||||
|
||||
max_num_bins = max_bins;
|
||||
const int32_t nthread = omp_get_max_threads();
|
||||
const uint32_t nbins = cut.Ptrs().back();
|
||||
hit_count.resize(nbins, 0);
|
||||
hit_count_tloc_.resize(nthread * nbins, 0);
|
||||
|
||||
this->p_fmat = p_fmat;
|
||||
size_t new_size = 1;
|
||||
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
new_size += batch.Size();
|
||||
}
|
||||
|
||||
row_ptr.resize(new_size);
|
||||
row_ptr[0] = 0;
|
||||
|
||||
size_t rbegin = 0;
|
||||
size_t prev_sum = 0;
|
||||
const bool isDense = p_fmat->IsDense();
|
||||
this->isDense_ = isDense;
|
||||
|
||||
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
// The number of threads is pegged to the batch size. If the OMP
|
||||
// block is parallelized on anything other than the batch/block size,
|
||||
// it should be reassigned
|
||||
const size_t batch_threads = std::max(
|
||||
size_t(1),
|
||||
std::min(batch.Size(), static_cast<size_t>(omp_get_max_threads())));
|
||||
auto page = batch.GetView();
|
||||
MemStackAllocator<size_t, 128> partial_sums(batch_threads);
|
||||
size_t* p_part = partial_sums.Get();
|
||||
|
||||
size_t block_size = batch.Size() / batch_threads;
|
||||
|
||||
dmlc::OMPException exc;
|
||||
#pragma omp parallel num_threads(batch_threads)
|
||||
{
|
||||
#pragma omp for
|
||||
for (omp_ulong tid = 0; tid < batch_threads; ++tid) {
|
||||
exc.Run([&]() {
|
||||
size_t ibegin = block_size * tid;
|
||||
size_t iend = (tid == (batch_threads-1) ? batch.Size() : (block_size * (tid+1)));
|
||||
|
||||
size_t sum = 0;
|
||||
for (size_t i = ibegin; i < iend; ++i) {
|
||||
sum += page[i].size();
|
||||
row_ptr[rbegin + 1 + i] = sum;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#pragma omp single
|
||||
{
|
||||
exc.Run([&]() {
|
||||
p_part[0] = prev_sum;
|
||||
for (size_t i = 1; i < batch_threads; ++i) {
|
||||
p_part[i] = p_part[i - 1] + row_ptr[rbegin + i*block_size];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#pragma omp for
|
||||
for (omp_ulong tid = 0; tid < batch_threads; ++tid) {
|
||||
exc.Run([&]() {
|
||||
size_t ibegin = block_size * tid;
|
||||
size_t iend = (tid == (batch_threads-1) ? batch.Size() : (block_size * (tid+1)));
|
||||
|
||||
for (size_t i = ibegin; i < iend; ++i) {
|
||||
row_ptr[rbegin + 1 + i] += p_part[tid];
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
exc.Rethrow();
|
||||
|
||||
const size_t n_offsets = cut.Ptrs().size() - 1;
|
||||
const size_t n_index = row_ptr[rbegin + batch.Size()];
|
||||
ResizeIndex(n_index, isDense);
|
||||
|
||||
CHECK_GT(cut.Values().size(), 0U);
|
||||
|
||||
uint32_t* offsets = nullptr;
|
||||
if (isDense) {
|
||||
index.ResizeOffset(n_offsets);
|
||||
offsets = index.Offset();
|
||||
for (size_t i = 0; i < n_offsets; ++i) {
|
||||
offsets[i] = cut.Ptrs()[i];
|
||||
}
|
||||
}
|
||||
|
||||
if (isDense) {
|
||||
BinTypeSize curent_bin_size = index.GetBinTypeSize();
|
||||
if (curent_bin_size == kUint8BinsTypeSize) {
|
||||
common::Span<uint8_t> index_data_span = {index.data<uint8_t>(),
|
||||
n_index};
|
||||
SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins,
|
||||
[offsets](auto idx, auto j) {
|
||||
return static_cast<uint8_t>(idx - offsets[j]);
|
||||
});
|
||||
|
||||
} else if (curent_bin_size == kUint16BinsTypeSize) {
|
||||
common::Span<uint16_t> index_data_span = {index.data<uint16_t>(),
|
||||
n_index};
|
||||
SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins,
|
||||
[offsets](auto idx, auto j) {
|
||||
return static_cast<uint16_t>(idx - offsets[j]);
|
||||
});
|
||||
} else {
|
||||
CHECK_EQ(curent_bin_size, kUint32BinsTypeSize);
|
||||
common::Span<uint32_t> index_data_span = {index.data<uint32_t>(),
|
||||
n_index};
|
||||
SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins,
|
||||
[offsets](auto idx, auto j) {
|
||||
return static_cast<uint32_t>(idx - offsets[j]);
|
||||
});
|
||||
}
|
||||
|
||||
/* For sparse DMatrix we have to store index of feature for each bin
|
||||
in index field to chose right offset. So offset is nullptr and index is not reduced */
|
||||
} else {
|
||||
common::Span<uint32_t> index_data_span = {index.data<uint32_t>(), n_index};
|
||||
SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins,
|
||||
[](auto idx, auto) { return idx; });
|
||||
}
|
||||
|
||||
ParallelFor(bst_omp_uint(nbins), nthread, [&](bst_omp_uint idx) {
|
||||
for (int32_t tid = 0; tid < nthread; ++tid) {
|
||||
hit_count[idx] += hit_count_tloc_[tid * nbins + idx];
|
||||
hit_count_tloc_[tid * nbins + idx] = 0; // reset for next batch
|
||||
}
|
||||
});
|
||||
|
||||
prev_sum = row_ptr[rbegin + batch.Size()];
|
||||
rbegin += batch.Size();
|
||||
}
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief fill a histogram by zeros in range [begin, end)
|
||||
*/
|
||||
@ -382,25 +229,25 @@ void GHistBuilder<GradientSumT>::BuildHist(
|
||||
BuildHistDispatch<GradientSumT, false, any_missing>(gpair, span2, gmat, hist);
|
||||
}
|
||||
}
|
||||
template
|
||||
void GHistBuilder<float>::BuildHist<true>(const std::vector<GradientPair>& gpair,
|
||||
template void
|
||||
GHistBuilder<float>::BuildHist<true>(const std::vector<GradientPair> &gpair,
|
||||
const RowSetCollection::Elem row_indices,
|
||||
const GHistIndexMatrix& gmat,
|
||||
const GHistIndexMatrix &gmat,
|
||||
GHistRow<float> hist);
|
||||
template
|
||||
void GHistBuilder<float>::BuildHist<false>(const std::vector<GradientPair>& gpair,
|
||||
template void
|
||||
GHistBuilder<float>::BuildHist<false>(const std::vector<GradientPair> &gpair,
|
||||
const RowSetCollection::Elem row_indices,
|
||||
const GHistIndexMatrix& gmat,
|
||||
const GHistIndexMatrix &gmat,
|
||||
GHistRow<float> hist);
|
||||
template
|
||||
void GHistBuilder<double>::BuildHist<true>(const std::vector<GradientPair>& gpair,
|
||||
template void
|
||||
GHistBuilder<double>::BuildHist<true>(const std::vector<GradientPair> &gpair,
|
||||
const RowSetCollection::Elem row_indices,
|
||||
const GHistIndexMatrix& gmat,
|
||||
const GHistIndexMatrix &gmat,
|
||||
GHistRow<double> hist);
|
||||
template
|
||||
void GHistBuilder<double>::BuildHist<false>(const std::vector<GradientPair>& gpair,
|
||||
template void
|
||||
GHistBuilder<double>::BuildHist<false>(const std::vector<GradientPair> &gpair,
|
||||
const RowSetCollection::Elem row_indices,
|
||||
const GHistIndexMatrix& gmat,
|
||||
const GHistIndexMatrix &gmat,
|
||||
GHistRow<double> hist);
|
||||
|
||||
template<typename GradientSumT>
|
||||
|
||||
@ -25,6 +25,8 @@
|
||||
#include "../include/rabit/rabit.h"
|
||||
|
||||
namespace xgboost {
|
||||
class GHistIndexMatrix;
|
||||
|
||||
namespace common {
|
||||
/*!
|
||||
* \brief A single row in global histogram index.
|
||||
@ -226,74 +228,6 @@ struct Index {
|
||||
Func func_;
|
||||
};
|
||||
|
||||
|
||||
/*!
|
||||
* \brief preprocessed global index matrix, in CSR format
|
||||
*
|
||||
* Transform floating values to integer index in histogram This is a global histogram
|
||||
* index for CPU histogram. On GPU ellpack page is used.
|
||||
*/
|
||||
struct GHistIndexMatrix {
|
||||
/*! \brief row pointer to rows by element position */
|
||||
std::vector<size_t> row_ptr;
|
||||
/*! \brief The index data */
|
||||
Index index;
|
||||
/*! \brief hit count of each index */
|
||||
std::vector<size_t> hit_count;
|
||||
/*! \brief The corresponding cuts */
|
||||
HistogramCuts cut;
|
||||
DMatrix* p_fmat;
|
||||
size_t max_num_bins;
|
||||
// Create a global histogram matrix, given cut
|
||||
void Init(DMatrix* p_fmat, int max_num_bins);
|
||||
|
||||
// specific method for sparse data as no possibility to reduce allocated memory
|
||||
template <typename BinIdxType, typename GetOffset>
|
||||
void SetIndexData(common::Span<BinIdxType> index_data_span,
|
||||
size_t batch_threads, const SparsePage &batch,
|
||||
size_t rbegin, size_t nbins, GetOffset get_offset) {
|
||||
const xgboost::Entry *data_ptr = batch.data.HostVector().data();
|
||||
const std::vector<bst_row_t> &offset_vec = batch.offset.HostVector();
|
||||
const size_t batch_size = batch.Size();
|
||||
CHECK_LT(batch_size, offset_vec.size());
|
||||
BinIdxType* index_data = index_data_span.data();
|
||||
ParallelFor(omp_ulong(batch_size), batch_threads, [&](omp_ulong i) {
|
||||
const int tid = omp_get_thread_num();
|
||||
size_t ibegin = row_ptr[rbegin + i];
|
||||
size_t iend = row_ptr[rbegin + i + 1];
|
||||
const size_t size = offset_vec[i + 1] - offset_vec[i];
|
||||
SparsePage::Inst inst = {data_ptr + offset_vec[i], size};
|
||||
CHECK_EQ(ibegin + inst.size(), iend);
|
||||
for (bst_uint j = 0; j < inst.size(); ++j) {
|
||||
uint32_t idx = cut.SearchBin(inst[j]);
|
||||
index_data[ibegin + j] = get_offset(idx, j);
|
||||
++hit_count_tloc_[tid * nbins + idx];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void ResizeIndex(const size_t n_index,
|
||||
const bool isDense);
|
||||
|
||||
inline void GetFeatureCounts(size_t* counts) const {
|
||||
auto nfeature = cut.Ptrs().size() - 1;
|
||||
for (unsigned fid = 0; fid < nfeature; ++fid) {
|
||||
auto ibegin = cut.Ptrs()[fid];
|
||||
auto iend = cut.Ptrs()[fid + 1];
|
||||
for (auto i = ibegin; i < iend; ++i) {
|
||||
counts[fid] += hit_count[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
inline bool IsDense() const {
|
||||
return isDense_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<size_t> hit_count_tloc_;
|
||||
bool isDense_;
|
||||
};
|
||||
|
||||
template <typename GradientIndex>
|
||||
int32_t XGBOOST_HOST_DEV_INLINE BinarySearchBin(size_t begin, size_t end,
|
||||
GradientIndex const &data,
|
||||
@ -647,6 +581,42 @@ class GHistBuilder {
|
||||
/*! \brief number of all bins over all features */
|
||||
uint32_t nbins_ { 0 };
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief A C-style array with in-stack allocation. As long as the array is smaller than
|
||||
* MaxStackSize, it will be allocated inside the stack. Otherwise, it will be
|
||||
* heap-allocated.
|
||||
*/
|
||||
template<typename T, size_t MaxStackSize>
|
||||
class MemStackAllocator {
|
||||
public:
|
||||
explicit MemStackAllocator(size_t required_size): required_size_(required_size) {
|
||||
}
|
||||
|
||||
T* Get() {
|
||||
if (!ptr_) {
|
||||
if (MaxStackSize >= required_size_) {
|
||||
ptr_ = stack_mem_;
|
||||
} else {
|
||||
ptr_ = reinterpret_cast<T*>(malloc(required_size_ * sizeof(T)));
|
||||
do_free_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
return ptr_;
|
||||
}
|
||||
|
||||
~MemStackAllocator() {
|
||||
if (do_free_) free(ptr_);
|
||||
}
|
||||
|
||||
|
||||
private:
|
||||
T* ptr_ = nullptr;
|
||||
bool do_free_ = false;
|
||||
size_t required_size_;
|
||||
T stack_mem_[MaxStackSize];
|
||||
};
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_COMMON_HIST_UTIL_H_
|
||||
|
||||
165
src/data/gradient_index.cc
Normal file
165
src/data/gradient_index.cc
Normal file
@ -0,0 +1,165 @@
|
||||
/*!
|
||||
* Copyright 2017-2021 by Contributors
|
||||
* \brief Data type for fast histogram aggregation.
|
||||
*/
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include "gradient_index.h"
|
||||
#include "../common/hist_util.h"
|
||||
|
||||
namespace xgboost {
|
||||
void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_bins) {
|
||||
cut = common::SketchOnDMatrix(p_fmat, max_bins);
|
||||
|
||||
max_num_bins = max_bins;
|
||||
const int32_t nthread = omp_get_max_threads();
|
||||
const uint32_t nbins = cut.Ptrs().back();
|
||||
hit_count.resize(nbins, 0);
|
||||
hit_count_tloc_.resize(nthread * nbins, 0);
|
||||
|
||||
this->p_fmat = p_fmat;
|
||||
size_t new_size = 1;
|
||||
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
new_size += batch.Size();
|
||||
}
|
||||
|
||||
row_ptr.resize(new_size);
|
||||
row_ptr[0] = 0;
|
||||
|
||||
size_t rbegin = 0;
|
||||
size_t prev_sum = 0;
|
||||
const bool isDense = p_fmat->IsDense();
|
||||
this->isDense_ = isDense;
|
||||
|
||||
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
// The number of threads is pegged to the batch size. If the OMP
|
||||
// block is parallelized on anything other than the batch/block size,
|
||||
// it should be reassigned
|
||||
const size_t batch_threads = std::max(
|
||||
size_t(1),
|
||||
std::min(batch.Size(), static_cast<size_t>(omp_get_max_threads())));
|
||||
auto page = batch.GetView();
|
||||
common::MemStackAllocator<size_t, 128> partial_sums(batch_threads);
|
||||
size_t* p_part = partial_sums.Get();
|
||||
|
||||
size_t block_size = batch.Size() / batch_threads;
|
||||
|
||||
dmlc::OMPException exc;
|
||||
#pragma omp parallel num_threads(batch_threads)
|
||||
{
|
||||
#pragma omp for
|
||||
for (omp_ulong tid = 0; tid < batch_threads; ++tid) {
|
||||
exc.Run([&]() {
|
||||
size_t ibegin = block_size * tid;
|
||||
size_t iend = (tid == (batch_threads-1) ? batch.Size() : (block_size * (tid+1)));
|
||||
|
||||
size_t sum = 0;
|
||||
for (size_t i = ibegin; i < iend; ++i) {
|
||||
sum += page[i].size();
|
||||
row_ptr[rbegin + 1 + i] = sum;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#pragma omp single
|
||||
{
|
||||
exc.Run([&]() {
|
||||
p_part[0] = prev_sum;
|
||||
for (size_t i = 1; i < batch_threads; ++i) {
|
||||
p_part[i] = p_part[i - 1] + row_ptr[rbegin + i*block_size];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#pragma omp for
|
||||
for (omp_ulong tid = 0; tid < batch_threads; ++tid) {
|
||||
exc.Run([&]() {
|
||||
size_t ibegin = block_size * tid;
|
||||
size_t iend = (tid == (batch_threads-1) ? batch.Size() : (block_size * (tid+1)));
|
||||
|
||||
for (size_t i = ibegin; i < iend; ++i) {
|
||||
row_ptr[rbegin + 1 + i] += p_part[tid];
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
exc.Rethrow();
|
||||
|
||||
const size_t n_offsets = cut.Ptrs().size() - 1;
|
||||
const size_t n_index = row_ptr[rbegin + batch.Size()];
|
||||
ResizeIndex(n_index, isDense);
|
||||
|
||||
CHECK_GT(cut.Values().size(), 0U);
|
||||
|
||||
uint32_t* offsets = nullptr;
|
||||
if (isDense) {
|
||||
index.ResizeOffset(n_offsets);
|
||||
offsets = index.Offset();
|
||||
for (size_t i = 0; i < n_offsets; ++i) {
|
||||
offsets[i] = cut.Ptrs()[i];
|
||||
}
|
||||
}
|
||||
|
||||
if (isDense) {
|
||||
common::BinTypeSize curent_bin_size = index.GetBinTypeSize();
|
||||
if (curent_bin_size == common::kUint8BinsTypeSize) {
|
||||
common::Span<uint8_t> index_data_span = {index.data<uint8_t>(),
|
||||
n_index};
|
||||
SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins,
|
||||
[offsets](auto idx, auto j) {
|
||||
return static_cast<uint8_t>(idx - offsets[j]);
|
||||
});
|
||||
|
||||
} else if (curent_bin_size == common::kUint16BinsTypeSize) {
|
||||
common::Span<uint16_t> index_data_span = {index.data<uint16_t>(),
|
||||
n_index};
|
||||
SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins,
|
||||
[offsets](auto idx, auto j) {
|
||||
return static_cast<uint16_t>(idx - offsets[j]);
|
||||
});
|
||||
} else {
|
||||
CHECK_EQ(curent_bin_size, common::kUint32BinsTypeSize);
|
||||
common::Span<uint32_t> index_data_span = {index.data<uint32_t>(),
|
||||
n_index};
|
||||
SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins,
|
||||
[offsets](auto idx, auto j) {
|
||||
return static_cast<uint32_t>(idx - offsets[j]);
|
||||
});
|
||||
}
|
||||
|
||||
/* For sparse DMatrix we have to store index of feature for each bin
|
||||
in index field to chose right offset. So offset is nullptr and index is not reduced */
|
||||
} else {
|
||||
common::Span<uint32_t> index_data_span = {index.data<uint32_t>(), n_index};
|
||||
SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins,
|
||||
[](auto idx, auto) { return idx; });
|
||||
}
|
||||
|
||||
common::ParallelFor(bst_omp_uint(nbins), nthread, [&](bst_omp_uint idx) {
|
||||
for (int32_t tid = 0; tid < nthread; ++tid) {
|
||||
hit_count[idx] += hit_count_tloc_[tid * nbins + idx];
|
||||
hit_count_tloc_[tid * nbins + idx] = 0; // reset for next batch
|
||||
}
|
||||
});
|
||||
|
||||
prev_sum = row_ptr[rbegin + batch.Size()];
|
||||
rbegin += batch.Size();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void GHistIndexMatrix::ResizeIndex(const size_t n_index,
|
||||
const bool isDense) {
|
||||
if ((max_num_bins - 1 <= static_cast<int>(std::numeric_limits<uint8_t>::max())) && isDense) {
|
||||
index.SetBinTypeSize(common::kUint8BinsTypeSize);
|
||||
index.Resize((sizeof(uint8_t)) * n_index);
|
||||
} else if ((max_num_bins - 1 > static_cast<int>(std::numeric_limits<uint8_t>::max()) &&
|
||||
max_num_bins - 1 <= static_cast<int>(std::numeric_limits<uint16_t>::max())) && isDense) {
|
||||
index.SetBinTypeSize(common::kUint16BinsTypeSize);
|
||||
index.Resize((sizeof(uint16_t)) * n_index);
|
||||
} else {
|
||||
index.SetBinTypeSize(common::kUint32BinsTypeSize);
|
||||
index.Resize((sizeof(uint32_t)) * n_index);
|
||||
}
|
||||
}
|
||||
} // namespace xgboost
|
||||
86
src/data/gradient_index.h
Normal file
86
src/data/gradient_index.h
Normal file
@ -0,0 +1,86 @@
|
||||
/*!
|
||||
* Copyright 2017-2021 by Contributors
|
||||
* \brief Data type for fast histogram aggregation.
|
||||
*/
|
||||
#ifndef XGBOOST_DATA_GRADIENT_INDEX_H_
|
||||
#define XGBOOST_DATA_GRADIENT_INDEX_H_
|
||||
#include <vector>
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/data.h"
|
||||
#include "../common/hist_util.h"
|
||||
#include "../common/threading_utils.h"
|
||||
|
||||
namespace xgboost {
|
||||
/*!
|
||||
* \brief preprocessed global index matrix, in CSR format
|
||||
*
|
||||
* Transform floating values to integer index in histogram This is a global histogram
|
||||
* index for CPU histogram. On GPU ellpack page is used.
|
||||
*/
|
||||
class GHistIndexMatrix {
|
||||
public:
|
||||
/*! \brief row pointer to rows by element position */
|
||||
std::vector<size_t> row_ptr;
|
||||
/*! \brief The index data */
|
||||
common::Index index;
|
||||
/*! \brief hit count of each index */
|
||||
std::vector<size_t> hit_count;
|
||||
/*! \brief The corresponding cuts */
|
||||
common::HistogramCuts cut;
|
||||
DMatrix* p_fmat;
|
||||
size_t max_num_bins;
|
||||
|
||||
GHistIndexMatrix(DMatrix* x, int32_t max_bin) {
|
||||
this->Init(x, max_bin);
|
||||
}
|
||||
// Create a global histogram matrix, given cut
|
||||
void Init(DMatrix* p_fmat, int max_num_bins);
|
||||
|
||||
// specific method for sparse data as no possibility to reduce allocated memory
|
||||
template <typename BinIdxType, typename GetOffset>
|
||||
void SetIndexData(common::Span<BinIdxType> index_data_span,
|
||||
size_t batch_threads, const SparsePage &batch,
|
||||
size_t rbegin, size_t nbins, GetOffset get_offset) {
|
||||
const xgboost::Entry *data_ptr = batch.data.HostVector().data();
|
||||
const std::vector<bst_row_t> &offset_vec = batch.offset.HostVector();
|
||||
const size_t batch_size = batch.Size();
|
||||
CHECK_LT(batch_size, offset_vec.size());
|
||||
BinIdxType* index_data = index_data_span.data();
|
||||
common::ParallelFor(omp_ulong(batch_size), batch_threads, [&](omp_ulong i) {
|
||||
const int tid = omp_get_thread_num();
|
||||
size_t ibegin = row_ptr[rbegin + i];
|
||||
size_t iend = row_ptr[rbegin + i + 1];
|
||||
const size_t size = offset_vec[i + 1] - offset_vec[i];
|
||||
SparsePage::Inst inst = {data_ptr + offset_vec[i], size};
|
||||
CHECK_EQ(ibegin + inst.size(), iend);
|
||||
for (bst_uint j = 0; j < inst.size(); ++j) {
|
||||
uint32_t idx = cut.SearchBin(inst[j]);
|
||||
index_data[ibegin + j] = get_offset(idx, j);
|
||||
++hit_count_tloc_[tid * nbins + idx];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void ResizeIndex(const size_t n_index,
|
||||
const bool isDense);
|
||||
|
||||
inline void GetFeatureCounts(size_t* counts) const {
|
||||
auto nfeature = cut.Ptrs().size() - 1;
|
||||
for (unsigned fid = 0; fid < nfeature; ++fid) {
|
||||
auto ibegin = cut.Ptrs()[fid];
|
||||
auto iend = cut.Ptrs()[fid + 1];
|
||||
for (auto i = ibegin; i < iend; ++i) {
|
||||
counts[fid] += hit_count[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
inline bool IsDense() const {
|
||||
return isDense_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<size_t> hit_count_tloc_;
|
||||
bool isDense_;
|
||||
};
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_DATA_GRADIENT_INDEX_H_
|
||||
@ -58,6 +58,10 @@ class IterativeDeviceDMatrix : public DMatrix {
|
||||
LOG(FATAL) << "Not implemented.";
|
||||
return BatchSet<SortedCSCPage>(BatchIterator<SortedCSCPage>(nullptr));
|
||||
}
|
||||
BatchSet<GHistIndexMatrix> GetGradientIndex(const BatchParam&) override {
|
||||
LOG(FATAL) << "Not implemented.";
|
||||
return BatchSet<GHistIndexMatrix>(BatchIterator<GHistIndexMatrix>(nullptr));
|
||||
}
|
||||
|
||||
BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) override;
|
||||
|
||||
|
||||
@ -97,6 +97,10 @@ class DMatrixProxy : public DMatrix {
|
||||
LOG(FATAL) << "Not implemented.";
|
||||
return BatchSet<EllpackPage>(BatchIterator<EllpackPage>(nullptr));
|
||||
}
|
||||
BatchSet<GHistIndexMatrix> GetGradientIndex(const BatchParam&) override {
|
||||
LOG(FATAL) << "Not implemented.";
|
||||
return BatchSet<GHistIndexMatrix>(BatchIterator<GHistIndexMatrix>(nullptr));
|
||||
}
|
||||
|
||||
dmlc::any Adapter() const {
|
||||
return batch_;
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
#include "../common/random.h"
|
||||
#include "../common/threading_utils.h"
|
||||
#include "adapter.h"
|
||||
#include "gradient_index.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
@ -89,6 +90,20 @@ BatchSet<EllpackPage> SimpleDMatrix::GetEllpackBatches(const BatchParam& param)
|
||||
return BatchSet<EllpackPage>(begin_iter);
|
||||
}
|
||||
|
||||
BatchSet<GHistIndexMatrix> SimpleDMatrix::GetGradientIndex(const BatchParam& param) {
|
||||
if (!(batch_param_ != BatchParam{})) {
|
||||
CHECK(param != BatchParam{}) << "Batch parameter is not initialized.";
|
||||
}
|
||||
if (!gradient_index_ || (batch_param_ != param && param != BatchParam{})) {
|
||||
CHECK_GE(param.max_bin, 2);
|
||||
gradient_index_.reset(new GHistIndexMatrix(this, param.max_bin));
|
||||
batch_param_ = param;
|
||||
}
|
||||
auto begin_iter = BatchIterator<GHistIndexMatrix>(
|
||||
new SimpleBatchIteratorImpl<GHistIndexMatrix>(gradient_index_.get()));
|
||||
return BatchSet<GHistIndexMatrix>(begin_iter);
|
||||
}
|
||||
|
||||
template <typename AdapterT>
|
||||
SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
|
||||
std::vector<uint64_t> qids;
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "gradient_index.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
@ -43,12 +44,14 @@ class SimpleDMatrix : public DMatrix {
|
||||
BatchSet<CSCPage> GetColumnBatches() override;
|
||||
BatchSet<SortedCSCPage> GetSortedColumnBatches() override;
|
||||
BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) override;
|
||||
BatchSet<GHistIndexMatrix> GetGradientIndex(const BatchParam& param) override;
|
||||
|
||||
MetaInfo info_;
|
||||
SparsePage sparse_page_; // Primary storage type
|
||||
std::unique_ptr<CSCPage> column_page_;
|
||||
std::unique_ptr<SortedCSCPage> sorted_column_page_;
|
||||
std::unique_ptr<EllpackPage> ellpack_page_;
|
||||
std::unique_ptr<GHistIndexMatrix> gradient_index_;
|
||||
BatchParam batch_param_;
|
||||
|
||||
bool EllpackExists() const override {
|
||||
|
||||
@ -47,6 +47,10 @@ class SparsePageDMatrix : public DMatrix {
|
||||
BatchSet<CSCPage> GetColumnBatches() override;
|
||||
BatchSet<SortedCSCPage> GetSortedColumnBatches() override;
|
||||
BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) override;
|
||||
BatchSet<GHistIndexMatrix> GetGradientIndex(const BatchParam&) override {
|
||||
LOG(FATAL) << "Not implemented.";
|
||||
return BatchSet<GHistIndexMatrix>(BatchIterator<GHistIndexMatrix>(nullptr));
|
||||
}
|
||||
|
||||
// source data pointers.
|
||||
std::unique_ptr<SparsePageSource> row_source_;
|
||||
|
||||
@ -69,18 +69,22 @@ template<typename GradientSumT>
|
||||
void QuantileHistMaker::CallBuilderUpdate(const std::unique_ptr<Builder<GradientSumT>>& builder,
|
||||
HostDeviceVector<GradientPair> *gpair,
|
||||
DMatrix *dmat,
|
||||
GHistIndexMatrix const& gmat,
|
||||
const std::vector<RegTree *> &trees) {
|
||||
for (auto tree : trees) {
|
||||
builder->Update(gmat_, column_matrix_, gpair, dmat, tree);
|
||||
builder->Update(gmat, column_matrix_, gpair, dmat, tree);
|
||||
}
|
||||
}
|
||||
void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair,
|
||||
DMatrix *dmat,
|
||||
const std::vector<RegTree *> &trees) {
|
||||
auto const &gmat =
|
||||
*(dmat->GetBatches<GHistIndexMatrix>(
|
||||
BatchParam{GenericParameter::kCpuId, param_.max_bin})
|
||||
.begin());
|
||||
if (dmat != p_last_dmat_ || is_gmat_initialized_ == false) {
|
||||
updater_monitor_.Start("GmatInitialization");
|
||||
gmat_.Init(dmat, static_cast<uint32_t>(param_.max_bin));
|
||||
column_matrix_.Init(gmat_, param_.sparse_threshold);
|
||||
column_matrix_.Init(gmat, param_.sparse_threshold);
|
||||
updater_monitor_.Stop("GmatInitialization");
|
||||
// A proper solution is puting cut matrix in DMatrix, see:
|
||||
// https://github.com/dmlc/xgboost/issues/5143
|
||||
@ -96,12 +100,12 @@ void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair,
|
||||
if (!float_builder_) {
|
||||
SetBuilder(n_trees, &float_builder_, dmat);
|
||||
}
|
||||
CallBuilderUpdate(float_builder_, gpair, dmat, trees);
|
||||
CallBuilderUpdate(float_builder_, gpair, dmat, gmat, trees);
|
||||
} else {
|
||||
if (!double_builder_) {
|
||||
SetBuilder(n_trees, &double_builder_, dmat);
|
||||
}
|
||||
CallBuilderUpdate(double_builder_, gpair, dmat, trees);
|
||||
CallBuilderUpdate(double_builder_, gpair, dmat, gmat, trees);
|
||||
}
|
||||
|
||||
param_.learning_rate = lr;
|
||||
@ -678,7 +682,7 @@ void QuantileHistMaker::Builder<GradientSumT>::InitData(const GHistIndexMatrix&
|
||||
// We should check that the partitioning was done correctly
|
||||
// and each row of the dataset fell into exactly one of the categories
|
||||
}
|
||||
MemStackAllocator<bool, 128> buff(this->nthread_);
|
||||
common::MemStackAllocator<bool, 128> buff(this->nthread_);
|
||||
bool* p_buff = buff.Get();
|
||||
std::fill(p_buff, p_buff + this->nthread_, false);
|
||||
|
||||
|
||||
@ -75,43 +75,9 @@ struct RandomReplace {
|
||||
}
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief A C-style array with in-stack allocation. As long as the array is smaller than MaxStackSize, it will be allocated inside the stack. Otherwise, it will be heap-allocated.
|
||||
*/
|
||||
template<typename T, size_t MaxStackSize>
|
||||
class MemStackAllocator {
|
||||
public:
|
||||
explicit MemStackAllocator(size_t required_size): required_size_(required_size) {
|
||||
}
|
||||
|
||||
T* Get() {
|
||||
if (!ptr_) {
|
||||
if (MaxStackSize >= required_size_) {
|
||||
ptr_ = stack_mem_;
|
||||
} else {
|
||||
ptr_ = reinterpret_cast<T*>(malloc(required_size_ * sizeof(T)));
|
||||
do_free_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
return ptr_;
|
||||
}
|
||||
|
||||
~MemStackAllocator() {
|
||||
if (do_free_) free(ptr_);
|
||||
}
|
||||
|
||||
|
||||
private:
|
||||
T* ptr_ = nullptr;
|
||||
bool do_free_ = false;
|
||||
size_t required_size_;
|
||||
T stack_mem_[MaxStackSize];
|
||||
};
|
||||
|
||||
namespace tree {
|
||||
|
||||
using xgboost::common::GHistIndexMatrix;
|
||||
using xgboost::GHistIndexMatrix;
|
||||
using xgboost::common::GHistIndexRow;
|
||||
using xgboost::common::HistCollection;
|
||||
using xgboost::common::RowSetCollection;
|
||||
@ -243,8 +209,6 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
CPUHistMakerTrainParam hist_maker_param_;
|
||||
// training parameter
|
||||
TrainParam param_;
|
||||
// quantized data matrix
|
||||
GHistIndexMatrix gmat_;
|
||||
// column accessor
|
||||
ColumnMatrix column_matrix_;
|
||||
DMatrix const* p_last_dmat_ {nullptr};
|
||||
@ -466,6 +430,7 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
void CallBuilderUpdate(const std::unique_ptr<Builder<GradientSumT>>& builder,
|
||||
HostDeviceVector<GradientPair> *gpair,
|
||||
DMatrix *dmat,
|
||||
GHistIndexMatrix const& gmat,
|
||||
const std::vector<RegTree *> &trees);
|
||||
|
||||
protected:
|
||||
|
||||
@ -14,8 +14,7 @@ TEST(DenseColumn, Test) {
|
||||
static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 2};
|
||||
for (size_t max_num_bin : max_num_bins) {
|
||||
auto dmat = RandomDataGenerator(100, 10, 0.0).GenerateDMatrix();
|
||||
GHistIndexMatrix gmat;
|
||||
gmat.Init(dmat.get(), max_num_bin);
|
||||
GHistIndexMatrix gmat(dmat.get(), max_num_bin);
|
||||
ColumnMatrix column_matrix;
|
||||
column_matrix.Init(gmat, 0.2);
|
||||
|
||||
@ -62,8 +61,7 @@ TEST(SparseColumn, Test) {
|
||||
static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 2};
|
||||
for (size_t max_num_bin : max_num_bins) {
|
||||
auto dmat = RandomDataGenerator(100, 1, 0.85).GenerateDMatrix();
|
||||
GHistIndexMatrix gmat;
|
||||
gmat.Init(dmat.get(), max_num_bin);
|
||||
GHistIndexMatrix gmat(dmat.get(), max_num_bin);
|
||||
ColumnMatrix column_matrix;
|
||||
column_matrix.Init(gmat, 0.5);
|
||||
switch (column_matrix.GetTypeSize()) {
|
||||
@ -103,8 +101,7 @@ TEST(DenseColumnWithMissing, Test) {
|
||||
static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 2 };
|
||||
for (size_t max_num_bin : max_num_bins) {
|
||||
auto dmat = RandomDataGenerator(100, 1, 0.5).GenerateDMatrix();
|
||||
GHistIndexMatrix gmat;
|
||||
gmat.Init(dmat.get(), max_num_bin);
|
||||
GHistIndexMatrix gmat(dmat.get(), max_num_bin);
|
||||
ColumnMatrix column_matrix;
|
||||
column_matrix.Init(gmat, 0.2);
|
||||
switch (column_matrix.GetTypeSize()) {
|
||||
@ -135,8 +132,7 @@ void TestGHistIndexMatrixCreation(size_t nthreads) {
|
||||
/* This should create multiple sparse pages */
|
||||
std::unique_ptr<DMatrix> dmat{ CreateSparsePageDMatrix(kEntries, kPageSize, filename) };
|
||||
omp_set_num_threads(nthreads);
|
||||
GHistIndexMatrix gmat;
|
||||
gmat.Init(dmat.get(), 256);
|
||||
GHistIndexMatrix gmat(dmat.get(), 256);
|
||||
}
|
||||
|
||||
TEST(HistIndexCreationWithExternalMemory, Test) {
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
#include <utility>
|
||||
|
||||
#include "../../../src/common/hist_util.h"
|
||||
#include "../../../src/data/gradient_index.h"
|
||||
#include "../helpers.h"
|
||||
#include "test_hist_util.h"
|
||||
|
||||
@ -255,8 +256,7 @@ TEST(HistUtil, IndexBinBound) {
|
||||
for (auto max_bin : bin_sizes) {
|
||||
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
|
||||
|
||||
common::GHistIndexMatrix hmat;
|
||||
hmat.Init(p_fmat.get(), max_bin);
|
||||
GHistIndexMatrix hmat(p_fmat.get(), max_bin);
|
||||
EXPECT_EQ(hmat.index.Size(), kRows*kCols);
|
||||
EXPECT_EQ(expected_bin_type_sizes[bin_id++], hmat.index.GetBinTypeSize());
|
||||
}
|
||||
@ -264,7 +264,7 @@ TEST(HistUtil, IndexBinBound) {
|
||||
|
||||
template <typename T>
|
||||
void CheckIndexData(T* data_ptr, uint32_t* offsets,
|
||||
const common::GHistIndexMatrix& hmat, size_t n_cols) {
|
||||
const GHistIndexMatrix& hmat, size_t n_cols) {
|
||||
for (size_t i = 0; i < hmat.index.Size(); ++i) {
|
||||
EXPECT_EQ(data_ptr[i] + offsets[i % n_cols], hmat.index[i]);
|
||||
}
|
||||
@ -279,8 +279,7 @@ TEST(HistUtil, IndexBinData) {
|
||||
|
||||
for (auto max_bin : kBinSizes) {
|
||||
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
|
||||
common::GHistIndexMatrix hmat;
|
||||
hmat.Init(p_fmat.get(), max_bin);
|
||||
GHistIndexMatrix hmat(p_fmat.get(), max_bin);
|
||||
uint32_t* offsets = hmat.index.Offset();
|
||||
EXPECT_EQ(hmat.index.Size(), kRows*kCols);
|
||||
switch (max_bin) {
|
||||
|
||||
@ -344,8 +344,7 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
auto dmat = RandomDataGenerator(kNRows, kNCols, 0).Seed(3).GenerateDMatrix();
|
||||
// dense, no missing values
|
||||
|
||||
common::GHistIndexMatrix gmat;
|
||||
gmat.Init(dmat.get(), kMaxBins);
|
||||
GHistIndexMatrix gmat(dmat.get(), kMaxBins);
|
||||
|
||||
RealImpl::InitData(gmat, *dmat, tree, &row_gpairs);
|
||||
this->hist_.AddHistRow(0);
|
||||
@ -434,8 +433,7 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
// kNRows samples with kNCols features
|
||||
auto dmat = RandomDataGenerator(kNRows, kNCols, sparsity).Seed(3).GenerateDMatrix();
|
||||
|
||||
common::GHistIndexMatrix gmat;
|
||||
gmat.Init(dmat.get(), kMaxBins);
|
||||
GHistIndexMatrix gmat(dmat.get(), kMaxBins);
|
||||
ColumnMatrix cm;
|
||||
|
||||
// treat everything as dense, as this is what we intend to test here
|
||||
@ -546,8 +544,7 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
|
||||
void TestInitData() {
|
||||
size_t constexpr kMaxBins = 4;
|
||||
common::GHistIndexMatrix gmat;
|
||||
gmat.Init(dmat_.get(), kMaxBins);
|
||||
GHistIndexMatrix gmat(dmat_.get(), kMaxBins);
|
||||
|
||||
RegTree tree = RegTree();
|
||||
tree.param.UpdateAllowUnknown(cfg_);
|
||||
@ -564,8 +561,7 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
|
||||
void TestInitDataSampling() {
|
||||
size_t constexpr kMaxBins = 4;
|
||||
common::GHistIndexMatrix gmat;
|
||||
gmat.Init(dmat_.get(), kMaxBins);
|
||||
GHistIndexMatrix gmat(dmat_.get(), kMaxBins);
|
||||
|
||||
RegTree tree = RegTree();
|
||||
tree.param.UpdateAllowUnknown(cfg_);
|
||||
@ -582,8 +578,7 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
|
||||
void TestAddHistRows() {
|
||||
size_t constexpr kMaxBins = 4;
|
||||
common::GHistIndexMatrix gmat;
|
||||
gmat.Init(dmat_.get(), kMaxBins);
|
||||
GHistIndexMatrix gmat(dmat_.get(), kMaxBins);
|
||||
|
||||
RegTree tree = RegTree();
|
||||
tree.param.UpdateAllowUnknown(cfg_);
|
||||
@ -599,8 +594,7 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
|
||||
void TestSyncHistograms() {
|
||||
size_t constexpr kMaxBins = 4;
|
||||
common::GHistIndexMatrix gmat;
|
||||
gmat.Init(dmat_.get(), kMaxBins);
|
||||
GHistIndexMatrix gmat(dmat_.get(), kMaxBins);
|
||||
|
||||
RegTree tree = RegTree();
|
||||
tree.param.UpdateAllowUnknown(cfg_);
|
||||
@ -620,8 +614,7 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
tree.param.UpdateAllowUnknown(cfg_);
|
||||
|
||||
size_t constexpr kMaxBins = 4;
|
||||
common::GHistIndexMatrix gmat;
|
||||
gmat.Init(dmat_.get(), kMaxBins);
|
||||
GHistIndexMatrix gmat(dmat_.get(), kMaxBins);
|
||||
if (double_builder_) {
|
||||
double_builder_->TestBuildHist(0, gmat, *dmat_, tree);
|
||||
} else {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user