Small cleanup to gradient index and hist. (#7668)

* Code comments.
* Const accessor to index.
* Remove some weird variables in the `Index` class.
* Simplify the `MemStackAllocator`.
This commit is contained in:
Jiaming Yuan 2022-02-23 11:37:21 +08:00 committed by GitHub
parent 49c74a5369
commit 6762c45494
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 149 additions and 148 deletions

View File

@ -266,9 +266,9 @@ class ColumnMatrix {
} }
template <typename T> template <typename T>
inline void SetIndexAllDense(T *index, const GHistIndexMatrix &gmat, inline void SetIndexAllDense(T const* index, const GHistIndexMatrix& gmat, const size_t nrow,
const size_t nrow, const size_t nfeature, const size_t nfeature, const bool noMissingValues,
const bool noMissingValues, int32_t n_threads) { int32_t n_threads) {
T* local_index = reinterpret_cast<T*>(&index_[0]); T* local_index = reinterpret_cast<T*>(&index_[0]);
/* missing values make sense only for column with type kDenseColumn, /* missing values make sense only for column with type kDenseColumn,
@ -313,7 +313,7 @@ class ColumnMatrix {
} }
template<typename T> template<typename T>
inline void SetIndex(uint32_t* index, const GHistIndexMatrix& gmat, inline void SetIndex(uint32_t const* index, const GHistIndexMatrix& gmat,
const size_t nfeature) { const size_t nfeature) {
std::vector<size_t> num_nonzeros; std::vector<size_t> num_nonzeros;
num_nonzeros.resize(nfeature); num_nonzeros.resize(nfeature);

View File

@ -197,19 +197,27 @@ enum BinTypeSize : uint32_t {
kUint32BinsTypeSize = 4 kUint32BinsTypeSize = 4
}; };
/**
* \brief Optionally compressed gradient index. The compression works only with dense
* data.
*
* The main body of construction code is in gradient_index.cc, this struct is only a
* storage class.
*/
struct Index { struct Index {
Index() { Index() { SetBinTypeSize(binTypeSize_); }
SetBinTypeSize(binTypeSize_);
}
Index(const Index& i) = delete; Index(const Index& i) = delete;
Index& operator=(Index i) = delete; Index& operator=(Index i) = delete;
Index(Index&& i) = delete; Index(Index&& i) = delete;
Index& operator=(Index&& i) = delete; Index& operator=(Index&& i) = delete;
uint32_t operator[](size_t i) const { uint32_t operator[](size_t i) const {
if (offset_ptr_ != nullptr) { if (!bin_offset_.empty()) {
return func_(data_ptr_, i) + offset_ptr_[i%p_]; // dense, compressed
auto fidx = i % bin_offset_.size();
// restore the index by adding back its feature offset.
return func_(data_.data(), i) + bin_offset_[fidx];
} else { } else {
return func_(data_ptr_, i); return func_(data_.data(), i);
} }
} }
void SetBinTypeSize(BinTypeSize binTypeSize) { void SetBinTypeSize(BinTypeSize binTypeSize) {
@ -225,8 +233,7 @@ struct Index {
func_ = &GetValueFromUint32; func_ = &GetValueFromUint32;
break; break;
default: default:
CHECK(binTypeSize == kUint8BinsTypeSize || CHECK(binTypeSize == kUint8BinsTypeSize || binTypeSize == kUint16BinsTypeSize ||
binTypeSize == kUint16BinsTypeSize ||
binTypeSize == kUint32BinsTypeSize); binTypeSize == kUint32BinsTypeSize);
} }
} }
@ -234,26 +241,24 @@ struct Index {
return binTypeSize_; return binTypeSize_;
} }
template <typename T> template <typename T>
T* data() const { // NOLINT T const* data() const { // NOLINT
return static_cast<T*>(data_ptr_); return reinterpret_cast<T const*>(data_.data());
} }
uint32_t* Offset() const { template <typename T>
return offset_ptr_; T* data() { // NOLINT
return reinterpret_cast<T*>(data_.data());
} }
size_t OffsetSize() const { uint32_t const* Offset() const { return bin_offset_.data(); }
return offset_.size(); size_t OffsetSize() const { return bin_offset_.size(); }
size_t Size() const { return data_.size() / (binTypeSize_); }
void Resize(const size_t n_bytes) {
data_.resize(n_bytes);
} }
size_t Size() const { // set the offset used in compression, cut_ptrs is the CSC indptr in HistogramCuts
return data_.size() / (binTypeSize_); void SetBinOffset(std::vector<uint32_t> const& cut_ptrs) {
} bin_offset_.resize(cut_ptrs.size() - 1); // resize to number of features.
void Resize(const size_t nBytesData) { std::copy_n(cut_ptrs.begin(), bin_offset_.size(), bin_offset_.begin());
data_.resize(nBytesData);
data_ptr_ = reinterpret_cast<void*>(data_.data());
}
void ResizeOffset(const size_t nDisps) {
offset_.resize(nDisps);
offset_ptr_ = offset_.data();
p_ = nDisps;
} }
std::vector<uint8_t>::const_iterator begin() const { // NOLINT std::vector<uint8_t>::const_iterator begin() const { // NOLINT
return data_.begin(); return data_.begin();
@ -270,24 +275,23 @@ struct Index {
} }
private: private:
static uint32_t GetValueFromUint8(void *t, size_t i) { // Functions to decompress the index.
return reinterpret_cast<uint8_t*>(t)[i]; static uint32_t GetValueFromUint8(uint8_t const* t, size_t i) { return t[i]; }
static uint32_t GetValueFromUint16(uint8_t const* t, size_t i) {
return reinterpret_cast<uint16_t const*>(t)[i];
} }
static uint32_t GetValueFromUint16(void* t, size_t i) { static uint32_t GetValueFromUint32(uint8_t const* t, size_t i) {
return reinterpret_cast<uint16_t*>(t)[i]; return reinterpret_cast<uint32_t const*>(t)[i];
}
static uint32_t GetValueFromUint32(void* t, size_t i) {
return reinterpret_cast<uint32_t*>(t)[i];
} }
using Func = uint32_t (*)(void*, size_t); using Func = uint32_t (*)(uint8_t const*, size_t);
std::vector<uint8_t> data_; std::vector<uint8_t> data_;
std::vector<uint32_t> offset_; // size of this field is equal to number of features // starting position of each feature inside the cut values (the indptr of the CSC cut matrix
void* data_ptr_; // HistogramCuts without the last entry.) Used for bin compression.
std::vector<uint32_t> bin_offset_;
BinTypeSize binTypeSize_ {kUint8BinsTypeSize}; BinTypeSize binTypeSize_ {kUint8BinsTypeSize};
size_t p_ {1};
uint32_t* offset_ptr_ {nullptr};
Func func_; Func func_;
}; };
@ -304,9 +308,11 @@ int32_t XGBOOST_HOST_DEV_INLINE BinarySearchBin(size_t begin, size_t end,
} }
previous_middle = middle; previous_middle = middle;
// index into all the bins
auto gidx = data[middle]; auto gidx = data[middle];
if (gidx >= fidx_begin && gidx < fidx_end) { if (gidx >= fidx_begin && gidx < fidx_end) {
// Found the intersection.
return static_cast<int32_t>(gidx); return static_cast<int32_t>(gidx);
} else if (gidx < fidx_begin) { } else if (gidx < fidx_begin) {
begin = middle; begin = middle;
@ -636,42 +642,6 @@ class GHistBuilder {
/*! \brief number of all bins over all features */ /*! \brief number of all bins over all features */
uint32_t nbins_ { 0 }; uint32_t nbins_ { 0 };
}; };
/*!
* \brief A C-style array with in-stack allocation. As long as the array is smaller than
* MaxStackSize, it will be allocated inside the stack. Otherwise, it will be
* heap-allocated.
*/
template<typename T, size_t MaxStackSize>
class MemStackAllocator {
public:
explicit MemStackAllocator(size_t required_size): required_size_(required_size) {
}
T* Get() {
if (!ptr_) {
if (MaxStackSize >= required_size_) {
ptr_ = stack_mem_;
} else {
ptr_ = reinterpret_cast<T*>(malloc(required_size_ * sizeof(T)));
do_free_ = true;
}
}
return ptr_;
}
~MemStackAllocator() {
if (do_free_) free(ptr_);
}
private:
T* ptr_ = nullptr;
bool do_free_ = false;
size_t required_size_;
T stack_mem_[MaxStackSize];
};
} // namespace common } // namespace common
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_COMMON_HIST_UTIL_H_ #endif // XGBOOST_COMMON_HIST_UTIL_H_

