Move GHistIndex into DMatrix. (#7064)

This commit is contained in:
Jiaming Yuan 2021-07-01 00:44:49 +08:00 committed by GitHub
parent 1c8fdf2218
commit 1cd20efe68
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 386 additions and 320 deletions

View File

@ -38,6 +38,7 @@
#include "../src/data/sparse_page_raw_format.cc"
#include "../src/data/ellpack_page.cc"
#include "../src/data/ellpack_page_source.cc"
#include "../src/data/gradient_index.cc"
// prediction
#include "../src/predictor/predictor.cc"

View File

@ -385,6 +385,8 @@ class EllpackPage {
std::unique_ptr<EllpackPageImpl> impl_;
};
class GHistIndexMatrix;
template<typename T>
class BatchIteratorImpl {
public:
@ -553,6 +555,7 @@ class DMatrix {
virtual BatchSet<CSCPage> GetColumnBatches() = 0;
virtual BatchSet<SortedCSCPage> GetSortedColumnBatches() = 0;
virtual BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) = 0;
virtual BatchSet<GHistIndexMatrix> GetGradientIndex(const BatchParam& param) = 0;
virtual bool EllpackExists() const = 0;
virtual bool SparsePageExists() const = 0;
@ -587,6 +590,11 @@ template<>
inline BatchSet<EllpackPage> DMatrix::GetBatches(const BatchParam& param) {
return GetEllpackBatches(param);
}
template<>
inline BatchSet<GHistIndexMatrix> DMatrix::GetBatches(const BatchParam& param) {
return GetGradientIndex(param);
}
} // namespace xgboost
namespace dmlc {

View File

@ -12,6 +12,7 @@
#include <vector>
#include <memory>
#include "hist_util.h"
#include "../data/gradient_index.h"
namespace xgboost {
namespace common {
@ -263,8 +264,9 @@ class ColumnMatrix {
}
template <typename T>
inline void SetIndexAllDense(T* index, const GHistIndexMatrix& gmat, const size_t nrow,
const size_t nfeature, const bool noMissingValues) {
inline void SetIndexAllDense(T *index, const GHistIndexMatrix &gmat,
const size_t nrow, const size_t nfeature,
const bool noMissingValues) {
T* local_index = reinterpret_cast<T*>(&index_[0]);
/* missing values make sense only for column with type kDenseColumn,

View File

@ -16,6 +16,7 @@
#include "column_matrix.h"
#include "quantile.h"
#include "./../tree/updater_quantile_hist.h"
#include "../data/gradient_index.h"
#if defined(XGBOOST_MM_PREFETCH_PRESENT)
#include <xmmintrin.h>
@ -29,164 +30,10 @@
namespace xgboost {
namespace common {
void GHistIndexMatrix::ResizeIndex(const size_t n_index,
const bool isDense) {
if ((max_num_bins - 1 <= static_cast<int>(std::numeric_limits<uint8_t>::max())) && isDense) {
index.SetBinTypeSize(kUint8BinsTypeSize);
index.Resize((sizeof(uint8_t)) * n_index);
} else if ((max_num_bins - 1 > static_cast<int>(std::numeric_limits<uint8_t>::max()) &&
max_num_bins - 1 <= static_cast<int>(std::numeric_limits<uint16_t>::max())) && isDense) {
index.SetBinTypeSize(kUint16BinsTypeSize);
index.Resize((sizeof(uint16_t)) * n_index);
} else {
index.SetBinTypeSize(kUint32BinsTypeSize);
index.Resize((sizeof(uint32_t)) * n_index);
}
}
HistogramCuts::HistogramCuts() {
cut_ptrs_.HostVector().emplace_back(0);
}
void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_bins) {
cut = SketchOnDMatrix(p_fmat, max_bins);
max_num_bins = max_bins;
const int32_t nthread = omp_get_max_threads();
const uint32_t nbins = cut.Ptrs().back();
hit_count.resize(nbins, 0);
hit_count_tloc_.resize(nthread * nbins, 0);
this->p_fmat = p_fmat;
size_t new_size = 1;
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
new_size += batch.Size();
}
row_ptr.resize(new_size);
row_ptr[0] = 0;
size_t rbegin = 0;
size_t prev_sum = 0;
const bool isDense = p_fmat->IsDense();
this->isDense_ = isDense;
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
// The number of threads is pegged to the batch size. If the OMP
// block is parallelized on anything other than the batch/block size,
// it should be reassigned
const size_t batch_threads = std::max(
size_t(1),
std::min(batch.Size(), static_cast<size_t>(omp_get_max_threads())));
auto page = batch.GetView();
MemStackAllocator<size_t, 128> partial_sums(batch_threads);
size_t* p_part = partial_sums.Get();
size_t block_size = batch.Size() / batch_threads;
dmlc::OMPException exc;
#pragma omp parallel num_threads(batch_threads)
{
#pragma omp for
for (omp_ulong tid = 0; tid < batch_threads; ++tid) {
exc.Run([&]() {
size_t ibegin = block_size * tid;
size_t iend = (tid == (batch_threads-1) ? batch.Size() : (block_size * (tid+1)));
size_t sum = 0;
for (size_t i = ibegin; i < iend; ++i) {
sum += page[i].size();
row_ptr[rbegin + 1 + i] = sum;
}
});
}
#pragma omp single
{
exc.Run([&]() {
p_part[0] = prev_sum;
for (size_t i = 1; i < batch_threads; ++i) {
p_part[i] = p_part[i - 1] + row_ptr[rbegin + i*block_size];
}
});
}
#pragma omp for
for (omp_ulong tid = 0; tid < batch_threads; ++tid) {
exc.Run([&]() {
size_t ibegin = block_size * tid;
size_t iend = (tid == (batch_threads-1) ? batch.Size() : (block_size * (tid+1)));
for (size_t i = ibegin; i < iend; ++i) {
row_ptr[rbegin + 1 + i] += p_part[tid];
}
});
}
}
exc.Rethrow();
const size_t n_offsets = cut.Ptrs().size() - 1;
const size_t n_index = row_ptr[rbegin + batch.Size()];
ResizeIndex(n_index, isDense);
CHECK_GT(cut.Values().size(), 0U);
uint32_t* offsets = nullptr;
if (isDense) {
index.ResizeOffset(n_offsets);
offsets = index.Offset();
for (size_t i = 0; i < n_offsets; ++i) {
offsets[i] = cut.Ptrs()[i];
}
}
if (isDense) {
BinTypeSize curent_bin_size = index.GetBinTypeSize();
if (curent_bin_size == kUint8BinsTypeSize) {
common::Span<uint8_t> index_data_span = {index.data<uint8_t>(),
n_index};
SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins,
[offsets](auto idx, auto j) {
return static_cast<uint8_t>(idx - offsets[j]);
});
} else if (curent_bin_size == kUint16BinsTypeSize) {
common::Span<uint16_t> index_data_span = {index.data<uint16_t>(),
n_index};
SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins,
[offsets](auto idx, auto j) {
return static_cast<uint16_t>(idx - offsets[j]);
});
} else {
CHECK_EQ(curent_bin_size, kUint32BinsTypeSize);
common::Span<uint32_t> index_data_span = {index.data<uint32_t>(),
n_index};
SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins,
[offsets](auto idx, auto j) {
return static_cast<uint32_t>(idx - offsets[j]);
});
}
/* For sparse DMatrix we have to store index of feature for each bin
in index field to chose right offset. So offset is nullptr and index is not reduced */
} else {
common::Span<uint32_t> index_data_span = {index.data<uint32_t>(), n_index};
SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins,
[](auto idx, auto) { return idx; });
}
ParallelFor(bst_omp_uint(nbins), nthread, [&](bst_omp_uint idx) {
for (int32_t tid = 0; tid < nthread; ++tid) {
hit_count[idx] += hit_count_tloc_[tid * nbins + idx];
hit_count_tloc_[tid * nbins + idx] = 0; // reset for next batch
}
});
prev_sum = row_ptr[rbegin + batch.Size()];
rbegin += batch.Size();
}
}
/*!
* \brief fill a histogram by zeros in range [begin, end)
*/
@ -382,23 +229,23 @@ void GHistBuilder<GradientSumT>::BuildHist(
BuildHistDispatch<GradientSumT, false, any_missing>(gpair, span2, gmat, hist);
}
}
template
void GHistBuilder<float>::BuildHist<true>(const std::vector<GradientPair>& gpair,
template void
GHistBuilder<float>::BuildHist<true>(const std::vector<GradientPair> &gpair,
const RowSetCollection::Elem row_indices,
const GHistIndexMatrix &gmat,
GHistRow<float> hist);
template
void GHistBuilder<float>::BuildHist<false>(const std::vector<GradientPair>& gpair,
template void
GHistBuilder<float>::BuildHist<false>(const std::vector<GradientPair> &gpair,
const RowSetCollection::Elem row_indices,
const GHistIndexMatrix &gmat,
GHistRow<float> hist);
template
void GHistBuilder<double>::BuildHist<true>(const std::vector<GradientPair>& gpair,
template void
GHistBuilder<double>::BuildHist<true>(const std::vector<GradientPair> &gpair,
const RowSetCollection::Elem row_indices,
const GHistIndexMatrix &gmat,
GHistRow<double> hist);
template
void GHistBuilder<double>::BuildHist<false>(const std::vector<GradientPair>& gpair,
template void
GHistBuilder<double>::BuildHist<false>(const std::vector<GradientPair> &gpair,
const RowSetCollection::Elem row_indices,
const GHistIndexMatrix &gmat,
GHistRow<double> hist);

View File

@ -25,6 +25,8 @@
#include "../include/rabit/rabit.h"
namespace xgboost {
class GHistIndexMatrix;
namespace common {
/*!
* \brief A single row in global histogram index.
@ -226,74 +228,6 @@ struct Index {
Func func_;
};
/*!
* \brief preprocessed global index matrix, in CSR format
*
* Transform floating values to integer index in histogram This is a global histogram
* index for CPU histogram. On GPU ellpack page is used.
*/
struct GHistIndexMatrix {
/*! \brief row pointer to rows by element position */
std::vector<size_t> row_ptr;
/*! \brief The index data */
Index index;
/*! \brief hit count of each index */
std::vector<size_t> hit_count;
/*! \brief The corresponding cuts */
HistogramCuts cut;
DMatrix* p_fmat;
size_t max_num_bins;
// Create a global histogram matrix, given cut
void Init(DMatrix* p_fmat, int max_num_bins);
// specific method for sparse data as no possibility to reduce allocated memory
template <typename BinIdxType, typename GetOffset>
void SetIndexData(common::Span<BinIdxType> index_data_span,
size_t batch_threads, const SparsePage &batch,
size_t rbegin, size_t nbins, GetOffset get_offset) {
const xgboost::Entry *data_ptr = batch.data.HostVector().data();
const std::vector<bst_row_t> &offset_vec = batch.offset.HostVector();
const size_t batch_size = batch.Size();
CHECK_LT(batch_size, offset_vec.size());
BinIdxType* index_data = index_data_span.data();
ParallelFor(omp_ulong(batch_size), batch_threads, [&](omp_ulong i) {
const int tid = omp_get_thread_num();
size_t ibegin = row_ptr[rbegin + i];
size_t iend = row_ptr[rbegin + i + 1];
const size_t size = offset_vec[i + 1] - offset_vec[i];
SparsePage::Inst inst = {data_ptr + offset_vec[i], size};
CHECK_EQ(ibegin + inst.size(), iend);
for (bst_uint j = 0; j < inst.size(); ++j) {
uint32_t idx = cut.SearchBin(inst[j]);
index_data[ibegin + j] = get_offset(idx, j);
++hit_count_tloc_[tid * nbins + idx];
}
});
}
void ResizeIndex(const size_t n_index,
const bool isDense);
inline void GetFeatureCounts(size_t* counts) const {
auto nfeature = cut.Ptrs().size() - 1;
for (unsigned fid = 0; fid < nfeature; ++fid) {
auto ibegin = cut.Ptrs()[fid];
auto iend = cut.Ptrs()[fid + 1];
for (auto i = ibegin; i < iend; ++i) {
counts[fid] += hit_count[i];
}
}
}
inline bool IsDense() const {
return isDense_;
}
private:
std::vector<size_t> hit_count_tloc_;
bool isDense_;
};
template <typename GradientIndex>
int32_t XGBOOST_HOST_DEV_INLINE BinarySearchBin(size_t begin, size_t end,
GradientIndex const &data,
@ -647,6 +581,42 @@ class GHistBuilder {
/*! \brief number of all bins over all features */
uint32_t nbins_ { 0 };
};
/*!
* \brief A C-style array with in-stack allocation. As long as the array is smaller than
* MaxStackSize, it will be allocated inside the stack. Otherwise, it will be
* heap-allocated.
*/
template<typename T, size_t MaxStackSize>
class MemStackAllocator {
public:
explicit MemStackAllocator(size_t required_size): required_size_(required_size) {
}
T* Get() {
if (!ptr_) {
if (MaxStackSize >= required_size_) {
ptr_ = stack_mem_;
} else {
ptr_ = reinterpret_cast<T*>(malloc(required_size_ * sizeof(T)));
do_free_ = true;
}
}
return ptr_;
}
~MemStackAllocator() {
if (do_free_) free(ptr_);
}
private:
T* ptr_ = nullptr;
bool do_free_ = false;
size_t required_size_;
T stack_mem_[MaxStackSize];
};
} // namespace common
} // namespace xgboost
#endif // XGBOOST_COMMON_HIST_UTIL_H_

165
src/data/gradient_index.cc Normal file
View File

@ -0,0 +1,165 @@
/*!
* Copyright 2017-2021 by Contributors
* \brief Data type for fast histogram aggregation.
*/
#include <algorithm>
#include <limits>
#include "gradient_index.h"
#include "../common/hist_util.h"
namespace xgboost {
void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_bins) {
cut = common::SketchOnDMatrix(p_fmat, max_bins);
max_num_bins = max_bins;
const int32_t nthread = omp_get_max_threads();
const uint32_t nbins = cut.Ptrs().back();
hit_count.resize(nbins, 0);
hit_count_tloc_.resize(nthread * nbins, 0);
this->p_fmat = p_fmat;
size_t new_size = 1;
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
new_size += batch.Size();
}
row_ptr.resize(new_size);
row_ptr[0] = 0;
size_t rbegin = 0;
size_t prev_sum = 0;
const bool isDense = p_fmat->IsDense();
this->isDense_ = isDense;
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
// The number of threads is pegged to the batch size. If the OMP
// block is parallelized on anything other than the batch/block size,
// it should be reassigned
const size_t batch_threads = std::max(
size_t(1),
std::min(batch.Size(), static_cast<size_t>(omp_get_max_threads())));
auto page = batch.GetView();
common::MemStackAllocator<size_t, 128> partial_sums(batch_threads);
size_t* p_part = partial_sums.Get();
size_t block_size = batch.Size() / batch_threads;
dmlc::OMPException exc;
#pragma omp parallel num_threads(batch_threads)
{
#pragma omp for
for (omp_ulong tid = 0; tid < batch_threads; ++tid) {
exc.Run([&]() {
size_t ibegin = block_size * tid;
size_t iend = (tid == (batch_threads-1) ? batch.Size() : (block_size * (tid+1)));
size_t sum = 0;
for (size_t i = ibegin; i < iend; ++i) {
sum += page[i].size();
row_ptr[rbegin + 1 + i] = sum;
}
});
}
#pragma omp single
{
exc.Run([&]() {
p_part[0] = prev_sum;
for (size_t i = 1; i < batch_threads; ++i) {
p_part[i] = p_part[i - 1] + row_ptr[rbegin + i*block_size];
}
});
}
#pragma omp for
for (omp_ulong tid = 0; tid < batch_threads; ++tid) {
exc.Run([&]() {
size_t ibegin = block_size * tid;
size_t iend = (tid == (batch_threads-1) ? batch.Size() : (block_size * (tid+1)));
for (size_t i = ibegin; i < iend; ++i) {
row_ptr[rbegin + 1 + i] += p_part[tid];
}
});
}
}
exc.Rethrow();
const size_t n_offsets = cut.Ptrs().size() - 1;
const size_t n_index = row_ptr[rbegin + batch.Size()];
ResizeIndex(n_index, isDense);
CHECK_GT(cut.Values().size(), 0U);
uint32_t* offsets = nullptr;
if (isDense) {
index.ResizeOffset(n_offsets);
offsets = index.Offset();
for (size_t i = 0; i < n_offsets; ++i) {
offsets[i] = cut.Ptrs()[i];
}
}
if (isDense) {
common::BinTypeSize curent_bin_size = index.GetBinTypeSize();
if (curent_bin_size == common::kUint8BinsTypeSize) {
common::Span<uint8_t> index_data_span = {index.data<uint8_t>(),
n_index};
SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins,
[offsets](auto idx, auto j) {
return static_cast<uint8_t>(idx - offsets[j]);
});
} else if (curent_bin_size == common::kUint16BinsTypeSize) {
common::Span<uint16_t> index_data_span = {index.data<uint16_t>(),
n_index};
SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins,
[offsets](auto idx, auto j) {
return static_cast<uint16_t>(idx - offsets[j]);
});
} else {
CHECK_EQ(curent_bin_size, common::kUint32BinsTypeSize);
common::Span<uint32_t> index_data_span = {index.data<uint32_t>(),
n_index};
SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins,
[offsets](auto idx, auto j) {
return static_cast<uint32_t>(idx - offsets[j]);
});
}
/* For sparse DMatrix we have to store index of feature for each bin
in index field to chose right offset. So offset is nullptr and index is not reduced */
} else {
common::Span<uint32_t> index_data_span = {index.data<uint32_t>(), n_index};
SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins,
[](auto idx, auto) { return idx; });
}
common::ParallelFor(bst_omp_uint(nbins), nthread, [&](bst_omp_uint idx) {
for (int32_t tid = 0; tid < nthread; ++tid) {
hit_count[idx] += hit_count_tloc_[tid * nbins + idx];
hit_count_tloc_[tid * nbins + idx] = 0; // reset for next batch
}
});
prev_sum = row_ptr[rbegin + batch.Size()];
rbegin += batch.Size();
}
}
void GHistIndexMatrix::ResizeIndex(const size_t n_index,
const bool isDense) {
if ((max_num_bins - 1 <= static_cast<int>(std::numeric_limits<uint8_t>::max())) && isDense) {
index.SetBinTypeSize(common::kUint8BinsTypeSize);
index.Resize((sizeof(uint8_t)) * n_index);
} else if ((max_num_bins - 1 > static_cast<int>(std::numeric_limits<uint8_t>::max()) &&
max_num_bins - 1 <= static_cast<int>(std::numeric_limits<uint16_t>::max())) && isDense) {
index.SetBinTypeSize(common::kUint16BinsTypeSize);
index.Resize((sizeof(uint16_t)) * n_index);
} else {
index.SetBinTypeSize(common::kUint32BinsTypeSize);
index.Resize((sizeof(uint32_t)) * n_index);
}
}
} // namespace xgboost

