External memory support for hist (#7531)

* Generate column matrix from gHistIndex.
* Avoid synchronization with the sparse page once the cache is written.
* Cleanups: Remove member variables/functions, change the update routine to look like approx and gpu_hist.
* Remove pruner.
This commit is contained in:
Jiaming Yuan 2022-03-22 00:13:20 +08:00 committed by GitHub
parent cd55823112
commit 4d81c741e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 563 additions and 686 deletions

View File

@ -48,17 +48,18 @@
#include "../src/predictor/cpu_predictor.cc"
// trees
#include "../src/tree/constraints.cc"
#include "../src/tree/hist/param.cc"
#include "../src/tree/param.cc"
#include "../src/tree/tree_model.cc"
#include "../src/tree/tree_updater.cc"
#include "../src/tree/updater_approx.cc"
#include "../src/tree/updater_colmaker.cc"
#include "../src/tree/updater_quantile_hist.cc"
#include "../src/tree/updater_histmaker.cc"
#include "../src/tree/updater_prune.cc"
#include "../src/tree/updater_quantile_hist.cc"
#include "../src/tree/updater_refresh.cc"
#include "../src/tree/updater_sync.cc"
#include "../src/tree/updater_histmaker.cc"
#include "../src/tree/updater_approx.cc"
#include "../src/tree/constraints.cc"
// linear
#include "../src/linear/linear_updater.cc"

View File

@ -7,6 +7,9 @@ instead of Quantile DMatrix. The feature is not ready for production use yet.
.. versionadded:: 1.5.0
See :doc:`the tutorial </tutorials/external_memory>` for more details.
"""
import os
import xgboost
@ -77,9 +80,14 @@ def main(tmpdir: str) -> xgboost.Booster:
missing = np.NaN
Xy = xgboost.DMatrix(it, missing=missing, enable_categorical=False)
# Other tree methods including ``hist`` and ``gpu_hist`` also work, but has some
# caveats. This is still an experimental feature.
booster = xgboost.train({"tree_method": "approx"}, Xy, evals=[(Xy, "Train")])
# Other tree methods including ``hist`` and ``gpu_hist`` also work, see tutorial in
# doc for details.
booster = xgboost.train(
{"tree_method": "approx", "max_depth": 2},
Xy,
evals=[(Xy, "Train")],
num_boost_round=10,
)
return booster

View File

@ -27,7 +27,7 @@ def main(args):
dtrain.set_info(feature_weights=fw)
bst = xgboost.train({'tree_method': 'hist',
'colsample_bynode': 0.5},
'colsample_bynode': 0.2},
dtrain, num_boost_round=10,
evals=[(dtrain, 'd')])
feature_map = bst.get_fscore()

View File

@ -127,9 +127,12 @@ the tree method still concatenate all the chunks into 1 final histogram index du
performance reason, but in compressed format. So its scalability has an upper bound but
still has lower memory cost in general.
********
CPU Hist
********
***********
CPU Version
***********
It's limited by the same factor of GPU Hist, except that gradient based sampling is not
yet supported on CPU.
For CPU histogram based tree methods (``approx``, ``hist``) it's recommended to use
``grow_policy=depthwise`` for performance reason. Iterating over data batches is slow,
with ``depthwise`` policy XGBoost can build a entire layer of tree nodes with a few
iterations, while with ``lossguide`` XGBoost needs to iterate over the data set for each
tree node.

View File

@ -1,5 +1,5 @@
/*!
* Copyright 2017 by Contributors
* Copyright 2017-2022 by Contributors
* \file column_matrix.h
* \brief Utility for fast column-wise access
* \author Philip Cho
@ -8,21 +8,22 @@
#ifndef XGBOOST_COMMON_COLUMN_MATRIX_H_
#define XGBOOST_COMMON_COLUMN_MATRIX_H_
#include <dmlc/endian.h>
#include <algorithm>
#include <limits>
#include <vector>
#include <memory>
#include "hist_util.h"
#include <vector>
#include "../data/gradient_index.h"
#include "hist_util.h"
namespace xgboost {
namespace common {
class ColumnMatrix;
/*! \brief column type */
enum ColumnType {
kDenseColumn,
kSparseColumn
};
enum ColumnType : uint8_t { kDenseColumn, kSparseColumn };
/*! \brief a column storage, to be used with ApplySplit. Note that each
bin id is stored as index[i] + index_base.
@ -34,9 +35,7 @@ class Column {
static constexpr int32_t kMissingId = -1;
Column(ColumnType type, common::Span<const BinIdxType> index, const uint32_t index_base)
: type_(type),
index_(index),
index_base_(index_base) {}
: type_(type), index_(index), index_base_(index_base) {}
virtual ~Column() = default;
@ -67,10 +66,9 @@ class Column {
template <typename BinIdxType>
class SparseColumn : public Column<BinIdxType> {
public:
SparseColumn(ColumnType type, common::Span<const BinIdxType> index,
uint32_t index_base, common::Span<const size_t> row_ind)
: Column<BinIdxType>(type, index, index_base),
row_ind_(row_ind) {}
SparseColumn(ColumnType type, common::Span<const BinIdxType> index, uint32_t index_base,
common::Span<const size_t> row_ind)
: Column<BinIdxType>(type, index, index_base), row_ind_(row_ind) {}
const size_t* GetRowData() const { return row_ind_.data(); }
@ -98,9 +96,7 @@ class SparseColumn: public Column<BinIdxType> {
return p - row_data;
}
size_t GetRowIdx(size_t idx) const {
return row_ind_.data()[idx];
}
size_t GetRowIdx(size_t idx) const { return row_ind_.data()[idx]; }
private:
/* indexes of rows */
@ -110,9 +106,8 @@ class SparseColumn: public Column<BinIdxType> {
template <typename BinIdxType, bool any_missing>
class DenseColumn : public Column<BinIdxType> {
public:
DenseColumn(ColumnType type, common::Span<const BinIdxType> index,
uint32_t index_base, const std::vector<bool>& missing_flags,
size_t feature_offset)
DenseColumn(ColumnType type, common::Span<const BinIdxType> index, uint32_t index_base,
const std::vector<bool>& missing_flags, size_t feature_offset)
: Column<BinIdxType>(type, index, index_base),
missing_flags_(missing_flags),
feature_offset_(feature_offset) {}
@ -126,9 +121,7 @@ class DenseColumn: public Column<BinIdxType> {
}
}
size_t GetInitialState(const size_t first_row_id) const {
return 0;
}
size_t GetInitialState(const size_t first_row_id) const { return 0; }
private:
/* flags for missing values in dense columns */
@ -141,28 +134,26 @@ class DenseColumn: public Column<BinIdxType> {
class ColumnMatrix {
public:
// get number of features
inline bst_uint GetNumFeature() const {
return static_cast<bst_uint>(type_.size());
}
bst_feature_t GetNumFeature() const { return static_cast<bst_feature_t>(type_.size()); }
// construct column matrix from GHistIndexMatrix
inline void Init(const GHistIndexMatrix& gmat, double sparse_threshold, int32_t n_threads) {
const int32_t nfeature = static_cast<int32_t>(gmat.cut.Ptrs().size() - 1);
inline void Init(SparsePage const& page, const GHistIndexMatrix& gmat, double sparse_threshold,
int32_t n_threads) {
auto const nfeature = static_cast<bst_feature_t>(gmat.cut.Ptrs().size() - 1);
const size_t nrow = gmat.row_ptr.size() - 1;
// identify type of each column
feature_counts_.resize(nfeature);
type_.resize(nfeature);
std::fill(feature_counts_.begin(), feature_counts_.end(), 0);
uint32_t max_val = std::numeric_limits<uint32_t>::max();
for (int32_t fid = 0; fid < nfeature; ++fid) {
for (bst_feature_t fid = 0; fid < nfeature; ++fid) {
CHECK_LE(gmat.cut.Ptrs()[fid + 1] - gmat.cut.Ptrs()[fid], max_val);
}
bool all_dense = gmat.IsDense();
gmat.GetFeatureCounts(&feature_counts_[0]);
// classify features
for (int32_t fid = 0; fid < nfeature; ++fid) {
if (static_cast<double>(feature_counts_[fid])
< sparse_threshold * nrow) {
for (bst_feature_t fid = 0; fid < nfeature; ++fid) {
if (static_cast<double>(feature_counts_[fid]) < sparse_threshold * nrow) {
type_[fid] = kSparseColumn;
all_dense = false;
} else {
@ -175,7 +166,7 @@ class ColumnMatrix {
feature_offsets_.resize(nfeature + 1);
size_t accum_index_ = 0;
feature_offsets_[0] = accum_index_;
for (int32_t fid = 1; fid < nfeature + 1; ++fid) {
for (bst_feature_t fid = 1; fid < nfeature + 1; ++fid) {
if (type_[fid - 1] == kDenseColumn) {
accum_index_ += static_cast<size_t>(nrow);
} else {
@ -197,6 +188,7 @@ class ColumnMatrix {
const bool noMissingValues = NoMissingValues(gmat.row_ptr[nrow], nrow, nfeature);
any_missing_ = !noMissingValues;
missing_flags_.clear();
if (noMissingValues) {
missing_flags_.resize(feature_offsets_[nfeature], false);
} else {
@ -207,26 +199,26 @@ class ColumnMatrix {
if (all_dense) {
BinTypeSize gmat_bin_size = gmat.index.GetBinTypeSize();
if (gmat_bin_size == kUint8BinsTypeSize) {
SetIndexAllDense(gmat.index.data<uint8_t>(), gmat, nrow, nfeature, noMissingValues,
SetIndexAllDense(page, gmat.index.data<uint8_t>(), gmat, nrow, nfeature, noMissingValues,
n_threads);
} else if (gmat_bin_size == kUint16BinsTypeSize) {
SetIndexAllDense(gmat.index.data<uint16_t>(), gmat, nrow, nfeature, noMissingValues,
SetIndexAllDense(page, gmat.index.data<uint16_t>(), gmat, nrow, nfeature, noMissingValues,
n_threads);
} else {
CHECK_EQ(gmat_bin_size, kUint32BinsTypeSize);
SetIndexAllDense(gmat.index.data<uint32_t>(), gmat, nrow, nfeature, noMissingValues,
SetIndexAllDense(page, gmat.index.data<uint32_t>(), gmat, nrow, nfeature, noMissingValues,
n_threads);
}
/* For sparse DMatrix gmat.index.getBinTypeSize() returns always kUint32BinsTypeSize
but for ColumnMatrix we still have a chance to reduce the memory consumption */
} else {
if (bins_type_size_ == kUint8BinsTypeSize) {
SetIndex<uint8_t>(gmat.index.data<uint32_t>(), gmat, nfeature);
SetIndex<uint8_t>(page, gmat.index.data<uint32_t>(), gmat, nfeature);
} else if (bins_type_size_ == kUint16BinsTypeSize) {
SetIndex<uint16_t>(gmat.index.data<uint32_t>(), gmat, nfeature);
SetIndex<uint16_t>(page, gmat.index.data<uint32_t>(), gmat, nfeature);
} else {
CHECK_EQ(bins_type_size_, kUint32BinsTypeSize);
SetIndex<uint32_t>(gmat.index.data<uint32_t>(), gmat, nfeature);
SetIndex<uint32_t>(page, gmat.index.data<uint32_t>(), gmat, nfeature);
}
}
}
@ -250,8 +242,8 @@ class ColumnMatrix {
const size_t feature_offset = feature_offsets_[fid]; // to get right place for certain feature
const size_t column_size = feature_offsets_[fid + 1] - feature_offset;
common::Span<const BinIdxType> bin_index = { reinterpret_cast<const BinIdxType*>(
&index_[feature_offset * bins_type_size_]),
common::Span<const BinIdxType> bin_index = {
reinterpret_cast<const BinIdxType*>(&index_[feature_offset * bins_type_size_]),
column_size};
std::unique_ptr<const Column<BinIdxType> > res;
if (type_[fid] == ColumnType::kDenseColumn) {
@ -266,8 +258,8 @@ class ColumnMatrix {
}
template <typename T>
inline void SetIndexAllDense(T const* index, const GHistIndexMatrix& gmat, const size_t nrow,
const size_t nfeature, const bool noMissingValues,
inline void SetIndexAllDense(SparsePage const& page, T const* index, const GHistIndexMatrix& gmat,
const size_t nrow, const size_t nfeature, const bool noMissingValues,
int32_t n_threads) {
T* local_index = reinterpret_cast<T*>(&index_[0]);
@ -285,51 +277,34 @@ class ColumnMatrix {
});
} else {
/* to handle rows in all batches, sum of all batch sizes equal to gmat.row_ptr.size() - 1 */
size_t rbegin = 0;
for (const auto &batch : gmat.p_fmat->GetBatches<SparsePage>()) {
const xgboost::Entry* data_ptr = batch.data.HostVector().data();
const std::vector<bst_row_t>& offset_vec = batch.offset.HostVector();
const size_t batch_size = batch.Size();
CHECK_LT(batch_size, offset_vec.size());
for (size_t rid = 0; rid < batch_size; ++rid) {
const size_t size = offset_vec[rid + 1] - offset_vec[rid];
SparsePage::Inst inst = {data_ptr + offset_vec[rid], size};
const size_t ibegin = gmat.row_ptr[rbegin + rid];
const size_t iend = gmat.row_ptr[rbegin + rid + 1];
CHECK_EQ(ibegin + inst.size(), iend);
size_t j = 0;
size_t fid = 0;
for (size_t i = ibegin; i < iend; ++i, ++j) {
fid = inst[j].index;
auto get_bin_idx = [&](auto bin_id, auto rid, bst_feature_t fid) {
// T* begin = &local_index[feature_offsets_[fid]];
const size_t idx = feature_offsets_[fid];
/* rbegin allows to store indexes from specific SparsePage batch */
local_index[idx + rbegin + rid] = index[i];
missing_flags_[idx + rbegin + rid] = false;
}
}
rbegin += batch.Size();
}
local_index[idx + rid] = bin_id;
missing_flags_[idx + rid] = false;
};
this->SetIndexSparse(page, index, gmat, nfeature, get_bin_idx);
}
}
template<typename T>
inline void SetIndex(uint32_t const* index, const GHistIndexMatrix& gmat,
const size_t nfeature) {
std::vector<size_t> num_nonzeros;
num_nonzeros.resize(nfeature);
std::fill(num_nonzeros.begin(), num_nonzeros.end(), 0);
T* local_index = reinterpret_cast<T*>(&index_[0]);
size_t rbegin = 0;
for (const auto &batch : gmat.p_fmat->GetBatches<SparsePage>()) {
// FIXME(jiamingy): In the future we might want to simply use binary search to simplify
// this and remove the dependency on SparsePage. This way we can have quantilized
// matrix for host similar to `DeviceQuantileDMatrix`.
template <typename T, typename BinFn>
void SetIndexSparse(SparsePage const& batch, T* index, const GHistIndexMatrix& gmat,
const size_t nfeature, BinFn&& assign_bin) {
std::vector<size_t> num_nonzeros(nfeature, 0ul);
const xgboost::Entry* data_ptr = batch.data.HostVector().data();
const std::vector<bst_row_t>& offset_vec = batch.offset.HostVector();
const size_t batch_size = batch.Size();
auto rbegin = 0;
const size_t batch_size = gmat.Size();
CHECK_LT(batch_size, offset_vec.size());
for (size_t rid = 0; rid < batch_size; ++rid) {
const size_t ibegin = gmat.row_ptr[rbegin + rid];
const size_t iend = gmat.row_ptr[rbegin + rid + 1];
size_t fid = 0;
const size_t size = offset_vec[rid + 1] - offset_vec[rid];
SparsePage::Inst inst = {data_ptr + offset_vec[rid], size};
@ -337,36 +312,110 @@ class ColumnMatrix {
size_t j = 0;
for (size_t i = ibegin; i < iend; ++i, ++j) {
const uint32_t bin_id = index[i];
auto fid = inst[j].index;
assign_bin(bin_id, rid, fid);
}
}
}
fid = inst[j].index;
template <typename T>
inline void SetIndex(SparsePage const& page, uint32_t const* index, const GHistIndexMatrix& gmat,
const size_t nfeature) {
T* local_index = reinterpret_cast<T*>(&index_[0]);
std::vector<size_t> num_nonzeros;
num_nonzeros.resize(nfeature);
std::fill(num_nonzeros.begin(), num_nonzeros.end(), 0);
auto get_bin_idx = [&](auto bin_id, auto rid, bst_feature_t fid) {
if (type_[fid] == kDenseColumn) {
T* begin = &local_index[feature_offsets_[fid]];
begin[rid + rbegin] = bin_id - index_base_[fid];
missing_flags_[feature_offsets_[fid] + rid + rbegin] = false;
begin[rid] = bin_id - index_base_[fid];
missing_flags_[feature_offsets_[fid] + rid] = false;
} else {
T* begin = &local_index[feature_offsets_[fid]];
begin[num_nonzeros[fid]] = bin_id - index_base_[fid];
row_ind_[feature_offsets_[fid] + num_nonzeros[fid]] = rid + rbegin;
row_ind_[feature_offsets_[fid] + num_nonzeros[fid]] = rid;
++num_nonzeros[fid];
}
}
}
rbegin += batch.Size();
}
}
BinTypeSize GetTypeSize() const {
return bins_type_size_;
};
this->SetIndexSparse(page, index, gmat, nfeature, get_bin_idx);
}
BinTypeSize GetTypeSize() const { return bins_type_size_; }
// This is just an utility function
bool NoMissingValues(const size_t n_elements,
const size_t n_row, const size_t n_features) {
bool NoMissingValues(const size_t n_elements, const size_t n_row, const size_t n_features) {
return n_elements == n_features * n_row;
}
// And this returns part of state
bool AnyMissing() const {
return any_missing_;
bool AnyMissing() const { return any_missing_; }
// IO procedures for external memory.
bool Read(dmlc::SeekStream* fi, uint32_t const* index_base) {
fi->Read(&index_);
fi->Read(&feature_counts_);
#if !DMLC_LITTLE_ENDIAN
// s390x
std::vector<std::underlying_type<ColumnType>::type> int_types;
fi->Read(&int_types);
type_.resize(int_types.size());
std::transform(
int_types.begin(), int_types.end(), type_.begin(),
[](std::underlying_type<ColumnType>::type i) { return static_cast<ColumnType>(i); });
#else
fi->Read(&type_);
#endif // !DMLC_LITTLE_ENDIAN
fi->Read(&row_ind_);
fi->Read(&feature_offsets_);
index_base_ = index_base;
#if !DMLC_LITTLE_ENDIAN
std::underlying_type<BinTypeSize>::type v;
fi->Read(&v);
bins_type_size_ = static_cast<BinTypeSize>(v);
#else
fi->Read(&bins_type_size_);
#endif
fi->Read(&any_missing_);
return true;
}
size_t Write(dmlc::Stream* fo) const {
size_t bytes{0};
auto write_vec = [&](auto const& vec) {
fo->Write(vec);
bytes += vec.size() * sizeof(typename std::remove_reference_t<decltype(vec)>::value_type) +
sizeof(uint64_t);
};
write_vec(index_);
write_vec(feature_counts_);
#if !DMLC_LITTLE_ENDIAN
// s390x
std::vector<std::underlying_type<ColumnType>::type> int_types(type_.size());
std::transform(type_.begin(), type_.end(), int_types.begin(), [](ColumnType t) {
return static_cast<std::underlying_type<ColumnType>::type>(t);
});
write_vec(int_types);
#else
write_vec(type_);
#endif // !DMLC_LITTLE_ENDIAN
write_vec(row_ind_);
write_vec(feature_offsets_);
#if !DMLC_LITTLE_ENDIAN
auto v = static_cast<std::underlying_type<BinTypeSize>::type>(bins_type_size_);
fo->Write(v);
#else
fo->Write(bins_type_size_);
#endif // DMLC_LITTLE_ENDIAN
bytes += sizeof(bins_type_size_);
fo->Write(any_missing_);
bytes += sizeof(any_missing_);
return bytes;
}
private:

View File

@ -1,5 +1,5 @@
/*!
* Copyright 2019-2021 XGBoost contributors
* Copyright 2019-2022 XGBoost contributors
*/
#include <memory>
#include <utility>
@ -12,6 +12,13 @@ namespace data {
void EllpackPageSource::Fetch() {
dh::safe_cuda(cudaSetDevice(param_.gpu_id));
if (!this->ReadCache()) {
if (count_ != 0 && !sync_) {
// source is initialized to be the 0th page during construction, so when count_ is 0
// there's no need to increment the source.
++(*source_);
}
// This is not read from cache so we still need it to be synced with sparse page source.
CHECK_EQ(count_, source_->Iter());
auto const &csr = source_->Page();
this->page_.reset(new EllpackPage{});
auto *impl = this->page_->Impl();

View File

@ -1,5 +1,5 @@
/*!
* Copyright 2019-2021 by XGBoost Contributors
* Copyright 2019-2022 by XGBoost Contributors
*/
#ifndef XGBOOST_DATA_ELLPACK_PAGE_SOURCE_H_
@ -25,15 +25,17 @@ class EllpackPageSource : public PageSourceIncMixIn<EllpackPage> {
std::unique_ptr<common::HistogramCuts> cuts_;
public:
EllpackPageSource(
float missing, int nthreads, bst_feature_t n_features, size_t n_batches,
EllpackPageSource(float missing, int nthreads, bst_feature_t n_features, size_t n_batches,
std::shared_ptr<Cache> cache, BatchParam param,
std::unique_ptr<common::HistogramCuts> cuts, bool is_dense,
size_t row_stride, common::Span<FeatureType const> feature_types,
std::unique_ptr<common::HistogramCuts> cuts, bool is_dense, size_t row_stride,
common::Span<FeatureType const> feature_types,
std::shared_ptr<SparsePageSource> source)
: PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache),
is_dense_{is_dense}, row_stride_{row_stride}, param_{std::move(param)},
feature_types_{feature_types}, cuts_{std::move(cuts)} {
: PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache, false),
is_dense_{is_dense},
row_stride_{row_stride},
param_{std::move(param)},
feature_types_{feature_types},
cuts_{std::move(cuts)} {
this->source_ = source;
this->Fetch();
}

View File

@ -144,7 +144,6 @@ void GHistIndexMatrix::Init(DMatrix *p_fmat, int max_bins, double sparse_thresh,
hit_count.resize(nbins, 0);
hit_count_tloc_.resize(n_threads * nbins, 0);
this->p_fmat = p_fmat;
size_t new_size = 1;
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
new_size += batch.Size();
@ -164,6 +163,16 @@ void GHistIndexMatrix::Init(DMatrix *p_fmat, int max_bins, double sparse_thresh,
prev_sum = row_ptr[rbegin + batch.Size()];
rbegin += batch.Size();
}
this->columns_ = std::make_unique<common::ColumnMatrix>();
// hessian is empty when hist tree method is used or when dataset is empty
if (hess.empty() && !std::isnan(sparse_thresh)) {
// hist
CHECK(!sorted_sketch);
for (auto const &page : p_fmat->GetBatches<SparsePage>()) {
this->columns_->Init(page, *this, sparse_thresh, n_threads);
}
}
}
void GHistIndexMatrix::Init(SparsePage const &batch, common::Span<FeatureType const> ft,
@ -187,6 +196,10 @@ void GHistIndexMatrix::Init(SparsePage const &batch, common::Span<FeatureType co
size_t prev_sum = 0;
this->PushBatch(batch, ft, rbegin, prev_sum, nbins, n_threads);
this->columns_ = std::make_unique<common::ColumnMatrix>();
if (!std::isnan(sparse_thresh)) {
this->columns_->Init(batch, *this, sparse_thresh, n_threads);
}
}
void GHistIndexMatrix::ResizeIndex(const size_t n_index, const bool isDense) {
@ -205,4 +218,17 @@ void GHistIndexMatrix::ResizeIndex(const size_t n_index, const bool isDense) {
index.Resize((sizeof(uint32_t)) * n_index);
}
}
common::ColumnMatrix const &GHistIndexMatrix::Transpose() const {
CHECK(columns_);
return *columns_;
}
bool GHistIndexMatrix::ReadColumnPage(dmlc::SeekStream *fi) {
return this->columns_->Read(fi, this->cut.Ptrs().data());
}
size_t GHistIndexMatrix::WriteColumnPage(dmlc::Stream *fo) const {
return this->columns_->Write(fo);
}
} // namespace xgboost

View File

@ -40,7 +40,6 @@ class GHistIndexMatrix {
std::vector<size_t> hit_count;
/*! \brief The corresponding cuts */
common::HistogramCuts cut;
DMatrix* p_fmat;
/*! \brief max_bin for each feature. */
size_t max_num_bins;
/*! \brief base row index for current page (used by external memory) */
@ -119,8 +118,12 @@ class GHistIndexMatrix {
return row_ptr.empty() ? 0 : row_ptr.size() - 1;
}
bool ReadColumnPage(dmlc::SeekStream* fi);
size_t WriteColumnPage(dmlc::Stream* fo) const;
common::ColumnMatrix const& Transpose() const;
private:
// unused at the moment: https://github.com/dmlc/xgboost/pull/7531
std::unique_ptr<common::ColumnMatrix> columns_;
std::vector<size_t> hit_count_tloc_;
bool isDense_;

View File

@ -1,5 +1,5 @@
/*!
* Copyright 2021 XGBoost contributors
* Copyright 2021-2022 XGBoost contributors
*/
#include "sparse_page_writer.h"
#include "gradient_index.h"
@ -7,7 +7,6 @@
namespace xgboost {
namespace data {
class GHistIndexRawFormat : public SparsePageFormat<GHistIndexMatrix> {
public:
bool Read(GHistIndexMatrix* page, dmlc::SeekStream* fi) override {
@ -50,6 +49,8 @@ class GHistIndexRawFormat : public SparsePageFormat<GHistIndexMatrix> {
if (is_dense) {
page->index.SetBinOffset(page->cut.Ptrs());
}
page->ReadColumnPage(fi);
return true;
}
@ -81,6 +82,8 @@ class GHistIndexRawFormat : public SparsePageFormat<GHistIndexMatrix> {
bytes += sizeof(page.base_rowid);
fo->Write(page.IsDense());
bytes += sizeof(page.IsDense());
bytes += page.WriteColumnPage(fo);
return bytes;
}
};

View File

@ -7,11 +7,18 @@ namespace xgboost {
namespace data {
void GradientIndexPageSource::Fetch() {
if (!this->ReadCache()) {
if (count_ != 0 && !sync_) {
// source is initialized to be the 0th page during construction, so when count_ is 0
// there's no need to increment the source.
++(*source_);
}
// This is not read from cache so we still need it to be synced with sparse page source.
CHECK_EQ(count_, source_->Iter());
auto const& csr = source_->Page();
this->page_.reset(new GHistIndexMatrix());
CHECK_NE(cuts_.Values().size(), 0);
this->page_->Init(*csr, feature_types_, cuts_, max_bin_per_feat_, is_dense_,
sparse_thresh_, nthreads_);
this->page_->Init(*csr, feature_types_, cuts_, max_bin_per_feat_, is_dense_, sparse_thresh_,
nthreads_);
this->WriteCache();
}
}

View File

@ -22,13 +22,14 @@ class GradientIndexPageSource : public PageSourceIncMixIn<GHistIndexMatrix> {
public:
GradientIndexPageSource(float missing, int nthreads, bst_feature_t n_features, size_t n_batches,
std::shared_ptr<Cache> cache, BatchParam param,
common::HistogramCuts cuts, bool is_dense, int32_t max_bin_per_feat,
common::HistogramCuts cuts, bool is_dense,
common::Span<FeatureType const> feature_types,
std::shared_ptr<SparsePageSource> source)
: PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache),
: PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache,
std::isnan(param.sparse_thresh)),
cuts_{std::move(cuts)},
is_dense_{is_dense},
max_bin_per_feat_{max_bin_per_feat},
max_bin_per_feat_{param.max_bin},
feature_types_{feature_types},
sparse_thresh_{param.sparse_thresh} {
this->source_ = source;

View File

@ -159,21 +159,6 @@ BatchSet<SortedCSCPage> SparsePageDMatrix::GetSortedColumnBatches() {
BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(const BatchParam &param) {
CHECK_GE(param.max_bin, 2);
if (param.hess.empty() && !param.regen) {
// hist method doesn't support full external memory implementation, so we concatenate
// all index here.
if (!ghist_index_page_ || (param != batch_param_ && param != BatchParam{})) {
this->InitializeSparsePage();
ghist_index_page_.reset(new GHistIndexMatrix{this, param.max_bin, param.sparse_thresh,
param.regen, ctx_.Threads()});
this->InitializeSparsePage();
batch_param_ = param;
}
auto begin_iter = BatchIterator<GHistIndexMatrix>(
new SimpleBatchIteratorImpl<GHistIndexMatrix>(ghist_index_page_));
return BatchSet<GHistIndexMatrix>(begin_iter);
}
auto id = MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_);
this->InitializeSparsePage();
if (!cache_info_.at(id)->written || RegenGHist(batch_param_, param)) {
@ -190,10 +175,9 @@ BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(const BatchParam
ghist_index_source_.reset();
CHECK_NE(cuts.Values().size(), 0);
auto ft = this->info_.feature_types.ConstHostSpan();
ghist_index_source_.reset(
new GradientIndexPageSource(this->missing_, this->ctx_.Threads(), this->Info().num_col_,
this->n_batches_, cache_info_.at(id), param, std::move(cuts),
this->IsDense(), param.max_bin, ft, sparse_page_source_));
ghist_index_source_.reset(new GradientIndexPageSource(
this->missing_, this->ctx_.Threads(), this->Info().num_col_, this->n_batches_,
cache_info_.at(id), param, std::move(cuts), this->IsDense(), ft, sparse_page_source_));
} else {
CHECK(ghist_index_source_);
ghist_index_source_->Reset();

View File

@ -11,6 +11,9 @@ namespace data {
BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(const BatchParam& param) {
CHECK_GE(param.gpu_id, 0);
CHECK_GE(param.max_bin, 2);
if (!(batch_param_ != BatchParam{})) {
CHECK(param != BatchParam{}) << "Batch parameter is not initialized.";
}
auto id = MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_);
size_t row_stride = 0;
this->InitializeSparsePage();

View File

@ -23,6 +23,7 @@
#include "proxy_dmatrix.h"
#include "../common/common.h"
#include "../common/timer.h"
namespace xgboost {
namespace data {
@ -118,26 +119,30 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
size_t n_prefetch_batches = std::min(kPreFetch, n_batches_);
CHECK_GT(n_prefetch_batches, 0) << "total batches:" << n_batches_;
size_t fetch_it = count_;
for (size_t i = 0; i < n_prefetch_batches; ++i, ++fetch_it) {
fetch_it %= n_batches_; // ring
if (ring_->at(fetch_it).valid()) { continue; }
if (ring_->at(fetch_it).valid()) {
continue;
}
auto const *self = this; // make sure it's const
CHECK_LT(fetch_it, cache_info_->offset.size());
ring_->at(fetch_it) = std::async(std::launch::async, [fetch_it, self]() {
common::Timer timer;
timer.Start();
std::unique_ptr<SparsePageFormat<S>> fmt{CreatePageFormat<S>("raw")};
auto n = self->cache_info_->ShardName();
size_t offset = self->cache_info_->offset.at(fetch_it);
std::unique_ptr<dmlc::SeekStream> fi{
dmlc::SeekStream::CreateForRead(n.c_str())};
std::unique_ptr<dmlc::SeekStream> fi{dmlc::SeekStream::CreateForRead(n.c_str())};
fi->Seek(offset);
CHECK_EQ(fi->Tell(), offset);
auto page = std::make_shared<S>();
CHECK(fmt->Read(page.get(), fi.get()));
LOG(INFO) << "Read a page in " << timer.ElapsedSeconds() << " seconds.";
return page;
});
}
CHECK_EQ(std::count_if(ring_->cbegin(), ring_->cend(),
[](auto const &f) { return f.valid(); }),
CHECK_EQ(std::count_if(ring_->cbegin(), ring_->cend(), [](auto const& f) { return f.valid(); }),
n_prefetch_batches)
<< "Sparse DMatrix assumes forward iteration.";
page_ = (*ring_)[count_].get();
@ -146,12 +151,18 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
void WriteCache() {
CHECK(!cache_info_->written);
common::Timer timer;
timer.Start();
std::unique_ptr<SparsePageFormat<S>> fmt{CreatePageFormat<S>("raw")};
if (!fo_) {
auto n = cache_info_->ShardName();
fo_.reset(dmlc::Stream::Create(n.c_str(), "w"));
}
auto bytes = fmt->Write(*page_, fo_.get());
timer.Stop();
LOG(INFO) << static_cast<double>(bytes) / 1024.0 / 1024.0 << " MB written in "
<< timer.ElapsedSeconds() << " seconds.";
cache_info_->offset.push_back(bytes);
}
@ -280,15 +291,24 @@ template <typename S>
class PageSourceIncMixIn : public SparsePageSourceImpl<S> {
protected:
std::shared_ptr<SparsePageSource> source_;
using Super = SparsePageSourceImpl<S>;
// synchronize the row page, `hist` and `gpu_hist` don't need the original sparse page
// so we avoid fetching it.
bool sync_{true};
public:
using SparsePageSourceImpl<S>::SparsePageSourceImpl;
PageSourceIncMixIn(float missing, int nthreads, bst_feature_t n_features, uint32_t n_batches,
std::shared_ptr<Cache> cache, bool sync)
: Super::SparsePageSourceImpl{missing, nthreads, n_features, n_batches, cache}, sync_{sync} {}
PageSourceIncMixIn& operator++() final {
TryLockGuard guard{this->single_threaded_};
if (sync_) {
++(*source_);
}
++this->count_;
this->at_end_ = source_->AtEnd();
this->at_end_ = this->count_ == this->n_batches_;
if (this->at_end_) {
this->cache_info_->Commit();
@ -299,7 +319,10 @@ class PageSourceIncMixIn : public SparsePageSourceImpl<S> {
} else {
this->Fetch();
}
if (sync_) {
CHECK_EQ(source_->Iter(), this->count_);
}
return *this;
}
};
@ -318,12 +341,9 @@ class CSCPageSource : public PageSourceIncMixIn<CSCPage> {
}
public:
CSCPageSource(
float missing, int nthreads, bst_feature_t n_features, uint32_t n_batches,
std::shared_ptr<Cache> cache,
std::shared_ptr<SparsePageSource> source)
: PageSourceIncMixIn(missing, nthreads, n_features,
n_batches, cache) {
CSCPageSource(float missing, int nthreads, bst_feature_t n_features, uint32_t n_batches,
std::shared_ptr<Cache> cache, std::shared_ptr<SparsePageSource> source)
: PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache, true) {
this->source_ = source;
this->Fetch();
}
@ -349,7 +369,7 @@ class SortedCSCPageSource : public PageSourceIncMixIn<SortedCSCPage> {
SortedCSCPageSource(float missing, int nthreads, bst_feature_t n_features,
uint32_t n_batches, std::shared_ptr<Cache> cache,
std::shared_ptr<SparsePageSource> source)
: PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache) {
: PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache, true) {
this->source_ = source;
this->Fetch();
}

View File

@ -1,5 +1,5 @@
/*!
* Copyright 2021 by XGBoost Contributors
* Copyright 2021-2022 by XGBoost Contributors
*/
#ifndef XGBOOST_TREE_HIST_HISTOGRAM_H_
#define XGBOOST_TREE_HIST_HISTOGRAM_H_
@ -8,10 +8,11 @@
#include <limits>
#include <vector>
#include "rabit/rabit.h"
#include "xgboost/tree_model.h"
#include "../../common/hist_util.h"
#include "../../data/gradient_index.h"
#include "expand_entry.h"
#include "rabit/rabit.h"
#include "xgboost/tree_model.h"
namespace xgboost {
namespace tree {
@ -323,6 +324,25 @@ template <typename GradientSumT, typename ExpandEntry> class HistogramBuilder {
(*sync_count) = std::max(1, n_left);
}
};
// Construct a work space for building histogram. Eventually we should move this
// function into histogram builder once hist tree method supports external memory.
template <typename Partitioner>
common::BlockedSpace2d ConstructHistSpace(Partitioner const &partitioners,
std::vector<CPUExpandEntry> const &nodes_to_build) {
std::vector<size_t> partition_size(nodes_to_build.size(), 0);
for (auto const &partition : partitioners) {
size_t k = 0;
for (auto node : nodes_to_build) {
auto n_rows_in_node = partition.Partitions()[node.nid].Size();
partition_size[k] = std::max(partition_size[k], n_rows_in_node);
k++;
}
}
common::BlockedSpace2d space{
nodes_to_build.size(), [&](size_t nidx_in_set) { return partition_size[nidx_in_set]; }, 256};
return space;
}
} // namespace tree
} // namespace xgboost
#endif // XGBOOST_TREE_HIST_HISTOGRAM_H_

10
src/tree/hist/param.cc Normal file
View File

@ -0,0 +1,10 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#include "param.h"
namespace xgboost {
namespace tree {
DMLC_REGISTER_PARAMETER(CPUHistMakerTrainParam);
} // namespace tree
} // namespace xgboost

View File

@ -94,7 +94,7 @@ class GloablApproxBuilder {
rabit::Allreduce<rabit::op::Sum, double>(reinterpret_cast<double *>(&root_sum), 2);
std::vector<CPUExpandEntry> nodes{best};
size_t i = 0;
auto space = this->ConstructHistSpace(nodes);
auto space = ConstructHistSpace(partitioner_, nodes);
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(BatchSpec(param_, hess))) {
histogram_builder_.BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(), nodes,
{}, gpair);
@ -123,25 +123,6 @@ class GloablApproxBuilder {
monitor_->Stop(__func__);
}
// Construct a work space for building histogram. Eventually we should move this
// function into histogram builder once hist tree method supports external memory.
common::BlockedSpace2d ConstructHistSpace(
std::vector<CPUExpandEntry> const &nodes_to_build) const {
std::vector<size_t> partition_size(nodes_to_build.size(), 0);
for (auto const &partition : partitioner_) {
size_t k = 0;
for (auto node : nodes_to_build) {
auto n_rows_in_node = partition.Partitions()[node.nid].Size();
partition_size[k] = std::max(partition_size[k], n_rows_in_node);
k++;
}
}
common::BlockedSpace2d space{nodes_to_build.size(),
[&](size_t nidx_in_set) { return partition_size[nidx_in_set]; },
256};
return space;
}
void BuildHistogram(DMatrix *p_fmat, RegTree *p_tree,
std::vector<CPUExpandEntry> const &valid_candidates,
std::vector<GradientPair> const &gpair, common::Span<float> hess) {
@ -164,7 +145,7 @@ class GloablApproxBuilder {
}
size_t i = 0;
auto space = this->ConstructHistSpace(nodes_to_build);
auto space = ConstructHistSpace(partitioner_, nodes_to_build);
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(BatchSpec(param_, hess))) {
histogram_builder_.BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(),
nodes_to_build, nodes_to_sub, gpair);
@ -191,7 +172,7 @@ class GloablApproxBuilder {
Driver<CPUExpandEntry> driver(static_cast<TrainParam::TreeGrowPolicy>(param_.grow_policy));
auto &tree = *p_tree;
driver.Push({this->InitRoot(p_fmat, gpair, hess, p_tree)});
bst_node_t num_leaves = 1;
bst_node_t num_leaves{1};
auto expand_set = driver.Pop();
/**
@ -223,10 +204,10 @@ class GloablApproxBuilder {
}
monitor_->Start("UpdatePosition");
size_t i = 0;
size_t page_id = 0;
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(BatchSpec(param_, hess))) {
partitioner_.at(i).UpdatePosition(ctx_, page, applied, p_tree);
i++;
partitioner_.at(page_id).UpdatePosition(ctx_, page, applied, p_tree);
page_id++;
}
monitor_->Stop("UpdatePosition");
@ -288,9 +269,9 @@ class GlobalApproxUpdater : public TreeUpdater {
out["hist_param"] = ToJson(hist_param_);
}
void InitData(TrainParam const &param, HostDeviceVector<GradientPair> *gpair,
void InitData(TrainParam const &param, HostDeviceVector<GradientPair> const *gpair,
std::vector<GradientPair> *sampled) {
auto const &h_gpair = gpair->HostVector();
auto const &h_gpair = gpair->ConstHostVector();
sampled->resize(h_gpair.size());
std::copy(h_gpair.cbegin(), h_gpair.cend(), sampled->begin());
auto &rnd = common::GlobalRandom();

View File

@ -4,80 +4,39 @@
* \brief use quantized feature values to construct a tree
* \author Philip Cho, Tianqi Checn, Egor Smirnov
*/
#include <dmlc/timer.h>
#include "./updater_quantile_hist.h"
#include <rabit/rabit.h>
#include <algorithm>
#include <cmath>
#include <iomanip>
#include <memory>
#include <numeric>
#include <queue>
#include <string>
#include <utility>
#include <vector>
#include "../common/column_matrix.h"
#include "../common/hist_util.h"
#include "../common/random.h"
#include "../common/threading_utils.h"
#include "constraints.h"
#include "hist/evaluate_splits.h"
#include "param.h"
#include "xgboost/logging.h"
#include "xgboost/tree_updater.h"
#include "constraints.h"
#include "param.h"
#include "./updater_quantile_hist.h"
#include "./split_evaluator.h"
#include "../common/random.h"
#include "../common/hist_util.h"
#include "../common/row_set.h"
#include "../common/column_matrix.h"
#include "../common/threading_utils.h"
namespace xgboost {
namespace tree {
DMLC_REGISTRY_FILE_TAG(updater_quantile_hist);
DMLC_REGISTER_PARAMETER(CPUHistMakerTrainParam);
void QuantileHistMaker::Configure(const Args &args) {
// initialize pruner
if (!pruner_) {
pruner_.reset(TreeUpdater::Create("prune", ctx_, task_));
}
pruner_->Configure(args);
param_.UpdateAllowUnknown(args);
hist_maker_param_.UpdateAllowUnknown(args);
}
template <typename GradientSumT>
void QuantileHistMaker::SetBuilder(const size_t n_trees,
std::unique_ptr<Builder<GradientSumT>>* builder, DMatrix* dmat) {
builder->reset(
new Builder<GradientSumT>(n_trees, param_, std::move(pruner_), dmat, task_, ctx_));
}
template<typename GradientSumT>
void QuantileHistMaker::CallBuilderUpdate(const std::unique_ptr<Builder<GradientSumT>>& builder,
HostDeviceVector<GradientPair> *gpair,
DMatrix *dmat,
GHistIndexMatrix const& gmat,
void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair, DMatrix *dmat,
const std::vector<RegTree *> &trees) {
for (auto tree : trees) {
builder->Update(gmat, column_matrix_, gpair, dmat, tree);
}
}
void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair,
DMatrix *dmat,
const std::vector<RegTree *> &trees) {
auto it = dmat->GetBatches<GHistIndexMatrix>(HistBatch(param_)).begin();
auto p_gmat = it.Page();
if (dmat != p_last_dmat_ || is_gmat_initialized_ == false) {
updater_monitor_.Start("GmatInitialization");
column_matrix_.Init(*p_gmat, param_.sparse_threshold, ctx_->Threads());
updater_monitor_.Stop("GmatInitialization");
// A proper solution is puting cut matrix in DMatrix, see:
// https://github.com/dmlc/xgboost/issues/5143
is_gmat_initialized_ = true;
}
// rescale learning rate according to size of trees
float lr = param_.learning_rate;
param_.learning_rate = lr / trees.size();
@ -86,19 +45,23 @@ void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair,
const size_t n_trees = trees.size();
if (hist_maker_param_.single_precision_histogram) {
if (!float_builder_) {
this->SetBuilder(n_trees, &float_builder_, dmat);
float_builder_.reset(new Builder<float>(n_trees, param_, dmat, task_, ctx_));
}
CallBuilderUpdate(float_builder_, gpair, dmat, *p_gmat, trees);
} else {
if (!double_builder_) {
SetBuilder(n_trees, &double_builder_, dmat);
double_builder_.reset(new Builder<double>(n_trees, param_, dmat, task_, ctx_));
}
}
for (auto p_tree : trees) {
if (hist_maker_param_.single_precision_histogram) {
this->float_builder_->UpdateTree(gpair, dmat, p_tree);
} else {
this->double_builder_->UpdateTree(gpair, dmat, p_tree);
}
CallBuilderUpdate(double_builder_, gpair, dmat, *p_gmat, trees);
}
param_.learning_rate = lr;
p_last_dmat_ = dmat;
}
bool QuantileHistMaker::UpdatePredictionCache(const DMatrix *data,
@ -113,23 +76,18 @@ bool QuantileHistMaker::UpdatePredictionCache(const DMatrix *data,
}
template <typename GradientSumT>
template <bool any_missing>
void QuantileHistMaker::Builder<GradientSumT>::InitRoot(
DMatrix *p_fmat, RegTree *p_tree, const std::vector<GradientPair> &gpair_h,
int *num_leaves, std::vector<CPUExpandEntry> *expand) {
CPUExpandEntry QuantileHistMaker::Builder<GradientSumT>::InitRoot(
DMatrix *p_fmat, RegTree *p_tree, const std::vector<GradientPair> &gpair_h) {
CPUExpandEntry node(RegTree::kRoot, p_tree->GetDepth(0), 0.0f);
nodes_for_explicit_hist_build_.clear();
nodes_for_subtraction_trick_.clear();
nodes_for_explicit_hist_build_.push_back(node);
auto const& row_set_collection = partitioner_.front().Partitions();
size_t page_id = 0;
for (auto const& gidx :
p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
this->histogram_builder_->BuildHist(
page_id, gidx, p_tree, row_set_collection,
nodes_for_explicit_hist_build_, nodes_for_subtraction_trick_, gpair_h);
auto space = ConstructHistSpace(partitioner_, {node});
for (auto const &gidx : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
std::vector<CPUExpandEntry> nodes_to_build{node};
std::vector<CPUExpandEntry> nodes_to_sub;
this->histogram_builder_->BuildHist(page_id, space, gidx, p_tree,
partitioner_.at(page_id).Partitions(), nodes_to_build,
nodes_to_sub, gpair_h);
++page_id;
}
@ -165,148 +123,118 @@ void QuantileHistMaker::Builder<GradientSumT>::InitRoot(
(*p_tree)[RegTree::kRoot].SetLeaf(param_.learning_rate * weight);
std::vector<CPUExpandEntry> entries{node};
builder_monitor_->Start("EvaluateSplits");
monitor_->Start("EvaluateSplits");
auto ft = p_fmat->Info().feature_types.ConstHostSpan();
for (auto const &gmat : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
evaluator_->EvaluateSplits(histogram_builder_->Histogram(), gmat.cut, ft,
*p_tree, &entries);
evaluator_->EvaluateSplits(histogram_builder_->Histogram(), gmat.cut, ft, *p_tree, &entries);
break;
}
builder_monitor_->Stop("EvaluateSplits");
monitor_->Stop("EvaluateSplits");
node = entries.front();
}
expand->push_back(node);
++(*num_leaves);
return node;
}
template <typename GradientSumT>
void QuantileHistMaker::Builder<GradientSumT>::AddSplitsToTree(
const std::vector<CPUExpandEntry>& expand,
RegTree *p_tree,
int *num_leaves,
std::vector<CPUExpandEntry>* nodes_for_apply_split) {
for (auto const& entry : expand) {
if (entry.IsValid(param_, *num_leaves)) {
nodes_for_apply_split->push_back(entry);
evaluator_->ApplyTreeSplit(entry, p_tree);
(*num_leaves)++;
}
void QuantileHistMaker::Builder<GradientSumT>::BuildHistogram(
DMatrix *p_fmat, RegTree *p_tree, std::vector<CPUExpandEntry> const &valid_candidates,
std::vector<GradientPair> const &gpair) {
std::vector<CPUExpandEntry> nodes_to_build(valid_candidates.size());
std::vector<CPUExpandEntry> nodes_to_sub(valid_candidates.size());
size_t n_idx = 0;
for (auto const &c : valid_candidates) {
auto left_nidx = (*p_tree)[c.nid].LeftChild();
auto right_nidx = (*p_tree)[c.nid].RightChild();
auto fewer_right = c.split.right_sum.GetHess() < c.split.left_sum.GetHess();
auto build_nidx = left_nidx;
auto subtract_nidx = right_nidx;
if (fewer_right) {
std::swap(build_nidx, subtract_nidx);
}
nodes_to_build[n_idx] = CPUExpandEntry{build_nidx, p_tree->GetDepth(build_nidx), {}};
nodes_to_sub[n_idx] = CPUExpandEntry{subtract_nidx, p_tree->GetDepth(subtract_nidx), {}};
n_idx++;
}
// Split nodes to 2 sets depending on amount of rows in each node
// Histograms for small nodes will be built explicitly
// Histograms for big nodes will be built by 'Subtraction Trick'
// Exception: in distributed setting, we always build the histogram for the left child node
// and use 'Subtraction Trick' to built the histogram for the right child node.
// This ensures that the workers operate on the same set of tree nodes.
template <typename GradientSumT>
void QuantileHistMaker::Builder<GradientSumT>::SplitSiblings(
const std::vector<CPUExpandEntry> &nodes_for_apply_split,
std::vector<CPUExpandEntry> *nodes_to_evaluate, RegTree *p_tree) {
builder_monitor_->Start("SplitSiblings");
auto const& row_set_collection = this->partitioner_.front().Partitions();
for (auto const& entry : nodes_for_apply_split) {
int nid = entry.nid;
const int cleft = (*p_tree)[nid].LeftChild();
const int cright = (*p_tree)[nid].RightChild();
const CPUExpandEntry left_node = CPUExpandEntry(cleft, p_tree->GetDepth(cleft), 0.0);
const CPUExpandEntry right_node = CPUExpandEntry(cright, p_tree->GetDepth(cright), 0.0);
nodes_to_evaluate->push_back(left_node);
nodes_to_evaluate->push_back(right_node);
if (row_set_collection[cleft].Size() < row_set_collection[cright].Size()) {
nodes_for_explicit_hist_build_.push_back(left_node);
nodes_for_subtraction_trick_.push_back(right_node);
} else {
nodes_for_explicit_hist_build_.push_back(right_node);
nodes_for_subtraction_trick_.push_back(left_node);
size_t page_id{0};
auto space = ConstructHistSpace(partitioner_, nodes_to_build);
for (auto const &gidx : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
histogram_builder_->BuildHist(page_id, space, gidx, p_tree,
partitioner_.at(page_id).Partitions(), nodes_to_build,
nodes_to_sub, gpair);
++page_id;
}
}
CHECK_EQ(nodes_for_subtraction_trick_.size(), nodes_for_explicit_hist_build_.size());
builder_monitor_->Stop("SplitSiblings");
}
template <typename GradientSumT>
template <bool any_missing>
void QuantileHistMaker::Builder<GradientSumT>::ExpandTree(
const GHistIndexMatrix& gmat,
const common::ColumnMatrix& column_matrix,
DMatrix* p_fmat,
RegTree* p_tree,
const std::vector<GradientPair>& gpair_h) {
builder_monitor_->Start("ExpandTree");
int num_leaves = 0;
DMatrix *p_fmat, RegTree *p_tree, const std::vector<GradientPair> &gpair_h) {
monitor_->Start(__func__);
Driver<CPUExpandEntry> driver(static_cast<TrainParam::TreeGrowPolicy>(param_.grow_policy));
std::vector<CPUExpandEntry> expand;
InitRoot<any_missing>(p_fmat, p_tree, gpair_h, &num_leaves, &expand);
driver.Push(expand[0]);
driver.Push(this->InitRoot(p_fmat, p_tree, gpair_h));
bst_node_t num_leaves{1};
auto expand_set = driver.Pop();
int32_t depth = 0;
while (!driver.IsEmpty()) {
expand = driver.Pop();
depth = expand[0].depth + 1;
std::vector<CPUExpandEntry> nodes_for_apply_split;
std::vector<CPUExpandEntry> nodes_to_evaluate;
nodes_for_explicit_hist_build_.clear();
nodes_for_subtraction_trick_.clear();
AddSplitsToTree(expand, p_tree, &num_leaves, &nodes_for_apply_split);
if (nodes_for_apply_split.size() != 0) {
HistRowPartitioner &partitioner = this->partitioner_.front();
if (gmat.cut.HasCategorical()) {
partitioner.UpdatePosition<any_missing, true>(this->ctx_, gmat, column_matrix,
nodes_for_apply_split, p_tree);
} else {
partitioner.UpdatePosition<any_missing, false>(this->ctx_, gmat, column_matrix,
nodes_for_apply_split, p_tree);
while (!expand_set.empty()) {
// candidates that can be further splited.
std::vector<CPUExpandEntry> valid_candidates;
// candidaates that can be applied.
std::vector<CPUExpandEntry> applied;
int32_t depth = expand_set.front().depth + 1;
for (auto const& candidate : expand_set) {
if (!candidate.IsValid(param_, num_leaves)) {
continue;
}
evaluator_->ApplyTreeSplit(candidate, p_tree);
applied.push_back(candidate);
num_leaves++;
if (CPUExpandEntry::ChildIsValid(param_, depth, num_leaves)) {
valid_candidates.emplace_back(candidate);
}
}
SplitSiblings(nodes_for_apply_split, &nodes_to_evaluate, p_tree);
if (param_.max_depth == 0 || depth < param_.max_depth) {
size_t i = 0;
for (auto const &gidx : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
this->histogram_builder_->BuildHist(i, gidx, p_tree, partitioner_.front().Partitions(),
nodes_for_explicit_hist_build_,
nodes_for_subtraction_trick_, gpair_h);
++i;
}
} else {
int starting_index = std::numeric_limits<int>::max();
int sync_count = 0;
this->histogram_builder_->AddHistRows(
&starting_index, &sync_count, nodes_for_explicit_hist_build_,
nodes_for_subtraction_trick_, p_tree);
monitor_->Start("UpdatePosition");
size_t page_id{0};
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
partitioner_.at(page_id).UpdatePosition(ctx_, page, applied, p_tree);
++page_id;
}
monitor_->Stop("UpdatePosition");
builder_monitor_->Start("EvaluateSplits");
std::vector<CPUExpandEntry> best_splits;
if (!valid_candidates.empty()) {
this->BuildHistogram(p_fmat, p_tree, valid_candidates, gpair_h);
auto const &tree = *p_tree;
for (auto const &candidate : valid_candidates) {
int left_child_nidx = tree[candidate.nid].LeftChild();
int right_child_nidx = tree[candidate.nid].RightChild();
CPUExpandEntry l_best{left_child_nidx, depth, 0.0};
CPUExpandEntry r_best{right_child_nidx, depth, 0.0};
best_splits.push_back(l_best);
best_splits.push_back(r_best);
}
auto const &histograms = histogram_builder_->Histogram();
auto ft = p_fmat->Info().feature_types.ConstHostSpan();
evaluator_->EvaluateSplits(this->histogram_builder_->Histogram(),
gmat.cut, ft, *p_tree, &nodes_to_evaluate);
builder_monitor_->Stop("EvaluateSplits");
for (auto const &gmat : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
evaluator_->EvaluateSplits(histograms, gmat.cut, ft, *p_tree, &best_splits);
break;
}
}
driver.Push(best_splits.begin(), best_splits.end());
expand_set = driver.Pop();
}
for (size_t i = 0; i < nodes_for_apply_split.size(); ++i) {
CPUExpandEntry left_node = nodes_to_evaluate.at(i * 2 + 0);
CPUExpandEntry right_node = nodes_to_evaluate.at(i * 2 + 1);
driver.Push(left_node);
driver.Push(right_node);
}
}
}
builder_monitor_->Stop("ExpandTree");
monitor_->Stop(__func__);
}
template <typename GradientSumT>
void QuantileHistMaker::Builder<GradientSumT>::Update(
const GHistIndexMatrix &gmat,
const common::ColumnMatrix &column_matrix,
HostDeviceVector<GradientPair> *gpair,
void QuantileHistMaker::Builder<GradientSumT>::UpdateTree(HostDeviceVector<GradientPair> *gpair,
DMatrix *p_fmat, RegTree *p_tree) {
builder_monitor_->Start("Update");
monitor_->Start(__func__);
std::vector<GradientPair> *gpair_ptr = &(gpair->HostVector());
// in case 'num_parallel_trees != 1' no posibility to change initial gpair
@ -315,18 +243,12 @@ void QuantileHistMaker::Builder<GradientSumT>::Update(
gpair_local_ = *gpair_ptr;
gpair_ptr = &gpair_local_;
}
p_last_fmat_mutable_ = p_fmat;
this->InitData(gmat, p_fmat, *p_tree, gpair_ptr);
this->InitData(p_fmat, *p_tree, gpair_ptr);
if (column_matrix.AnyMissing()) {
ExpandTree<true>(gmat, column_matrix, p_fmat, p_tree, *gpair_ptr);
} else {
ExpandTree<false>(gmat, column_matrix, p_fmat, p_tree, *gpair_ptr);
}
pruner_->Update(gpair, p_fmat, std::vector<RegTree*>{p_tree});
ExpandTree(p_fmat, p_tree, *gpair_ptr);
builder_monitor_->Stop("Update");
monitor_->Stop(__func__);
}
template <typename GradientSumT>
@ -334,20 +256,20 @@ bool QuantileHistMaker::Builder<GradientSumT>::UpdatePredictionCache(
DMatrix const *data, linalg::VectorView<float> out_preds) const {
// p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in
// conjunction with Update().
if (!p_last_fmat_ || !p_last_tree_ || data != p_last_fmat_ ||
p_last_fmat_ != p_last_fmat_mutable_) {
if (!p_last_fmat_ || !p_last_tree_ || data != p_last_fmat_) {
return false;
}
builder_monitor_->Start(__func__);
monitor_->Start(__func__);
CHECK_EQ(out_preds.Size(), data->Info().num_row_);
UpdatePredictionCacheImpl(ctx_, p_last_tree_, partitioner_, *evaluator_, param_, out_preds);
builder_monitor_->Stop(__func__);
monitor_->Stop(__func__);
return true;
}
template <typename GradientSumT>
void QuantileHistMaker::Builder<GradientSumT>::InitSampling(const DMatrix &fmat,
std::vector<GradientPair> *gpair) {
monitor_->Start(__func__);
const auto &info = fmat.Info();
auto& rnd = common::GlobalRandom();
std::vector<GradientPair>& gpair_ref = *gpair;
@ -380,6 +302,7 @@ void QuantileHistMaker::Builder<GradientSumT>::InitSampling(const DMatrix &fmat,
}
exc.Rethrow();
#endif // XGBOOST_CUSTOMIZE_GLOBAL_PRNG
monitor_->Stop(__func__);
}
template<typename GradientSumT>
size_t QuantileHistMaker::Builder<GradientSumT>::GetNumberOfTrees() {
@ -387,10 +310,9 @@ size_t QuantileHistMaker::Builder<GradientSumT>::GetNumberOfTrees() {
}
template <typename GradientSumT>
void QuantileHistMaker::Builder<GradientSumT>::InitData(const GHistIndexMatrix &gmat, DMatrix *fmat,
const RegTree &tree,
void QuantileHistMaker::Builder<GradientSumT>::InitData(DMatrix *fmat, const RegTree &tree,
std::vector<GradientPair> *gpair) {
builder_monitor_->Start("InitData");
monitor_->Start(__func__);
const auto& info = fmat->Info();
{
@ -406,18 +328,14 @@ void QuantileHistMaker::Builder<GradientSumT>::InitData(const GHistIndexMatrix &
partitioner_.emplace_back(page.Size(), page.base_rowid, this->ctx_->Threads());
++page_id;
}
histogram_builder_->Reset(n_total_bins, BatchParam{param_.max_bin, param_.sparse_threshold},
ctx_->Threads(), page_id, rabit::IsDistributed());
histogram_builder_->Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id,
rabit::IsDistributed());
if (param_.subsample < 1.0f) {
CHECK_EQ(param_.sampling_method, TrainParam::kUniform)
<< "Only uniform sampling is supported, "
<< "gradient-based sampling is only support by GPU Hist.";
builder_monitor_->Start("InitSampling");
InitSampling(*fmat, gpair);
builder_monitor_->Stop("InitSampling");
// We should check that the partitioning was done correctly
// and each row of the dataset fell into exactly one of the categories
}
}
@ -426,7 +344,7 @@ void QuantileHistMaker::Builder<GradientSumT>::InitData(const GHistIndexMatrix &
evaluator_.reset(new HistEvaluator<GradientSumT, CPUExpandEntry>{
param_, info, this->ctx_->Threads(), column_sampler_, task_});
builder_monitor_->Stop("InitData");
monitor_->Stop(__func__);
}
void HistRowPartitioner::FindSplitConditions(const std::vector<CPUExpandEntry> &nodes,
@ -470,21 +388,8 @@ void HistRowPartitioner::AddSplitsToRowSet(const std::vector<CPUExpandEntry> &no
template struct QuantileHistMaker::Builder<float>;
template struct QuantileHistMaker::Builder<double>;
XGBOOST_REGISTER_TREE_UPDATER(FastHistMaker, "grow_fast_histmaker")
.describe("(Deprecated, use grow_quantile_histmaker instead.)"
" Grow tree using quantized histogram.")
.set_body(
[](ObjInfo task) {
LOG(WARNING) << "grow_fast_histmaker is deprecated, "
<< "use grow_quantile_histmaker instead.";
return new QuantileHistMaker(task);
});
XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker")
.describe("Grow tree using quantized histogram.")
.set_body(
[](ObjInfo task) {
return new QuantileHistMaker(task);
});
.set_body([](ObjInfo task) { return new QuantileHistMaker(task); });
} // namespace tree
} // namespace xgboost

View File

@ -7,7 +7,6 @@
#ifndef XGBOOST_TREE_UPDATER_QUANTILE_HIST_H_
#define XGBOOST_TREE_UPDATER_QUANTILE_HIST_H_
#include <dmlc/timer.h>
#include <rabit/rabit.h>
#include <xgboost/tree_updater.h>
@ -29,7 +28,6 @@
#include "constraints.h"
#include "./param.h"
#include "./driver.h"
#include "./split_evaluator.h"
#include "../common/random.h"
#include "../common/timer.h"
#include "../common/hist_util.h"
@ -194,6 +192,24 @@ class HistRowPartitioner {
AddSplitsToRowSet(nodes, p_tree);
}
void UpdatePosition(GenericParameter const* ctx, GHistIndexMatrix const& page,
std::vector<CPUExpandEntry> const& applied, RegTree const* p_tree) {
auto const& column_matrix = page.Transpose();
if (page.cut.HasCategorical()) {
if (column_matrix.AnyMissing()) {
this->template UpdatePosition<true, true>(ctx, page, column_matrix, applied, p_tree);
} else {
this->template UpdatePosition<false, true>(ctx, page, column_matrix, applied, p_tree);
}
} else {
if (column_matrix.AnyMissing()) {
this->template UpdatePosition<true, false>(ctx, page, column_matrix, applied, p_tree);
} else {
this->template UpdatePosition<false, false>(ctx, page, column_matrix, applied, p_tree);
}
}
}
auto const& Partitions() const { return row_set_collection_; }
size_t Size() const {
return std::distance(row_set_collection_.begin(), row_set_collection_.end());
@ -209,9 +225,7 @@ inline BatchParam HistBatch(TrainParam const& param) {
/*! \brief construct a tree using quantized feature values */
class QuantileHistMaker: public TreeUpdater {
public:
explicit QuantileHistMaker(ObjInfo task) : task_{task} {
updater_monitor_.Init("QuantileHistMaker");
}
explicit QuantileHistMaker(ObjInfo task) : task_{task} {}
void Configure(const Args& args) override;
void Update(HostDeviceVector<GradientPair>* gpair,
@ -256,10 +270,6 @@ class QuantileHistMaker: public TreeUpdater {
CPUHistMakerTrainParam hist_maker_param_;
// training parameter
TrainParam param_;
// column accessor
common::ColumnMatrix column_matrix_;
DMatrix const* p_last_dmat_ {nullptr};
bool is_gmat_initialized_ {false};
// actual builder that runs the algorithm
template<typename GradientSumT>
@ -267,60 +277,40 @@ class QuantileHistMaker: public TreeUpdater {
public:
using GradientPairT = xgboost::detail::GradientPairInternal<GradientSumT>;
// constructor
explicit Builder(const size_t n_trees, const TrainParam& param,
std::unique_ptr<TreeUpdater> pruner, DMatrix const* fmat, ObjInfo task,
GenericParameter const* ctx)
explicit Builder(const size_t n_trees, const TrainParam& param, DMatrix const* fmat,
ObjInfo task, GenericParameter const* ctx)
: n_trees_(n_trees),
param_(param),
pruner_(std::move(pruner)),
p_last_fmat_(fmat),
histogram_builder_{new HistogramBuilder<GradientSumT, CPUExpandEntry>},
task_{task},
ctx_{ctx},
builder_monitor_{std::make_unique<common::Monitor>()} {
builder_monitor_->Init("Quantile::Builder");
monitor_{std::make_unique<common::Monitor>()} {
monitor_->Init("Quantile::Builder");
}
// update one tree, growing
void Update(const GHistIndexMatrix& gmat, const common::ColumnMatrix& column_matrix,
HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat, RegTree* p_tree);
void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat, RegTree* p_tree);
bool UpdatePredictionCache(DMatrix const* data, linalg::VectorView<float> out_preds) const;
protected:
private:
// initialize temp data structure
void InitData(const GHistIndexMatrix& gmat, DMatrix* fmat, const RegTree& tree,
std::vector<GradientPair>* gpair);
void InitData(DMatrix* fmat, const RegTree& tree, std::vector<GradientPair>* gpair);
size_t GetNumberOfTrees();
void InitSampling(const DMatrix& fmat, std::vector<GradientPair>* gpair);
template <bool any_missing>
void InitRoot(DMatrix* p_fmat,
RegTree *p_tree,
const std::vector<GradientPair> &gpair_h,
int *num_leaves, std::vector<CPUExpandEntry> *expand);
// Split nodes to 2 sets depending on amount of rows in each node
// Histograms for small nodes will be built explicitly
// Histograms for big nodes will be built by 'Subtraction Trick'
void SplitSiblings(const std::vector<CPUExpandEntry>& nodes,
std::vector<CPUExpandEntry>* nodes_to_evaluate,
RegTree *p_tree);
void AddSplitsToTree(const std::vector<CPUExpandEntry>& expand,
RegTree *p_tree,
int *num_leaves,
std::vector<CPUExpandEntry>* nodes_for_apply_split);
template <bool any_missing>
void ExpandTree(const GHistIndexMatrix& gmat,
const common::ColumnMatrix& column_matrix,
DMatrix* p_fmat,
RegTree* p_tree,
CPUExpandEntry InitRoot(DMatrix* p_fmat, RegTree* p_tree,
const std::vector<GradientPair>& gpair_h);
// --data fields--
void BuildHistogram(DMatrix* p_fmat, RegTree* p_tree,
std::vector<CPUExpandEntry> const& valid_candidates,
std::vector<GradientPair> const& gpair);
void ExpandTree(DMatrix* p_fmat, RegTree* p_tree, const std::vector<GradientPair>& gpair_h);
private:
const size_t n_trees_;
const TrainParam& param_;
std::shared_ptr<common::ColumnSampler> column_sampler_{
@ -328,48 +318,24 @@ class QuantileHistMaker: public TreeUpdater {
std::vector<GradientPair> gpair_local_;
std::unique_ptr<TreeUpdater> pruner_;
std::unique_ptr<HistEvaluator<GradientSumT, CPUExpandEntry>> evaluator_;
// Right now there's only 1 partitioner in this vector, when external memory is fully
// supported we will have number of partitioners equal to number of pages.
std::vector<HistRowPartitioner> partitioner_;
// back pointers to tree and data matrix
const RegTree* p_last_tree_{nullptr};
DMatrix const* const p_last_fmat_;
DMatrix* p_last_fmat_mutable_;
// key is the node id which should be calculated by Subtraction Trick, value is the node which
// provides the evidence for subtraction
std::vector<CPUExpandEntry> nodes_for_subtraction_trick_;
// list of nodes whose histograms would be built explicitly.
std::vector<CPUExpandEntry> nodes_for_explicit_hist_build_;
enum class DataLayout { kDenseDataZeroBased, kDenseDataOneBased, kSparseData };
std::unique_ptr<HistogramBuilder<GradientSumT, CPUExpandEntry>> histogram_builder_;
ObjInfo task_;
// Context for number of threads
GenericParameter const* ctx_;
std::unique_ptr<common::Monitor> builder_monitor_;
std::unique_ptr<common::Monitor> monitor_;
};
common::Monitor updater_monitor_;
template<typename GradientSumT>
void SetBuilder(const size_t n_trees, std::unique_ptr<Builder<GradientSumT>>*, DMatrix *dmat);
template<typename GradientSumT>
void CallBuilderUpdate(const std::unique_ptr<Builder<GradientSumT>>& builder,
HostDeviceVector<GradientPair> *gpair,
DMatrix *dmat,
GHistIndexMatrix const& gmat,
const std::vector<RegTree *> &trees);
protected:
std::unique_ptr<Builder<float>> float_builder_;
std::unique_ptr<Builder<double>> double_builder_;
std::unique_ptr<TreeUpdater> pruner_;
ObjInfo task_;
};
} // namespace tree

View File

@ -21,7 +21,9 @@ TEST(DenseColumn, Test) {
GHistIndexMatrix gmat{dmat.get(), max_num_bin, sparse_thresh, false,
common::OmpGetNumThreads(0)};
ColumnMatrix column_matrix;
column_matrix.Init(gmat, 0.2, common::OmpGetNumThreads(0));
for (auto const& page : dmat->GetBatches<SparsePage>()) {
column_matrix.Init(page, gmat, sparse_thresh, common::OmpGetNumThreads(0));
}
for (auto i = 0ull; i < dmat->Info().num_row_; i++) {
for (auto j = 0ull; j < dmat->Info().num_col_; j++) {
@ -68,7 +70,9 @@ TEST(SparseColumn, Test) {
auto dmat = RandomDataGenerator(100, 1, 0.85).GenerateDMatrix();
GHistIndexMatrix gmat{dmat.get(), max_num_bin, 0.5f, false, common::OmpGetNumThreads(0)};
ColumnMatrix column_matrix;
column_matrix.Init(gmat, 0.5, common::OmpGetNumThreads(0));
for (auto const& page : dmat->GetBatches<SparsePage>()) {
column_matrix.Init(page, gmat, 1.0, common::OmpGetNumThreads(0));
}
switch (column_matrix.GetTypeSize()) {
case kUint8BinsTypeSize: {
auto col = column_matrix.GetColumn<uint8_t, true>(0);
@ -106,9 +110,11 @@ TEST(DenseColumnWithMissing, Test) {
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 2};
for (int32_t max_num_bin : max_num_bins) {
auto dmat = RandomDataGenerator(100, 1, 0.5).GenerateDMatrix();
GHistIndexMatrix gmat{dmat.get(), max_num_bin, 0.2, false, common::OmpGetNumThreads(0)};
GHistIndexMatrix gmat(dmat.get(), max_num_bin, 0.2, false, common::OmpGetNumThreads(0));
ColumnMatrix column_matrix;
column_matrix.Init(gmat, 0.2, common::OmpGetNumThreads(0));
for (auto const& page : dmat->GetBatches<SparsePage>()) {
column_matrix.Init(page, gmat, 0.2, common::OmpGetNumThreads(0));
}
switch (column_matrix.GetTypeSize()) {
case kUint8BinsTypeSize: {
auto col = column_matrix.GetColumn<uint8_t, true>(0);

View File

@ -3,6 +3,7 @@
*/
#include <gtest/gtest.h>
#include "../../../src/common/column_matrix.h"
#include "../../../src/data/gradient_index.h"
#include "../../../src/data/sparse_page_source.h"
#include "../helpers.h"
@ -15,33 +16,31 @@ TEST(GHistIndexPageRawFormat, IO) {
auto m = RandomDataGenerator{100, 14, 0.5}.GenerateDMatrix();
dmlc::TemporaryDirectory tmpdir;
std::string path = tmpdir.path + "/ghistindex.page";
auto batch = BatchParam{256, 0.5};
{
std::unique_ptr<dmlc::Stream> fo{dmlc::Stream::Create(path.c_str(), "w")};
for (auto const &index :
m->GetBatches<GHistIndexMatrix>({GenericParameter::kCpuId, 256})) {
for (auto const &index : m->GetBatches<GHistIndexMatrix>(batch)) {
format->Write(index, fo.get());
}
}
GHistIndexMatrix page;
std::unique_ptr<dmlc::SeekStream> fi{
dmlc::SeekStream::CreateForRead(path.c_str())};
std::unique_ptr<dmlc::SeekStream> fi{dmlc::SeekStream::CreateForRead(path.c_str())};
format->Read(&page, fi.get());
for (auto const &gidx :
m->GetBatches<GHistIndexMatrix>({GenericParameter::kCpuId, 256})) {
for (auto const &gidx : m->GetBatches<GHistIndexMatrix>(batch)) {
auto const &loaded = gidx;
ASSERT_EQ(loaded.cut.Ptrs(), page.cut.Ptrs());
ASSERT_EQ(loaded.cut.MinValues(), page.cut.MinValues());
ASSERT_EQ(loaded.cut.Values(), page.cut.Values());
ASSERT_EQ(loaded.base_rowid, page.base_rowid);
ASSERT_EQ(loaded.IsDense(), page.IsDense());
ASSERT_TRUE(std::equal(loaded.index.begin(), loaded.index.end(),
page.index.begin()));
ASSERT_TRUE(std::equal(loaded.index.Offset(),
loaded.index.Offset() + loaded.index.OffsetSize(),
ASSERT_TRUE(std::equal(loaded.index.begin(), loaded.index.end(), page.index.begin()));
ASSERT_TRUE(std::equal(loaded.index.Offset(), loaded.index.Offset() + loaded.index.OffsetSize(),
page.index.Offset()));
ASSERT_EQ(loaded.Transpose().GetTypeSize(), loaded.Transpose().GetTypeSize());
}
}
} // namespace data

View File

@ -446,6 +446,12 @@ void TestHistogramExternalMemory(BatchParam batch_param, bool is_approx) {
TEST(CPUHistogram, ExternalMemory) {
int32_t constexpr kBins = 256;
TestHistogramExternalMemory(BatchParam{kBins, common::Span<float>{}, false}, true);
float sparse_thresh{0.5};
TestHistogramExternalMemory({kBins, sparse_thresh}, false);
sparse_thresh = std::numeric_limits<float>::quiet_NaN();
TestHistogramExternalMemory({kBins, sparse_thresh}, false);
}
} // namespace tree
} // namespace xgboost

View File

@ -18,138 +18,6 @@
namespace xgboost {
namespace tree {
class QuantileHistMock : public QuantileHistMaker {
static double constexpr kEps = 1e-6;
template <typename GradientSumT>
struct BuilderMock : public QuantileHistMaker::Builder<GradientSumT> {
using RealImpl = QuantileHistMaker::Builder<GradientSumT>;
BuilderMock(const TrainParam &param, std::unique_ptr<TreeUpdater> pruner,
DMatrix const *fmat, GenericParameter const* ctx)
: RealImpl(1, param, std::move(pruner), fmat, ObjInfo{ObjInfo::kRegression}, ctx) {}
public:
void TestInitData(const GHistIndexMatrix& gmat,
std::vector<GradientPair>* gpair,
DMatrix* p_fmat,
const RegTree& tree) {
RealImpl::InitData(gmat, p_fmat, tree, gpair);
/* The creation of HistCutMatrix and GHistIndexMatrix are not technically
* part of QuantileHist updater logic, but we include it here because
* QuantileHist updater object currently stores GHistIndexMatrix
* internally. According to https://github.com/dmlc/xgboost/pull/3803,
* we should eventually move GHistIndexMatrix out of the QuantileHist
* updater. */
const size_t num_row = p_fmat->Info().num_row_;
const size_t num_col = p_fmat->Info().num_col_;
/* Validate HistCutMatrix */
ASSERT_EQ(gmat.cut.Ptrs().size(), num_col + 1);
for (size_t fid = 0; fid < num_col; ++fid) {
const size_t ibegin = gmat.cut.Ptrs()[fid];
const size_t iend = gmat.cut.Ptrs()[fid + 1];
// Ordered, but empty feature is allowed.
ASSERT_LE(ibegin, iend);
for (size_t i = ibegin; i < iend - 1; ++i) {
// Quantile points must be sorted in ascending order
// No duplicates allowed
ASSERT_LT(gmat.cut.Values()[i], gmat.cut.Values()[i + 1])
<< "ibegin: " << ibegin << ", "
<< "iend: " << iend;
}
}
/* Validate GHistIndexMatrix */
ASSERT_EQ(gmat.row_ptr.size(), num_row + 1);
ASSERT_LT(*std::max_element(gmat.index.begin(), gmat.index.end()),
gmat.cut.Ptrs().back());
for (const auto& batch : p_fmat->GetBatches<xgboost::SparsePage>()) {
auto page = batch.GetView();
for (size_t i = 0; i < batch.Size(); ++i) {
const size_t rid = batch.base_rowid + i;
ASSERT_LT(rid, num_row);
const size_t gmat_row_offset = gmat.row_ptr[rid];
ASSERT_LT(gmat_row_offset, gmat.index.Size());
SparsePage::Inst inst = page[i];
ASSERT_EQ(gmat.row_ptr[rid] + inst.size(), gmat.row_ptr[rid + 1]);
for (size_t j = 0; j < inst.size(); ++j) {
// Each entry of GHistIndexMatrix represents a bin ID
const size_t bin_id = gmat.index[gmat_row_offset + j];
const size_t fid = inst[j].index;
// The bin ID must correspond to correct feature
ASSERT_GE(bin_id, gmat.cut.Ptrs()[fid]);
ASSERT_LT(bin_id, gmat.cut.Ptrs()[fid + 1]);
// The bin ID must correspond to a region between two
// suitable quantile points
ASSERT_LT(inst[j].fvalue, gmat.cut.Values()[bin_id]);
if (bin_id > gmat.cut.Ptrs()[fid]) {
ASSERT_GE(inst[j].fvalue, gmat.cut.Values()[bin_id - 1]);
} else {
ASSERT_GE(inst[j].fvalue, gmat.cut.MinValues()[fid]);
}
}
}
}
}
};
int static constexpr kNRows = 8, kNCols = 16;
std::shared_ptr<xgboost::DMatrix> dmat_;
GenericParameter ctx_;
const std::vector<std::pair<std::string, std::string> > cfg_;
std::shared_ptr<BuilderMock<float> > float_builder_;
std::shared_ptr<BuilderMock<double> > double_builder_;
public:
explicit QuantileHistMock(
const std::vector<std::pair<std::string, std::string> >& args,
const bool single_precision_histogram = false, bool batch = true) :
QuantileHistMaker{ObjInfo{ObjInfo::kRegression}}, cfg_{args} {
QuantileHistMaker::Configure(args);
dmat_ = RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix();
ctx_.UpdateAllowUnknown(Args{});
if (single_precision_histogram) {
float_builder_.reset(new BuilderMock<float>(param_, std::move(pruner_), dmat_.get(), &ctx_));
} else {
double_builder_.reset(
new BuilderMock<double>(param_, std::move(pruner_), dmat_.get(), &ctx_));
}
}
~QuantileHistMock() override = default;
static size_t GetNumColumns() { return kNCols; }
void TestInitData() {
int32_t constexpr kMaxBins = 4;
GHistIndexMatrix gmat{dmat_.get(), kMaxBins, 0.0f, false, common::OmpGetNumThreads(0)};
RegTree tree = RegTree();
tree.param.UpdateAllowUnknown(cfg_);
std::vector<GradientPair> gpair =
{ {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f},
{0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} };
if (double_builder_) {
double_builder_->TestInitData(gmat, &gpair, dmat_.get(), tree);
} else {
float_builder_->TestInitData(gmat, &gpair, dmat_.get(), tree);
}
}
};
TEST(QuantileHist, InitData) {
std::vector<std::pair<std::string, std::string>> cfg
{{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())}};
QuantileHistMock maker(cfg);
maker.TestInitData();
const bool single_precision_histogram = true;
QuantileHistMock maker_float(cfg, single_precision_histogram);
maker_float.TestInitData();
}
TEST(QuantileHist, Partitioner) {
size_t n_samples = 1024, n_features = 1, base_rowid = 0;
GenericParameter ctx;
@ -163,45 +31,44 @@ TEST(QuantileHist, Partitioner) {
auto Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true);
std::vector<CPUExpandEntry> candidates{{0, 0, 0.4}};
auto grad = GenerateRandomGradients(n_samples);
std::vector<float> hess(grad.Size());
std::transform(grad.HostVector().cbegin(), grad.HostVector().cend(), hess.begin(),
[](auto gpair) { return gpair.GetHess(); });
auto cuts = common::SketchOnDMatrix(Xy.get(), 64, ctx.Threads());
for (auto const& page : Xy->GetBatches<GHistIndexMatrix>({64, 0.5})) {
for (auto const& page : Xy->GetBatches<SparsePage>()) {
GHistIndexMatrix gmat;
gmat.Init(page, {}, cuts, 64, false, 0.5, ctx.Threads());
bst_feature_t const split_ind = 0;
common::ColumnMatrix column_indices;
column_indices.Init(page, 0.5, ctx.Threads());
column_indices.Init(page, gmat, 0.5, ctx.Threads());
{
auto min_value = page.cut.MinValues()[split_ind];
auto min_value = gmat.cut.MinValues()[split_ind];
RegTree tree;
HistRowPartitioner partitioner{n_samples, base_rowid, ctx.Threads()};
GetSplit(&tree, min_value, &candidates);
partitioner.UpdatePosition<false, true>(&ctx, page, column_indices, candidates, &tree);
partitioner.UpdatePosition<false, true>(&ctx, gmat, column_indices, candidates, &tree);
ASSERT_EQ(partitioner.Size(), 3);
ASSERT_EQ(partitioner[1].Size(), 0);
ASSERT_EQ(partitioner[2].Size(), n_samples);
}
{
HistRowPartitioner partitioner{n_samples, base_rowid, ctx.Threads()};
auto ptr = page.cut.Ptrs()[split_ind + 1];
float split_value = page.cut.Values().at(ptr / 2);
auto ptr = gmat.cut.Ptrs()[split_ind + 1];
float split_value = gmat.cut.Values().at(ptr / 2);
RegTree tree;
GetSplit(&tree, split_value, &candidates);
auto left_nidx = tree[RegTree::kRoot].LeftChild();
partitioner.UpdatePosition<false, true>(&ctx, page, column_indices, candidates, &tree);
partitioner.UpdatePosition<false, true>(&ctx, gmat, column_indices, candidates, &tree);
auto elem = partitioner[left_nidx];
ASSERT_LT(elem.Size(), n_samples);
ASSERT_GT(elem.Size(), 1);
for (auto it = elem.begin; it != elem.end; ++it) {
auto value = page.cut.Values().at(page.index[*it]);
auto value = gmat.cut.Values().at(gmat.index[*it]);
ASSERT_LE(value, split_value);
}
auto right_nidx = tree[RegTree::kRoot].RightChild();
elem = partitioner[right_nidx];
for (auto it = elem.begin; it != elem.end; ++it) {
auto value = page.cut.Values().at(page.index[*it]);
auto value = gmat.cut.Values().at(gmat.index[*it]);
ASSERT_GT(value, split_value) << *it;
}
}

View File

@ -1,7 +1,7 @@
import xgboost as xgb
from xgboost.data import SingleBatchInternalIter as SingleBatch
import numpy as np
from testing import IteratorForTest
from testing import IteratorForTest, non_increasing
from typing import Tuple, List
import pytest
from hypothesis import given, strategies, settings
@ -108,7 +108,7 @@ def run_data_iterator(
evals_result=results_from_it,
verbose_eval=False,
)
it_predt = from_it.predict(Xy)
assert non_increasing(results_from_it["Train"]["rmse"])
X, y = it.as_arrays()
Xy = xgb.DMatrix(X, y)
@ -125,13 +125,13 @@ def run_data_iterator(
verbose_eval=False,
)
arr_predt = from_arrays.predict(Xy)
assert non_increasing(results_from_arrays["Train"]["rmse"])
if tree_method != "gpu_hist":
rtol = 1e-1 # flaky
else:
# Model can be sensitive to quantiles, use 1e-2 to relax the test.
np.testing.assert_allclose(it_predt, arr_predt, rtol=1e-2)
rtol = 1e-6
rtol = 1e-2
# CPU sketching is more memory efficient but less consistent due to small chunks
it_predt = from_it.predict(Xy)
arr_predt = from_arrays.predict(Xy)
np.testing.assert_allclose(it_predt, arr_predt, rtol=rtol)
np.testing.assert_allclose(
results_from_it["Train"]["rmse"],