View File

@ -246,6 +246,43 @@ inline int32_t OmpGetNumThreads(int32_t n_threads) {
n_threads = std::max(n_threads, 1); n_threads = std::max(n_threads, 1);
return n_threads; return n_threads;
} }
/*!
* \brief A C-style array with in-stack allocation. As long as the array is smaller than
* MaxStackSize, it will be allocated inside the stack. Otherwise, it will be
* heap-allocated.
*/
template <typename T, size_t MaxStackSize>
class MemStackAllocator {
public:
explicit MemStackAllocator(size_t required_size) : required_size_(required_size) {
if (MaxStackSize >= required_size_) {
ptr_ = stack_mem_;
} else {
ptr_ = reinterpret_cast<T*>(malloc(required_size_ * sizeof(T)));
}
if (!ptr_) {
throw std::bad_alloc{};
}
}
~MemStackAllocator() {
if (required_size_ > MaxStackSize) {
free(ptr_);
}
}
T& operator[](size_t i) { return ptr_[i]; }
T const& operator[](size_t i) const { return ptr_[i]; }
// FIXME(jiamingy): Remove this once we merge partitioner cleanup for hist.
auto Get() { return ptr_; }
private:
T* ptr_ = nullptr;
size_t required_size_;
T stack_mem_[MaxStackSize];
};
} // namespace common } // namespace common
} // namespace xgboost } // namespace xgboost