86
src/data/gradient_index.h Normal file
View File

@ -0,0 +1,86 @@
/*!
* Copyright 2017-2021 by Contributors
* \brief Data type for fast histogram aggregation.
*/
#ifndef XGBOOST_DATA_GRADIENT_INDEX_H_
#define XGBOOST_DATA_GRADIENT_INDEX_H_
#include <vector>
#include "xgboost/base.h"
#include "xgboost/data.h"
#include "../common/hist_util.h"
#include "../common/threading_utils.h"
namespace xgboost {
/*!
* \brief preprocessed global index matrix, in CSR format
*
* Transform floating values to integer index in histogram This is a global histogram
* index for CPU histogram. On GPU ellpack page is used.
*/
class GHistIndexMatrix {
public:
/*! \brief row pointer to rows by element position */
std::vector<size_t> row_ptr;
/*! \brief The index data */
common::Index index;
/*! \brief hit count of each index */
std::vector<size_t> hit_count;
/*! \brief The corresponding cuts */
common::HistogramCuts cut;
DMatrix* p_fmat;
size_t max_num_bins;
GHistIndexMatrix(DMatrix* x, int32_t max_bin) {
this->Init(x, max_bin);
}
// Create a global histogram matrix, given cut
void Init(DMatrix* p_fmat, int max_num_bins);
// specific method for sparse data as no possibility to reduce allocated memory
template <typename BinIdxType, typename GetOffset>
void SetIndexData(common::Span<BinIdxType> index_data_span,
size_t batch_threads, const SparsePage &batch,
size_t rbegin, size_t nbins, GetOffset get_offset) {
const xgboost::Entry *data_ptr = batch.data.HostVector().data();
const std::vector<bst_row_t> &offset_vec = batch.offset.HostVector();
const size_t batch_size = batch.Size();
CHECK_LT(batch_size, offset_vec.size());
BinIdxType* index_data = index_data_span.data();
common::ParallelFor(omp_ulong(batch_size), batch_threads, [&](omp_ulong i) {
const int tid = omp_get_thread_num();
size_t ibegin = row_ptr[rbegin + i];
size_t iend = row_ptr[rbegin + i + 1];
const size_t size = offset_vec[i + 1] - offset_vec[i];
SparsePage::Inst inst = {data_ptr + offset_vec[i], size};
CHECK_EQ(ibegin + inst.size(), iend);
for (bst_uint j = 0; j < inst.size(); ++j) {
uint32_t idx = cut.SearchBin(inst[j]);
index_data[ibegin + j] = get_offset(idx, j);
++hit_count_tloc_[tid * nbins + idx];
}
});
}
void ResizeIndex(const size_t n_index,
const bool isDense);
inline void GetFeatureCounts(size_t* counts) const {
auto nfeature = cut.Ptrs().size() - 1;
for (unsigned fid = 0; fid < nfeature; ++fid) {
auto ibegin = cut.Ptrs()[fid];
auto iend = cut.Ptrs()[fid + 1];
for (auto i = ibegin; i < iend; ++i) {
counts[fid] += hit_count[i];
}
}
}
inline bool IsDense() const {
return isDense_;
}
private:
std::vector<size_t> hit_count_tloc_;
bool isDense_;
};
} // namespace xgboost
#endif // XGBOOST_DATA_GRADIENT_INDEX_H_