View File

@ -10,6 +10,7 @@
#include "../common/column_matrix.h" #include "../common/column_matrix.h"
#include "../common/hist_util.h" #include "../common/hist_util.h"
#include "../common/threading_utils.h"
namespace xgboost { namespace xgboost {
@ -34,7 +35,6 @@ void GHistIndexMatrix::PushBatch(SparsePage const &batch,
std::max(static_cast<size_t>(1), std::min(batch.Size(), static_cast<size_t>(n_threads))); std::max(static_cast<size_t>(1), std::min(batch.Size(), static_cast<size_t>(n_threads)));
auto page = batch.GetView(); auto page = batch.GetView();
common::MemStackAllocator<size_t, 128> partial_sums(batch_threads); common::MemStackAllocator<size_t, 128> partial_sums(batch_threads);
size_t *p_part = partial_sums.Get();
size_t block_size = batch.Size() / batch_threads; size_t block_size = batch.Size() / batch_threads;
@ -48,10 +48,10 @@ void GHistIndexMatrix::PushBatch(SparsePage const &batch,
size_t iend = (tid == (batch_threads - 1) ? batch.Size() size_t iend = (tid == (batch_threads - 1) ? batch.Size()
: (block_size * (tid + 1))); : (block_size * (tid + 1)));
size_t sum = 0; size_t running_sum = 0;
for (size_t i = ibegin; i < iend; ++i) { for (size_t ridx = ibegin; ridx < iend; ++ridx) {
sum += page[i].size(); running_sum += page[ridx].size();
row_ptr[rbegin + 1 + i] = sum; row_ptr[rbegin + 1 + ridx] = running_sum;
} }
}); });
} }
@ -59,9 +59,9 @@ void GHistIndexMatrix::PushBatch(SparsePage const &batch,
#pragma omp single #pragma omp single
{ {
exc.Run([&]() { exc.Run([&]() {
p_part[0] = prev_sum; partial_sums[0] = prev_sum;
for (size_t i = 1; i < batch_threads; ++i) { for (size_t i = 1; i < batch_threads; ++i) {
p_part[i] = p_part[i - 1] + row_ptr[rbegin + i * block_size]; partial_sums[i] = partial_sums[i - 1] + row_ptr[rbegin + i * block_size];
} }
}); });
} }
@ -74,55 +74,52 @@ void GHistIndexMatrix::PushBatch(SparsePage const &batch,
: (block_size * (tid + 1))); : (block_size * (tid + 1)));
for (size_t i = ibegin; i < iend; ++i) { for (size_t i = ibegin; i < iend; ++i) {
row_ptr[rbegin + 1 + i] += p_part[tid]; row_ptr[rbegin + 1 + i] += partial_sums[tid];
} }
}); });
} }
} }
exc.Rethrow(); exc.Rethrow();
const size_t n_offsets = cut.Ptrs().size() - 1; const size_t n_index = row_ptr[rbegin + batch.Size()]; // number of entries in this page
const size_t n_index = row_ptr[rbegin + batch.Size()];
ResizeIndex(n_index, isDense_); ResizeIndex(n_index, isDense_);
CHECK_GT(cut.Values().size(), 0U); CHECK_GT(cut.Values().size(), 0U);
uint32_t *offsets = nullptr;
if (isDense_) { if (isDense_) {
index.ResizeOffset(n_offsets); index.SetBinOffset(cut.Ptrs());
offsets = index.Offset();
for (size_t i = 0; i < n_offsets; ++i) {
offsets[i] = cut.Ptrs()[i];
}
} }
uint32_t const *offsets = index.Offset();
if (isDense_) { if (isDense_) {
// Inside the lambda functions, bin_idx is the index for cut value across all
// features. By subtracting it with starting pointer of each feature, we can reduce
// it to smaller value and compress it to smaller types.
common::BinTypeSize curent_bin_size = index.GetBinTypeSize(); common::BinTypeSize curent_bin_size = index.GetBinTypeSize();
if (curent_bin_size == common::kUint8BinsTypeSize) { if (curent_bin_size == common::kUint8BinsTypeSize) {
common::Span<uint8_t> index_data_span = {index.data<uint8_t>(), n_index}; common::Span<uint8_t> index_data_span = {index.data<uint8_t>(), n_index};
SetIndexData(index_data_span, ft, batch_threads, batch, rbegin, nbins, SetIndexData(index_data_span, ft, batch_threads, batch, rbegin, nbins,
[offsets](auto idx, auto j) { [offsets](auto bin_idx, auto fidx) {
return static_cast<uint8_t>(idx - offsets[j]); return static_cast<uint8_t>(bin_idx - offsets[fidx]);
}); });
} else if (curent_bin_size == common::kUint16BinsTypeSize) { } else if (curent_bin_size == common::kUint16BinsTypeSize) {
common::Span<uint16_t> index_data_span = {index.data<uint16_t>(), n_index}; common::Span<uint16_t> index_data_span = {index.data<uint16_t>(), n_index};
SetIndexData(index_data_span, ft, batch_threads, batch, rbegin, nbins, SetIndexData(index_data_span, ft, batch_threads, batch, rbegin, nbins,
[offsets](auto idx, auto j) { [offsets](auto bin_idx, auto fidx) {
return static_cast<uint16_t>(idx - offsets[j]); return static_cast<uint16_t>(bin_idx - offsets[fidx]);
}); });
} else { } else {
CHECK_EQ(curent_bin_size, common::kUint32BinsTypeSize); CHECK_EQ(curent_bin_size, common::kUint32BinsTypeSize);
common::Span<uint32_t> index_data_span = {index.data<uint32_t>(), n_index}; common::Span<uint32_t> index_data_span = {index.data<uint32_t>(), n_index};
SetIndexData(index_data_span, ft, batch_threads, batch, rbegin, nbins, SetIndexData(index_data_span, ft, batch_threads, batch, rbegin, nbins,
[offsets](auto idx, auto j) { [offsets](auto bin_idx, auto fidx) {
return static_cast<uint32_t>(idx - offsets[j]); return static_cast<uint32_t>(bin_idx - offsets[fidx]);
}); });
} }
} else {
/* For sparse DMatrix we have to store index of feature for each bin /* For sparse DMatrix we have to store index of feature for each bin
in index field to chose right offset. So offset is nullptr and index is in index field to chose right offset. So offset is nullptr and index is
not reduced */ not reduced */
} else {
common::Span<uint32_t> index_data_span = {index.data<uint32_t>(), n_index}; common::Span<uint32_t> index_data_span = {index.data<uint32_t>(), n_index};
SetIndexData(index_data_span, ft, batch_threads, batch, rbegin, nbins, SetIndexData(index_data_span, ft, batch_threads, batch, rbegin, nbins,
[](auto idx, auto) { return idx; }); [](auto idx, auto) { return idx; });
@ -194,11 +191,13 @@ void GHistIndexMatrix::Init(SparsePage const &batch, common::Span<FeatureType co
void GHistIndexMatrix::ResizeIndex(const size_t n_index, const bool isDense) { void GHistIndexMatrix::ResizeIndex(const size_t n_index, const bool isDense) {
if ((max_num_bins - 1 <= static_cast<int>(std::numeric_limits<uint8_t>::max())) && isDense) { if ((max_num_bins - 1 <= static_cast<int>(std::numeric_limits<uint8_t>::max())) && isDense) {
// compress dense index to uint8
index.SetBinTypeSize(common::kUint8BinsTypeSize); index.SetBinTypeSize(common::kUint8BinsTypeSize);
index.Resize((sizeof(uint8_t)) * n_index); index.Resize((sizeof(uint8_t)) * n_index);
} else if ((max_num_bins - 1 > static_cast<int>(std::numeric_limits<uint8_t>::max()) && } else if ((max_num_bins - 1 > static_cast<int>(std::numeric_limits<uint8_t>::max()) &&
max_num_bins - 1 <= static_cast<int>(std::numeric_limits<uint16_t>::max())) && max_num_bins - 1 <= static_cast<int>(std::numeric_limits<uint16_t>::max())) &&
isDense) { isDense) {
// compress dense index to uint16
index.SetBinTypeSize(common::kUint16BinsTypeSize); index.SetBinTypeSize(common::kUint16BinsTypeSize);
index.Resize((sizeof(uint16_t)) * n_index); index.Resize((sizeof(uint16_t)) * n_index);
} else { } else {

View File

@ -21,6 +21,13 @@ namespace xgboost {
* index for CPU histogram. On GPU ellpack page is used. * index for CPU histogram. On GPU ellpack page is used.
*/ */
class GHistIndexMatrix { class GHistIndexMatrix {
/**
* \brief Push a page into index matrix, the function is only necessary because hist has
* partial support for external memory.
*
* \param rbegin The beginning row index of current page. (total rows in previous pages)
* \param prev_sum Total number of entries in previous pages.
*/
void PushBatch(SparsePage const& batch, common::Span<FeatureType const> ft, size_t rbegin, void PushBatch(SparsePage const& batch, common::Span<FeatureType const> ft, size_t rbegin,
size_t prev_sum, uint32_t nbins, int32_t n_threads); size_t prev_sum, uint32_t nbins, int32_t n_threads);
@ -64,12 +71,12 @@ class GHistIndexMatrix {
BinIdxType* index_data = index_data_span.data(); BinIdxType* index_data = index_data_span.data();
auto const& ptrs = cut.Ptrs(); auto const& ptrs = cut.Ptrs();
auto const& values = cut.Values(); auto const& values = cut.Values();
common::ParallelFor(batch_size, batch_threads, [&](omp_ulong i) { common::ParallelFor(batch_size, batch_threads, [&](omp_ulong ridx) {
const int tid = omp_get_thread_num(); const int tid = omp_get_thread_num();
size_t ibegin = row_ptr[rbegin + i]; size_t ibegin = row_ptr[rbegin + ridx]; // index of first entry for current block
size_t iend = row_ptr[rbegin + i + 1]; size_t iend = row_ptr[rbegin + ridx + 1]; // first entry for next block
const size_t size = offset_vec[i + 1] - offset_vec[i]; const size_t size = offset_vec[ridx + 1] - offset_vec[ridx];
SparsePage::Inst inst = {data_ptr + offset_vec[i], size}; SparsePage::Inst inst = {data_ptr + offset_vec[ridx], size};
CHECK_EQ(ibegin + inst.size(), iend); CHECK_EQ(ibegin + inst.size(), iend);
for (bst_uint j = 0; j < inst.size(); ++j) { for (bst_uint j = 0; j < inst.size(); ++j) {
auto e = inst[j]; auto e = inst[j];
@ -103,6 +110,10 @@ class GHistIndexMatrix {
return isDense_; return isDense_;
} }
void SetDense(bool is_dense) { isDense_ = is_dense; } void SetDense(bool is_dense) { isDense_ = is_dense; }
/**
* \brief Get the local row index.
*/
size_t RowIdx(size_t ridx) const { return row_ptr[ridx - base_rowid]; }
bst_row_t Size() const { bst_row_t Size() const {
return row_ptr.empty() ? 0 : row_ptr.size() - 1; return row_ptr.empty() ? 0 : row_ptr.size() - 1;

View File

@ -16,14 +16,6 @@ class GHistIndexRawFormat : public SparsePageFormat<GHistIndexMatrix> {
} }
// indptr // indptr
fi->Read(&page->row_ptr); fi->Read(&page->row_ptr);
// offset
using OffsetT = std::iterator_traits<decltype(page->index.Offset())>::value_type;
std::vector<OffsetT> offset;
if (!fi->Read(&offset)) {
return false;
}
page->index.ResizeOffset(offset.size());
std::copy(offset.begin(), offset.end(), page->index.Offset());
// data // data
std::vector<uint8_t> data; std::vector<uint8_t> data;
if (!fi->Read(&data)) { if (!fi->Read(&data)) {
@ -55,6 +47,9 @@ class GHistIndexRawFormat : public SparsePageFormat<GHistIndexMatrix> {
return false; return false;
} }
page->SetDense(is_dense); page->SetDense(is_dense);
if (is_dense) {
page->index.SetBinOffset(page->cut.Ptrs());
}
return true; return true;
} }
@ -65,13 +60,6 @@ class GHistIndexRawFormat : public SparsePageFormat<GHistIndexMatrix> {
fo->Write(page.row_ptr); fo->Write(page.row_ptr);
bytes += page.row_ptr.size() * sizeof(decltype(page.row_ptr)::value_type) + bytes += page.row_ptr.size() * sizeof(decltype(page.row_ptr)::value_type) +
sizeof(uint64_t); sizeof(uint64_t);
// offset
using OffsetT = std::iterator_traits<decltype(page.index.Offset())>::value_type;
std::vector<OffsetT> offset(page.index.OffsetSize());
std::copy(page.index.Offset(),
page.index.Offset() + page.index.OffsetSize(), offset.begin());
fo->Write(offset);
bytes += page.index.OffsetSize() * sizeof(OffsetT) + sizeof(uint64_t);
// data // data
std::vector<uint8_t> data(page.index.begin(), page.index.end()); std::vector<uint8_t> data(page.index.begin(), page.index.end());
fo->Write(data); fo->Write(data);

View File

@ -35,14 +35,12 @@ class ApproxRowPartitioner {
std::vector<uint32_t> const &cut_ptrs, std::vector<uint32_t> const &cut_ptrs,
std::vector<float> const &cut_values) { std::vector<float> const &cut_values) {
int32_t gidx = -1; int32_t gidx = -1;
auto const &row_ptr = index.row_ptr;
auto get_rid = [&](size_t ridx) { return row_ptr[ridx - index.base_rowid]; };
if (index.IsDense()) { if (index.IsDense()) {
gidx = index.index[get_rid(ridx) + fidx]; // RowIdx returns the starting pos of this row
gidx = index.index[index.RowIdx(ridx) + fidx];
} else { } else {
auto begin = get_rid(ridx); auto begin = index.RowIdx(ridx);
auto end = get_rid(ridx + 1); auto end = index.RowIdx(ridx + 1);
auto f_begin = cut_ptrs[fidx]; auto f_begin = cut_ptrs[fidx];
auto f_end = cut_ptrs[fidx + 1]; auto f_end = cut_ptrs[fidx + 1];
gidx = common::BinarySearchBin(begin, end, index.index, f_begin, f_end); gidx = common::BinarySearchBin(begin, end, index.index, f_begin, f_end);

View File

@ -135,7 +135,7 @@ void QuantileHistMaker::Builder<GradientSumT>::InitRoot(
{ {
auto nid = RegTree::kRoot; auto nid = RegTree::kRoot;
GHistRowT hist = this->histogram_builder_->Histogram()[nid]; auto hist = this->histogram_builder_->Histogram()[nid];
GradientPairT grad_stat; GradientPairT grad_stat;
if (data_layout_ == DataLayout::kDenseDataZeroBased || if (data_layout_ == DataLayout::kDenseDataZeroBased ||
data_layout_ == DataLayout::kDenseDataOneBased) { data_layout_ == DataLayout::kDenseDataOneBased) {
@ -149,7 +149,7 @@ void QuantileHistMaker::Builder<GradientSumT>::InitRoot(
grad_stat.Add(et.GetGrad(), et.GetHess()); grad_stat.Add(et.GetGrad(), et.GetHess());
} }
} else { } else {
const RowSetCollection::Elem e = row_set_collection_[nid]; const common::RowSetCollection::Elem e = row_set_collection_[nid];
for (const size_t *it = e.begin; it < e.end; ++it) { for (const size_t *it = e.begin; it < e.end; ++it) {
grad_stat.Add(gpair_h[*it].GetGrad(), gpair_h[*it].GetHess()); grad_stat.Add(gpair_h[*it].GetGrad(), gpair_h[*it].GetHess());
} }
@ -229,7 +229,7 @@ template<typename GradientSumT>
template <bool any_missing> template <bool any_missing>
void QuantileHistMaker::Builder<GradientSumT>::ExpandTree( void QuantileHistMaker::Builder<GradientSumT>::ExpandTree(
const GHistIndexMatrix& gmat, const GHistIndexMatrix& gmat,
const ColumnMatrix& column_matrix, const common::ColumnMatrix& column_matrix,
DMatrix* p_fmat, DMatrix* p_fmat,
RegTree* p_tree, RegTree* p_tree,
const std::vector<GradientPair>& gpair_h) { const std::vector<GradientPair>& gpair_h) {

View File

@ -147,7 +147,7 @@ class QuantileHistMaker: public TreeUpdater {
// training parameter // training parameter
TrainParam param_; TrainParam param_;
// column accessor // column accessor
ColumnMatrix column_matrix_; common::ColumnMatrix column_matrix_;
DMatrix const* p_last_dmat_ {nullptr}; DMatrix const* p_last_dmat_ {nullptr};
bool is_gmat_initialized_ {false}; bool is_gmat_initialized_ {false};
@ -155,7 +155,6 @@ class QuantileHistMaker: public TreeUpdater {
template<typename GradientSumT> template<typename GradientSumT>
struct Builder { struct Builder {
public: public:
using GHistRowT = GHistRow<GradientSumT>;
using GradientPairT = xgboost::detail::GradientPairInternal<GradientSumT>; using GradientPairT = xgboost::detail::GradientPairInternal<GradientSumT>;
// constructor // constructor
explicit Builder(const size_t n_trees, const TrainParam& param, explicit Builder(const size_t n_trees, const TrainParam& param,
@ -164,7 +163,6 @@ class QuantileHistMaker: public TreeUpdater {
: n_trees_(n_trees), : n_trees_(n_trees),
param_(param), param_(param),
pruner_(std::move(pruner)), pruner_(std::move(pruner)),
p_last_tree_(nullptr),
p_last_fmat_(fmat), p_last_fmat_(fmat),
histogram_builder_{new HistogramBuilder<GradientSumT, CPUExpandEntry>}, histogram_builder_{new HistogramBuilder<GradientSumT, CPUExpandEntry>},
task_{task}, task_{task},
@ -172,7 +170,7 @@ class QuantileHistMaker: public TreeUpdater {
builder_monitor_.Init("Quantile::Builder"); builder_monitor_.Init("Quantile::Builder");
} }
// update one tree, growing // update one tree, growing
void Update(const GHistIndexMatrix& gmat, const ColumnMatrix& column_matrix, void Update(const GHistIndexMatrix& gmat, const common::ColumnMatrix& column_matrix,
HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat, RegTree* p_tree); HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat, RegTree* p_tree);
bool UpdatePredictionCache(const DMatrix* data, bool UpdatePredictionCache(const DMatrix* data,

View File

@ -306,8 +306,8 @@ TEST(HistUtil, IndexBinBound) {
} }
template <typename T> template <typename T>
void CheckIndexData(T* data_ptr, uint32_t* offsets, void CheckIndexData(T const* data_ptr, uint32_t const* offsets, const GHistIndexMatrix& hmat,
const GHistIndexMatrix& hmat, size_t n_cols) { size_t n_cols) {
for (size_t i = 0; i < hmat.index.Size(); ++i) { for (size_t i = 0; i < hmat.index.Size(); ++i) {
EXPECT_EQ(data_ptr[i] + offsets[i % n_cols], hmat.index[i]); EXPECT_EQ(data_ptr[i] + offsets[i % n_cols], hmat.index[i]);
} }
@ -323,7 +323,7 @@ TEST(HistUtil, IndexBinData) {
for (auto max_bin : kBinSizes) { for (auto max_bin : kBinSizes) {
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(); auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
GHistIndexMatrix hmat(p_fmat.get(), max_bin, 0.5, false, common::OmpGetNumThreads(0)); GHistIndexMatrix hmat(p_fmat.get(), max_bin, 0.5, false, common::OmpGetNumThreads(0));
uint32_t* offsets = hmat.index.Offset(); uint32_t const* offsets = hmat.index.Offset();
EXPECT_EQ(hmat.index.Size(), kRows*kCols); EXPECT_EQ(hmat.index.Size(), kRows*kCols);
switch (max_bin) { switch (max_bin) {
case kBinSizes[0]: case kBinSizes[0]:

View File

@ -6,15 +6,16 @@
#include <limits> #include <limits>
#include "../../../../src/common/categorical.h" #include "../../../../src/common/categorical.h"
#include "../../../../src/common/row_set.h"
#include "../../../../src/tree/hist/expand_entry.h"
#include "../../../../src/tree/hist/histogram.h" #include "../../../../src/tree/hist/histogram.h"
#include "../../../../src/tree/updater_quantile_hist.h"
#include "../../categorical_helpers.h" #include "../../categorical_helpers.h"
#include "../../helpers.h" #include "../../helpers.h"
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {
namespace { namespace {
void InitRowPartitionForTest(RowSetCollection *row_set, size_t n_samples, size_t base_rowid = 0) { void InitRowPartitionForTest(common::RowSetCollection *row_set, size_t n_samples, size_t base_rowid = 0) {
auto &row_indices = *row_set->Data(); auto &row_indices = *row_set->Data();
row_indices.resize(n_samples); row_indices.resize(n_samples);
std::iota(row_indices.begin(), row_indices.end(), base_rowid); std::iota(row_indices.begin(), row_indices.end(), base_rowid);
@ -91,7 +92,7 @@ void TestSyncHist(bool is_distributed) {
uint32_t total_bins = gmat.cut.Ptrs().back(); uint32_t total_bins = gmat.cut.Ptrs().back();
histogram.Reset(total_bins, {kMaxBins, 0.5}, omp_get_max_threads(), 1, is_distributed); histogram.Reset(total_bins, {kMaxBins, 0.5}, omp_get_max_threads(), 1, is_distributed);
RowSetCollection row_set_collection_; common::RowSetCollection row_set_collection_;
{ {
row_set_collection_.Clear(); row_set_collection_.Clear();
std::vector<size_t> &row_indices = *row_set_collection_.Data(); std::vector<size_t> &row_indices = *row_set_collection_.Data();
@ -256,7 +257,7 @@ void TestBuildHistogram(bool is_distributed) {
RegTree tree; RegTree tree;
RowSetCollection row_set_collection; common::RowSetCollection row_set_collection;
row_set_collection.Clear(); row_set_collection.Clear();
std::vector<size_t> &row_indices = *row_set_collection.Data(); std::vector<size_t> &row_indices = *row_set_collection.Data();
row_indices.resize(kNRows); row_indices.resize(kNRows);
@ -318,7 +319,7 @@ void TestHistogramCategorical(size_t n_categories) {
auto gpair = GenerateRandomGradients(kRows, 0, 2); auto gpair = GenerateRandomGradients(kRows, 0, 2);
RowSetCollection row_set_collection; common::RowSetCollection row_set_collection;
row_set_collection.Clear(); row_set_collection.Clear();
std::vector<size_t> &row_indices = *row_set_collection.Data(); std::vector<size_t> &row_indices = *row_set_collection.Data();
row_indices.resize(kRows); row_indices.resize(kRows);
@ -381,13 +382,13 @@ void TestHistogramExternalMemory(BatchParam batch_param, bool is_approx) {
std::vector<CPUExpandEntry> nodes; std::vector<CPUExpandEntry> nodes;
nodes.emplace_back(0, tree.GetDepth(0), 0.0f); nodes.emplace_back(0, tree.GetDepth(0), 0.0f);
GHistRow<double> multi_page; common::GHistRow<double> multi_page;
HistogramBuilder<double, CPUExpandEntry> multi_build; HistogramBuilder<double, CPUExpandEntry> multi_build;
{ {
/** /**
* Multi page * Multi page
*/ */
std::vector<RowSetCollection> rows_set; std::vector<common::RowSetCollection> rows_set;
for (auto const &page : m->GetBatches<GHistIndexMatrix>(batch_param)) { for (auto const &page : m->GetBatches<GHistIndexMatrix>(batch_param)) {
CHECK_LT(page.base_rowid, m->Info().num_row_); CHECK_LT(page.base_rowid, m->Info().num_row_);
auto n_rows_in_node = page.Size(); auto n_rows_in_node = page.Size();
@ -417,12 +418,12 @@ void TestHistogramExternalMemory(BatchParam batch_param, bool is_approx) {
} }
HistogramBuilder<double, CPUExpandEntry> single_build; HistogramBuilder<double, CPUExpandEntry> single_build;
GHistRow<double> single_page; common::GHistRow<double> single_page;
{ {
/** /**
* Single page * Single page
*/ */
RowSetCollection row_set_collection; common::RowSetCollection row_set_collection;
InitRowPartitionForTest(&row_set_collection, n_samples); InitRowPartitionForTest(&row_set_collection, n_samples);
single_build.Reset(total_bins, batch_param, common::OmpGetNumThreads(0), 1, false); single_build.Reset(total_bins, batch_param, common::OmpGetNumThreads(0), 1, false);

View File

@ -24,7 +24,6 @@ class QuantileHistMock : public QuantileHistMaker {
template <typename GradientSumT> template <typename GradientSumT>
struct BuilderMock : public QuantileHistMaker::Builder<GradientSumT> { struct BuilderMock : public QuantileHistMaker::Builder<GradientSumT> {
using RealImpl = QuantileHistMaker::Builder<GradientSumT>; using RealImpl = QuantileHistMaker::Builder<GradientSumT>;
using GHistRowT = typename RealImpl::GHistRowT;
BuilderMock(const TrainParam &param, std::unique_ptr<TreeUpdater> pruner, BuilderMock(const TrainParam &param, std::unique_ptr<TreeUpdater> pruner,
DMatrix const *fmat, GenericParameter const* ctx) DMatrix const *fmat, GenericParameter const* ctx)