View File

@ -58,6 +58,10 @@ class IterativeDeviceDMatrix : public DMatrix {
LOG(FATAL) << "Not implemented.";
return BatchSet<SortedCSCPage>(BatchIterator<SortedCSCPage>(nullptr));
}
BatchSet<GHistIndexMatrix> GetGradientIndex(const BatchParam&) override {
LOG(FATAL) << "Not implemented.";
return BatchSet<GHistIndexMatrix>(BatchIterator<GHistIndexMatrix>(nullptr));
}
BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) override;

View File

@ -97,6 +97,10 @@ class DMatrixProxy : public DMatrix {
LOG(FATAL) << "Not implemented.";
return BatchSet<EllpackPage>(BatchIterator<EllpackPage>(nullptr));
}
BatchSet<GHistIndexMatrix> GetGradientIndex(const BatchParam&) override {
LOG(FATAL) << "Not implemented.";
return BatchSet<GHistIndexMatrix>(BatchIterator<GHistIndexMatrix>(nullptr));
}
dmlc::any Adapter() const {
return batch_;

View File

@ -17,6 +17,7 @@
#include "../common/random.h"
#include "../common/threading_utils.h"
#include "adapter.h"
#include "gradient_index.h"
namespace xgboost {
namespace data {
@ -89,6 +90,20 @@ BatchSet<EllpackPage> SimpleDMatrix::GetEllpackBatches(const BatchParam& param)
return BatchSet<EllpackPage>(begin_iter);
}
BatchSet<GHistIndexMatrix> SimpleDMatrix::GetGradientIndex(const BatchParam& param) {
if (!(batch_param_ != BatchParam{})) {
CHECK(param != BatchParam{}) << "Batch parameter is not initialized.";
}
if (!gradient_index_ || (batch_param_ != param && param != BatchParam{})) {
CHECK_GE(param.max_bin, 2);
gradient_index_.reset(new GHistIndexMatrix(this, param.max_bin));
batch_param_ = param;
}
auto begin_iter = BatchIterator<GHistIndexMatrix>(
new SimpleBatchIteratorImpl<GHistIndexMatrix>(gradient_index_.get()));
return BatchSet<GHistIndexMatrix>(begin_iter);
}
template <typename AdapterT>
SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
std::vector<uint64_t> qids;

View File

@ -13,6 +13,7 @@
#include <memory>
#include <string>
#include "gradient_index.h"
namespace xgboost {
namespace data {
@ -43,12 +44,14 @@ class SimpleDMatrix : public DMatrix {
BatchSet<CSCPage> GetColumnBatches() override;
BatchSet<SortedCSCPage> GetSortedColumnBatches() override;
BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) override;
BatchSet<GHistIndexMatrix> GetGradientIndex(const BatchParam& param) override;
MetaInfo info_;
SparsePage sparse_page_; // Primary storage type
std::unique_ptr<CSCPage> column_page_;
std::unique_ptr<SortedCSCPage> sorted_column_page_;
std::unique_ptr<EllpackPage> ellpack_page_;
std::unique_ptr<GHistIndexMatrix> gradient_index_;
BatchParam batch_param_;
bool EllpackExists() const override {

View File

@ -47,6 +47,10 @@ class SparsePageDMatrix : public DMatrix {
BatchSet<CSCPage> GetColumnBatches() override;
BatchSet<SortedCSCPage> GetSortedColumnBatches() override;
BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) override;
BatchSet<GHistIndexMatrix> GetGradientIndex(const BatchParam&) override {
LOG(FATAL) << "Not implemented.";
return BatchSet<GHistIndexMatrix>(BatchIterator<GHistIndexMatrix>(nullptr));
}
// source data pointers.
std::unique_ptr<SparsePageSource> row_source_;

View File

@ -69,18 +69,22 @@ template<typename GradientSumT>
void QuantileHistMaker::CallBuilderUpdate(const std::unique_ptr<Builder<GradientSumT>>& builder,
HostDeviceVector<GradientPair> *gpair,
DMatrix *dmat,
GHistIndexMatrix const& gmat,
const std::vector<RegTree *> &trees) {
for (auto tree : trees) {
builder->Update(gmat_, column_matrix_, gpair, dmat, tree);
builder->Update(gmat, column_matrix_, gpair, dmat, tree);
}
}
void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair,
DMatrix *dmat,
const std::vector<RegTree *> &trees) {
auto const &gmat =
*(dmat->GetBatches<GHistIndexMatrix>(
BatchParam{GenericParameter::kCpuId, param_.max_bin})
.begin());
if (dmat != p_last_dmat_ || is_gmat_initialized_ == false) {
updater_monitor_.Start("GmatInitialization");
gmat_.Init(dmat, static_cast<uint32_t>(param_.max_bin));
column_matrix_.Init(gmat_, param_.sparse_threshold);
column_matrix_.Init(gmat, param_.sparse_threshold);
updater_monitor_.Stop("GmatInitialization");
// A proper solution is puting cut matrix in DMatrix, see:
// https://github.com/dmlc/xgboost/issues/5143
@ -96,12 +100,12 @@ void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair,
if (!float_builder_) {
SetBuilder(n_trees, &float_builder_, dmat);
}
CallBuilderUpdate(float_builder_, gpair, dmat, trees);
CallBuilderUpdate(float_builder_, gpair, dmat, gmat, trees);
} else {
if (!double_builder_) {
SetBuilder(n_trees, &double_builder_, dmat);
}
CallBuilderUpdate(double_builder_, gpair, dmat, trees);
CallBuilderUpdate(double_builder_, gpair, dmat, gmat, trees);
}
param_.learning_rate = lr;
@ -678,7 +682,7 @@ void QuantileHistMaker::Builder<GradientSumT>::InitData(const GHistIndexMatrix&
// We should check that the partitioning was done correctly
// and each row of the dataset fell into exactly one of the categories
}
MemStackAllocator<bool, 128> buff(this->nthread_);
common::MemStackAllocator<bool, 128> buff(this->nthread_);
bool* p_buff = buff.Get();
std::fill(p_buff, p_buff + this->nthread_, false);

View File

@ -75,43 +75,9 @@ struct RandomReplace {
}
};
/*!
* \brief A C-style array with in-stack allocation. As long as the array is smaller than MaxStackSize, it will be allocated inside the stack. Otherwise, it will be heap-allocated.
*/
template<typename T, size_t MaxStackSize>
class MemStackAllocator {
public:
explicit MemStackAllocator(size_t required_size): required_size_(required_size) {
}
T* Get() {
if (!ptr_) {
if (MaxStackSize >= required_size_) {
ptr_ = stack_mem_;
} else {
ptr_ = reinterpret_cast<T*>(malloc(required_size_ * sizeof(T)));
do_free_ = true;
}
}
return ptr_;
}
~MemStackAllocator() {
if (do_free_) free(ptr_);
}
private:
T* ptr_ = nullptr;
bool do_free_ = false;
size_t required_size_;
T stack_mem_[MaxStackSize];
};
namespace tree {
using xgboost::common::GHistIndexMatrix;
using xgboost::GHistIndexMatrix;
using xgboost::common::GHistIndexRow;
using xgboost::common::HistCollection;
using xgboost::common::RowSetCollection;
@ -243,8 +209,6 @@ class QuantileHistMaker: public TreeUpdater {
CPUHistMakerTrainParam hist_maker_param_;
// training parameter
TrainParam param_;
// quantized data matrix
GHistIndexMatrix gmat_;
// column accessor
ColumnMatrix column_matrix_;
DMatrix const* p_last_dmat_ {nullptr};
@ -466,6 +430,7 @@ class QuantileHistMaker: public TreeUpdater {
void CallBuilderUpdate(const std::unique_ptr<Builder<GradientSumT>>& builder,
HostDeviceVector<GradientPair> *gpair,
DMatrix *dmat,
GHistIndexMatrix const& gmat,
const std::vector<RegTree *> &trees);
protected:

View File

@ -14,8 +14,7 @@ TEST(DenseColumn, Test) {
static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 2};
for (size_t max_num_bin : max_num_bins) {
auto dmat = RandomDataGenerator(100, 10, 0.0).GenerateDMatrix();
GHistIndexMatrix gmat;
gmat.Init(dmat.get(), max_num_bin);
GHistIndexMatrix gmat(dmat.get(), max_num_bin);
ColumnMatrix column_matrix;
column_matrix.Init(gmat, 0.2);
@ -62,8 +61,7 @@ TEST(SparseColumn, Test) {
static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 2};
for (size_t max_num_bin : max_num_bins) {
auto dmat = RandomDataGenerator(100, 1, 0.85).GenerateDMatrix();
GHistIndexMatrix gmat;
gmat.Init(dmat.get(), max_num_bin);
GHistIndexMatrix gmat(dmat.get(), max_num_bin);
ColumnMatrix column_matrix;
column_matrix.Init(gmat, 0.5);
switch (column_matrix.GetTypeSize()) {
@ -103,8 +101,7 @@ TEST(DenseColumnWithMissing, Test) {
static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 2 };
for (size_t max_num_bin : max_num_bins) {
auto dmat = RandomDataGenerator(100, 1, 0.5).GenerateDMatrix();
GHistIndexMatrix gmat;
gmat.Init(dmat.get(), max_num_bin);
GHistIndexMatrix gmat(dmat.get(), max_num_bin);
ColumnMatrix column_matrix;
column_matrix.Init(gmat, 0.2);
switch (column_matrix.GetTypeSize()) {
@ -135,8 +132,7 @@ void TestGHistIndexMatrixCreation(size_t nthreads) {
/* This should create multiple sparse pages */
std::unique_ptr<DMatrix> dmat{ CreateSparsePageDMatrix(kEntries, kPageSize, filename) };
omp_set_num_threads(nthreads);
GHistIndexMatrix gmat;
gmat.Init(dmat.get(), 256);
GHistIndexMatrix gmat(dmat.get(), 256);
}
TEST(HistIndexCreationWithExternalMemory, Test) {

View File

@ -4,6 +4,7 @@
#include <utility>
#include "../../../src/common/hist_util.h"
#include "../../../src/data/gradient_index.h"
#include "../helpers.h"
#include "test_hist_util.h"
@ -255,8 +256,7 @@ TEST(HistUtil, IndexBinBound) {
for (auto max_bin : bin_sizes) {
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
common::GHistIndexMatrix hmat;
hmat.Init(p_fmat.get(), max_bin);
GHistIndexMatrix hmat(p_fmat.get(), max_bin);
EXPECT_EQ(hmat.index.Size(), kRows*kCols);
EXPECT_EQ(expected_bin_type_sizes[bin_id++], hmat.index.GetBinTypeSize());
}
@ -264,7 +264,7 @@ TEST(HistUtil, IndexBinBound) {
template <typename T>
void CheckIndexData(T* data_ptr, uint32_t* offsets,
const common::GHistIndexMatrix& hmat, size_t n_cols) {
const GHistIndexMatrix& hmat, size_t n_cols) {
for (size_t i = 0; i < hmat.index.Size(); ++i) {
EXPECT_EQ(data_ptr[i] + offsets[i % n_cols], hmat.index[i]);
}
@ -279,8 +279,7 @@ TEST(HistUtil, IndexBinData) {
for (auto max_bin : kBinSizes) {
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
common::GHistIndexMatrix hmat;
hmat.Init(p_fmat.get(), max_bin);
GHistIndexMatrix hmat(p_fmat.get(), max_bin);
uint32_t* offsets = hmat.index.Offset();
EXPECT_EQ(hmat.index.Size(), kRows*kCols);
switch (max_bin) {

View File

@ -344,8 +344,7 @@ class QuantileHistMock : public QuantileHistMaker {
auto dmat = RandomDataGenerator(kNRows, kNCols, 0).Seed(3).GenerateDMatrix();
// dense, no missing values
common::GHistIndexMatrix gmat;
gmat.Init(dmat.get(), kMaxBins);
GHistIndexMatrix gmat(dmat.get(), kMaxBins);
RealImpl::InitData(gmat, *dmat, tree, &row_gpairs);
this->hist_.AddHistRow(0);
@ -434,8 +433,7 @@ class QuantileHistMock : public QuantileHistMaker {
// kNRows samples with kNCols features
auto dmat = RandomDataGenerator(kNRows, kNCols, sparsity).Seed(3).GenerateDMatrix();
common::GHistIndexMatrix gmat;
gmat.Init(dmat.get(), kMaxBins);
GHistIndexMatrix gmat(dmat.get(), kMaxBins);
ColumnMatrix cm;
// treat everything as dense, as this is what we intend to test here
@ -546,8 +544,7 @@ class QuantileHistMock : public QuantileHistMaker {
void TestInitData() {
size_t constexpr kMaxBins = 4;
common::GHistIndexMatrix gmat;
gmat.Init(dmat_.get(), kMaxBins);
GHistIndexMatrix gmat(dmat_.get(), kMaxBins);
RegTree tree = RegTree();
tree.param.UpdateAllowUnknown(cfg_);
@ -564,8 +561,7 @@ class QuantileHistMock : public QuantileHistMaker {
void TestInitDataSampling() {
size_t constexpr kMaxBins = 4;
common::GHistIndexMatrix gmat;
gmat.Init(dmat_.get(), kMaxBins);
GHistIndexMatrix gmat(dmat_.get(), kMaxBins);
RegTree tree = RegTree();
tree.param.UpdateAllowUnknown(cfg_);
@ -582,8 +578,7 @@ class QuantileHistMock : public QuantileHistMaker {
void TestAddHistRows() {
size_t constexpr kMaxBins = 4;
common::GHistIndexMatrix gmat;
gmat.Init(dmat_.get(), kMaxBins);
GHistIndexMatrix gmat(dmat_.get(), kMaxBins);
RegTree tree = RegTree();
tree.param.UpdateAllowUnknown(cfg_);
@ -599,8 +594,7 @@ class QuantileHistMock : public QuantileHistMaker {
void TestSyncHistograms() {
size_t constexpr kMaxBins = 4;
common::GHistIndexMatrix gmat;
gmat.Init(dmat_.get(), kMaxBins);
GHistIndexMatrix gmat(dmat_.get(), kMaxBins);
RegTree tree = RegTree();
tree.param.UpdateAllowUnknown(cfg_);
@ -620,8 +614,7 @@ class QuantileHistMock : public QuantileHistMaker {
tree.param.UpdateAllowUnknown(cfg_);
size_t constexpr kMaxBins = 4;
common::GHistIndexMatrix gmat;
gmat.Init(dmat_.get(), kMaxBins);
GHistIndexMatrix gmat(dmat_.get(), kMaxBins);
if (double_builder_) {
double_builder_->TestBuildHist(0, gmat, *dmat_, tree);
} else {