External memory support for hist (#7531)

* Generate column matrix from gHistIndex.
* Avoid synchronization with the sparse page once the cache is written.
* Cleanups: Remove member variables/functions, change the update routine to look like approx and gpu_hist.
* Remove pruner.
This commit is contained in:
Jiaming Yuan 2022-03-22 00:13:20 +08:00 committed by GitHub
parent cd55823112
commit 4d81c741e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 563 additions and 686 deletions

View File

@ -48,17 +48,18 @@
#include "../src/predictor/cpu_predictor.cc" #include "../src/predictor/cpu_predictor.cc"
// trees // trees
#include "../src/tree/constraints.cc"
#include "../src/tree/hist/param.cc"
#include "../src/tree/param.cc" #include "../src/tree/param.cc"
#include "../src/tree/tree_model.cc" #include "../src/tree/tree_model.cc"
#include "../src/tree/tree_updater.cc" #include "../src/tree/tree_updater.cc"
#include "../src/tree/updater_approx.cc"
#include "../src/tree/updater_colmaker.cc" #include "../src/tree/updater_colmaker.cc"
#include "../src/tree/updater_quantile_hist.cc" #include "../src/tree/updater_histmaker.cc"
#include "../src/tree/updater_prune.cc" #include "../src/tree/updater_prune.cc"
#include "../src/tree/updater_quantile_hist.cc"
#include "../src/tree/updater_refresh.cc" #include "../src/tree/updater_refresh.cc"
#include "../src/tree/updater_sync.cc" #include "../src/tree/updater_sync.cc"
#include "../src/tree/updater_histmaker.cc"
#include "../src/tree/updater_approx.cc"
#include "../src/tree/constraints.cc"
// linear // linear
#include "../src/linear/linear_updater.cc" #include "../src/linear/linear_updater.cc"

View File

@ -7,6 +7,9 @@ instead of Quantile DMatrix. The feature is not ready for production use yet.
.. versionadded:: 1.5.0 .. versionadded:: 1.5.0
See :doc:`the tutorial </tutorials/external_memory>` for more details.
""" """
import os import os
import xgboost import xgboost
@ -77,9 +80,14 @@ def main(tmpdir: str) -> xgboost.Booster:
missing = np.NaN missing = np.NaN
Xy = xgboost.DMatrix(it, missing=missing, enable_categorical=False) Xy = xgboost.DMatrix(it, missing=missing, enable_categorical=False)
# Other tree methods including ``hist`` and ``gpu_hist`` also work, but has some # Other tree methods including ``hist`` and ``gpu_hist`` also work, see tutorial in
# caveats. This is still an experimental feature. # doc for details.
booster = xgboost.train({"tree_method": "approx"}, Xy, evals=[(Xy, "Train")]) booster = xgboost.train(
{"tree_method": "approx", "max_depth": 2},
Xy,
evals=[(Xy, "Train")],
num_boost_round=10,
)
return booster return booster

View File

@ -27,7 +27,7 @@ def main(args):
dtrain.set_info(feature_weights=fw) dtrain.set_info(feature_weights=fw)
bst = xgboost.train({'tree_method': 'hist', bst = xgboost.train({'tree_method': 'hist',
'colsample_bynode': 0.5}, 'colsample_bynode': 0.2},
dtrain, num_boost_round=10, dtrain, num_boost_round=10,
evals=[(dtrain, 'd')]) evals=[(dtrain, 'd')])
feature_map = bst.get_fscore() feature_map = bst.get_fscore()

View File

@ -127,9 +127,12 @@ the tree method still concatenate all the chunks into 1 final histogram index du
performance reason, but in compressed format. So its scalability has an upper bound but performance reason, but in compressed format. So its scalability has an upper bound but
still has lower memory cost in general. still has lower memory cost in general.
******** ***********
CPU Hist CPU Version
******** ***********
It's limited by the same factor of GPU Hist, except that gradient based sampling is not For CPU histogram based tree methods (``approx``, ``hist``) it's recommended to use
yet supported on CPU. ``grow_policy=depthwise`` for performance reason. Iterating over data batches is slow,
with ``depthwise`` policy XGBoost can build a entire layer of tree nodes with a few
iterations, while with ``lossguide`` XGBoost needs to iterate over the data set for each
tree node.

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2017 by Contributors * Copyright 2017-2022 by Contributors
* \file column_matrix.h * \file column_matrix.h
* \brief Utility for fast column-wise access * \brief Utility for fast column-wise access
* \author Philip Cho * \author Philip Cho
@ -8,21 +8,22 @@
#ifndef XGBOOST_COMMON_COLUMN_MATRIX_H_ #ifndef XGBOOST_COMMON_COLUMN_MATRIX_H_
#define XGBOOST_COMMON_COLUMN_MATRIX_H_ #define XGBOOST_COMMON_COLUMN_MATRIX_H_
#include <dmlc/endian.h>
#include <algorithm>
#include <limits> #include <limits>
#include <vector>
#include <memory> #include <memory>
#include "hist_util.h" #include <vector>
#include "../data/gradient_index.h" #include "../data/gradient_index.h"
#include "hist_util.h"
namespace xgboost { namespace xgboost {
namespace common { namespace common {
class ColumnMatrix; class ColumnMatrix;
/*! \brief column type */ /*! \brief column type */
enum ColumnType { enum ColumnType : uint8_t { kDenseColumn, kSparseColumn };
kDenseColumn,
kSparseColumn
};
/*! \brief a column storage, to be used with ApplySplit. Note that each /*! \brief a column storage, to be used with ApplySplit. Note that each
bin id is stored as index[i] + index_base. bin id is stored as index[i] + index_base.
@ -34,9 +35,7 @@ class Column {
static constexpr int32_t kMissingId = -1; static constexpr int32_t kMissingId = -1;
Column(ColumnType type, common::Span<const BinIdxType> index, const uint32_t index_base) Column(ColumnType type, common::Span<const BinIdxType> index, const uint32_t index_base)
: type_(type), : type_(type), index_(index), index_base_(index_base) {}
index_(index),
index_base_(index_base) {}
virtual ~Column() = default; virtual ~Column() = default;
@ -67,10 +66,9 @@ class Column {
template <typename BinIdxType> template <typename BinIdxType>
class SparseColumn : public Column<BinIdxType> { class SparseColumn : public Column<BinIdxType> {
public: public:
SparseColumn(ColumnType type, common::Span<const BinIdxType> index, SparseColumn(ColumnType type, common::Span<const BinIdxType> index, uint32_t index_base,
uint32_t index_base, common::Span<const size_t> row_ind) common::Span<const size_t> row_ind)
: Column<BinIdxType>(type, index, index_base), : Column<BinIdxType>(type, index, index_base), row_ind_(row_ind) {}
row_ind_(row_ind) {}
const size_t* GetRowData() const { return row_ind_.data(); } const size_t* GetRowData() const { return row_ind_.data(); }
@ -98,9 +96,7 @@ class SparseColumn: public Column<BinIdxType> {
return p - row_data; return p - row_data;
} }
size_t GetRowIdx(size_t idx) const { size_t GetRowIdx(size_t idx) const { return row_ind_.data()[idx]; }
return row_ind_.data()[idx];
}
private: private:
/* indexes of rows */ /* indexes of rows */
@ -110,9 +106,8 @@ class SparseColumn: public Column<BinIdxType> {
template <typename BinIdxType, bool any_missing> template <typename BinIdxType, bool any_missing>
class DenseColumn : public Column<BinIdxType> { class DenseColumn : public Column<BinIdxType> {
public: public:
DenseColumn(ColumnType type, common::Span<const BinIdxType> index, DenseColumn(ColumnType type, common::Span<const BinIdxType> index, uint32_t index_base,
uint32_t index_base, const std::vector<bool>& missing_flags, const std::vector<bool>& missing_flags, size_t feature_offset)
size_t feature_offset)
: Column<BinIdxType>(type, index, index_base), : Column<BinIdxType>(type, index, index_base),
missing_flags_(missing_flags), missing_flags_(missing_flags),
feature_offset_(feature_offset) {} feature_offset_(feature_offset) {}
@ -126,9 +121,7 @@ class DenseColumn: public Column<BinIdxType> {
} }
} }
size_t GetInitialState(const size_t first_row_id) const { size_t GetInitialState(const size_t first_row_id) const { return 0; }
return 0;
}
private: private:
/* flags for missing values in dense columns */ /* flags for missing values in dense columns */
@ -141,28 +134,26 @@ class DenseColumn: public Column<BinIdxType> {
class ColumnMatrix { class ColumnMatrix {
public: public:
// get number of features // get number of features
inline bst_uint GetNumFeature() const { bst_feature_t GetNumFeature() const { return static_cast<bst_feature_t>(type_.size()); }
return static_cast<bst_uint>(type_.size());
}
// construct column matrix from GHistIndexMatrix // construct column matrix from GHistIndexMatrix
inline void Init(const GHistIndexMatrix& gmat, double sparse_threshold, int32_t n_threads) { inline void Init(SparsePage const& page, const GHistIndexMatrix& gmat, double sparse_threshold,
const int32_t nfeature = static_cast<int32_t>(gmat.cut.Ptrs().size() - 1); int32_t n_threads) {
auto const nfeature = static_cast<bst_feature_t>(gmat.cut.Ptrs().size() - 1);
const size_t nrow = gmat.row_ptr.size() - 1; const size_t nrow = gmat.row_ptr.size() - 1;
// identify type of each column // identify type of each column
feature_counts_.resize(nfeature); feature_counts_.resize(nfeature);
type_.resize(nfeature); type_.resize(nfeature);
std::fill(feature_counts_.begin(), feature_counts_.end(), 0); std::fill(feature_counts_.begin(), feature_counts_.end(), 0);
uint32_t max_val = std::numeric_limits<uint32_t>::max(); uint32_t max_val = std::numeric_limits<uint32_t>::max();
for (int32_t fid = 0; fid < nfeature; ++fid) { for (bst_feature_t fid = 0; fid < nfeature; ++fid) {
CHECK_LE(gmat.cut.Ptrs()[fid + 1] - gmat.cut.Ptrs()[fid], max_val); CHECK_LE(gmat.cut.Ptrs()[fid + 1] - gmat.cut.Ptrs()[fid], max_val);
} }
bool all_dense = gmat.IsDense(); bool all_dense = gmat.IsDense();
gmat.GetFeatureCounts(&feature_counts_[0]); gmat.GetFeatureCounts(&feature_counts_[0]);
// classify features // classify features
for (int32_t fid = 0; fid < nfeature; ++fid) { for (bst_feature_t fid = 0; fid < nfeature; ++fid) {
if (static_cast<double>(feature_counts_[fid]) if (static_cast<double>(feature_counts_[fid]) < sparse_threshold * nrow) {
< sparse_threshold * nrow) {
type_[fid] = kSparseColumn; type_[fid] = kSparseColumn;
all_dense = false; all_dense = false;
} else { } else {
@ -175,7 +166,7 @@ class ColumnMatrix {
feature_offsets_.resize(nfeature + 1); feature_offsets_.resize(nfeature + 1);
size_t accum_index_ = 0; size_t accum_index_ = 0;
feature_offsets_[0] = accum_index_; feature_offsets_[0] = accum_index_;
for (int32_t fid = 1; fid < nfeature + 1; ++fid) { for (bst_feature_t fid = 1; fid < nfeature + 1; ++fid) {
if (type_[fid - 1] == kDenseColumn) { if (type_[fid - 1] == kDenseColumn) {
accum_index_ += static_cast<size_t>(nrow); accum_index_ += static_cast<size_t>(nrow);
} else { } else {
@ -197,6 +188,7 @@ class ColumnMatrix {
const bool noMissingValues = NoMissingValues(gmat.row_ptr[nrow], nrow, nfeature); const bool noMissingValues = NoMissingValues(gmat.row_ptr[nrow], nrow, nfeature);
any_missing_ = !noMissingValues; any_missing_ = !noMissingValues;
missing_flags_.clear();
if (noMissingValues) { if (noMissingValues) {
missing_flags_.resize(feature_offsets_[nfeature], false); missing_flags_.resize(feature_offsets_[nfeature], false);
} else { } else {
@ -207,26 +199,26 @@ class ColumnMatrix {
if (all_dense) { if (all_dense) {
BinTypeSize gmat_bin_size = gmat.index.GetBinTypeSize(); BinTypeSize gmat_bin_size = gmat.index.GetBinTypeSize();
if (gmat_bin_size == kUint8BinsTypeSize) { if (gmat_bin_size == kUint8BinsTypeSize) {
SetIndexAllDense(gmat.index.data<uint8_t>(), gmat, nrow, nfeature, noMissingValues, SetIndexAllDense(page, gmat.index.data<uint8_t>(), gmat, nrow, nfeature, noMissingValues,
n_threads); n_threads);
} else if (gmat_bin_size == kUint16BinsTypeSize) { } else if (gmat_bin_size == kUint16BinsTypeSize) {
SetIndexAllDense(gmat.index.data<uint16_t>(), gmat, nrow, nfeature, noMissingValues, SetIndexAllDense(page, gmat.index.data<uint16_t>(), gmat, nrow, nfeature, noMissingValues,
n_threads); n_threads);
} else { } else {
CHECK_EQ(gmat_bin_size, kUint32BinsTypeSize); CHECK_EQ(gmat_bin_size, kUint32BinsTypeSize);
SetIndexAllDense(gmat.index.data<uint32_t>(), gmat, nrow, nfeature, noMissingValues, SetIndexAllDense(page, gmat.index.data<uint32_t>(), gmat, nrow, nfeature, noMissingValues,
n_threads); n_threads);
} }
/* For sparse DMatrix gmat.index.getBinTypeSize() returns always kUint32BinsTypeSize /* For sparse DMatrix gmat.index.getBinTypeSize() returns always kUint32BinsTypeSize
but for ColumnMatrix we still have a chance to reduce the memory consumption */ but for ColumnMatrix we still have a chance to reduce the memory consumption */
} else { } else {
if (bins_type_size_ == kUint8BinsTypeSize) { if (bins_type_size_ == kUint8BinsTypeSize) {
SetIndex<uint8_t>(gmat.index.data<uint32_t>(), gmat, nfeature); SetIndex<uint8_t>(page, gmat.index.data<uint32_t>(), gmat, nfeature);
} else if (bins_type_size_ == kUint16BinsTypeSize) { } else if (bins_type_size_ == kUint16BinsTypeSize) {
SetIndex<uint16_t>(gmat.index.data<uint32_t>(), gmat, nfeature); SetIndex<uint16_t>(page, gmat.index.data<uint32_t>(), gmat, nfeature);
} else { } else {
CHECK_EQ(bins_type_size_, kUint32BinsTypeSize); CHECK_EQ(bins_type_size_, kUint32BinsTypeSize);
SetIndex<uint32_t>(gmat.index.data<uint32_t>(), gmat, nfeature); SetIndex<uint32_t>(page, gmat.index.data<uint32_t>(), gmat, nfeature);
} }
} }
} }
@ -250,8 +242,8 @@ class ColumnMatrix {
const size_t feature_offset = feature_offsets_[fid]; // to get right place for certain feature const size_t feature_offset = feature_offsets_[fid]; // to get right place for certain feature
const size_t column_size = feature_offsets_[fid + 1] - feature_offset; const size_t column_size = feature_offsets_[fid + 1] - feature_offset;
common::Span<const BinIdxType> bin_index = { reinterpret_cast<const BinIdxType*>( common::Span<const BinIdxType> bin_index = {
&index_[feature_offset * bins_type_size_]), reinterpret_cast<const BinIdxType*>(&index_[feature_offset * bins_type_size_]),
column_size}; column_size};
std::unique_ptr<const Column<BinIdxType> > res; std::unique_ptr<const Column<BinIdxType> > res;
if (type_[fid] == ColumnType::kDenseColumn) { if (type_[fid] == ColumnType::kDenseColumn) {
@ -266,8 +258,8 @@ class ColumnMatrix {
} }
template <typename T> template <typename T>
inline void SetIndexAllDense(T const* index, const GHistIndexMatrix& gmat, const size_t nrow, inline void SetIndexAllDense(SparsePage const& page, T const* index, const GHistIndexMatrix& gmat,
const size_t nfeature, const bool noMissingValues, const size_t nrow, const size_t nfeature, const bool noMissingValues,
int32_t n_threads) { int32_t n_threads) {
T* local_index = reinterpret_cast<T*>(&index_[0]); T* local_index = reinterpret_cast<T*>(&index_[0]);
@ -285,51 +277,34 @@ class ColumnMatrix {
}); });
} else { } else {
/* to handle rows in all batches, sum of all batch sizes equal to gmat.row_ptr.size() - 1 */ /* to handle rows in all batches, sum of all batch sizes equal to gmat.row_ptr.size() - 1 */
size_t rbegin = 0; auto get_bin_idx = [&](auto bin_id, auto rid, bst_feature_t fid) {
for (const auto &batch : gmat.p_fmat->GetBatches<SparsePage>()) { // T* begin = &local_index[feature_offsets_[fid]];
const xgboost::Entry* data_ptr = batch.data.HostVector().data();
const std::vector<bst_row_t>& offset_vec = batch.offset.HostVector();
const size_t batch_size = batch.Size();
CHECK_LT(batch_size, offset_vec.size());
for (size_t rid = 0; rid < batch_size; ++rid) {
const size_t size = offset_vec[rid + 1] - offset_vec[rid];
SparsePage::Inst inst = {data_ptr + offset_vec[rid], size};
const size_t ibegin = gmat.row_ptr[rbegin + rid];
const size_t iend = gmat.row_ptr[rbegin + rid + 1];
CHECK_EQ(ibegin + inst.size(), iend);
size_t j = 0;
size_t fid = 0;
for (size_t i = ibegin; i < iend; ++i, ++j) {
fid = inst[j].index;
const size_t idx = feature_offsets_[fid]; const size_t idx = feature_offsets_[fid];
/* rbegin allows to store indexes from specific SparsePage batch */ /* rbegin allows to store indexes from specific SparsePage batch */
local_index[idx + rbegin + rid] = index[i]; local_index[idx + rid] = bin_id;
missing_flags_[idx + rbegin + rid] = false;
} missing_flags_[idx + rid] = false;
} };
rbegin += batch.Size(); this->SetIndexSparse(page, index, gmat, nfeature, get_bin_idx);
}
} }
} }
template<typename T> // FIXME(jiamingy): In the future we might want to simply use binary search to simplify
inline void SetIndex(uint32_t const* index, const GHistIndexMatrix& gmat, // this and remove the dependency on SparsePage. This way we can have quantilized
const size_t nfeature) { // matrix for host similar to `DeviceQuantileDMatrix`.
std::vector<size_t> num_nonzeros; template <typename T, typename BinFn>
num_nonzeros.resize(nfeature); void SetIndexSparse(SparsePage const& batch, T* index, const GHistIndexMatrix& gmat,
std::fill(num_nonzeros.begin(), num_nonzeros.end(), 0); const size_t nfeature, BinFn&& assign_bin) {
std::vector<size_t> num_nonzeros(nfeature, 0ul);
T* local_index = reinterpret_cast<T*>(&index_[0]);
size_t rbegin = 0;
for (const auto &batch : gmat.p_fmat->GetBatches<SparsePage>()) {
const xgboost::Entry* data_ptr = batch.data.HostVector().data(); const xgboost::Entry* data_ptr = batch.data.HostVector().data();
const std::vector<bst_row_t>& offset_vec = batch.offset.HostVector(); const std::vector<bst_row_t>& offset_vec = batch.offset.HostVector();
const size_t batch_size = batch.Size(); auto rbegin = 0;
const size_t batch_size = gmat.Size();
CHECK_LT(batch_size, offset_vec.size()); CHECK_LT(batch_size, offset_vec.size());
for (size_t rid = 0; rid < batch_size; ++rid) { for (size_t rid = 0; rid < batch_size; ++rid) {
const size_t ibegin = gmat.row_ptr[rbegin + rid]; const size_t ibegin = gmat.row_ptr[rbegin + rid];
const size_t iend = gmat.row_ptr[rbegin + rid + 1]; const size_t iend = gmat.row_ptr[rbegin + rid + 1];
size_t fid = 0;
const size_t size = offset_vec[rid + 1] - offset_vec[rid]; const size_t size = offset_vec[rid + 1] - offset_vec[rid];
SparsePage::Inst inst = {data_ptr + offset_vec[rid], size}; SparsePage::Inst inst = {data_ptr + offset_vec[rid], size};
@ -337,36 +312,110 @@ class ColumnMatrix {
size_t j = 0; size_t j = 0;
for (size_t i = ibegin; i < iend; ++i, ++j) { for (size_t i = ibegin; i < iend; ++i, ++j) {
const uint32_t bin_id = index[i]; const uint32_t bin_id = index[i];
auto fid = inst[j].index;
assign_bin(bin_id, rid, fid);
}
}
}
fid = inst[j].index; template <typename T>
inline void SetIndex(SparsePage const& page, uint32_t const* index, const GHistIndexMatrix& gmat,
const size_t nfeature) {
T* local_index = reinterpret_cast<T*>(&index_[0]);
std::vector<size_t> num_nonzeros;
num_nonzeros.resize(nfeature);
std::fill(num_nonzeros.begin(), num_nonzeros.end(), 0);
auto get_bin_idx = [&](auto bin_id, auto rid, bst_feature_t fid) {
if (type_[fid] == kDenseColumn) { if (type_[fid] == kDenseColumn) {
T* begin = &local_index[feature_offsets_[fid]]; T* begin = &local_index[feature_offsets_[fid]];
begin[rid + rbegin] = bin_id - index_base_[fid]; begin[rid] = bin_id - index_base_[fid];
missing_flags_[feature_offsets_[fid] + rid + rbegin] = false; missing_flags_[feature_offsets_[fid] + rid] = false;
} else { } else {
T* begin = &local_index[feature_offsets_[fid]]; T* begin = &local_index[feature_offsets_[fid]];
begin[num_nonzeros[fid]] = bin_id - index_base_[fid]; begin[num_nonzeros[fid]] = bin_id - index_base_[fid];
row_ind_[feature_offsets_[fid] + num_nonzeros[fid]] = rid + rbegin; row_ind_[feature_offsets_[fid] + num_nonzeros[fid]] = rid;
++num_nonzeros[fid]; ++num_nonzeros[fid];
} }
} };
} this->SetIndexSparse(page, index, gmat, nfeature, get_bin_idx);
rbegin += batch.Size();
}
}
BinTypeSize GetTypeSize() const {
return bins_type_size_;
} }
BinTypeSize GetTypeSize() const { return bins_type_size_; }
// This is just an utility function // This is just an utility function
bool NoMissingValues(const size_t n_elements, bool NoMissingValues(const size_t n_elements, const size_t n_row, const size_t n_features) {
const size_t n_row, const size_t n_features) {
return n_elements == n_features * n_row; return n_elements == n_features * n_row;
} }
// And this returns part of state // And this returns part of state
bool AnyMissing() const { bool AnyMissing() const { return any_missing_; }
return any_missing_;
// IO procedures for external memory.
bool Read(dmlc::SeekStream* fi, uint32_t const* index_base) {
fi->Read(&index_);
fi->Read(&feature_counts_);
#if !DMLC_LITTLE_ENDIAN
// s390x
std::vector<std::underlying_type<ColumnType>::type> int_types;
fi->Read(&int_types);
type_.resize(int_types.size());
std::transform(
int_types.begin(), int_types.end(), type_.begin(),
[](std::underlying_type<ColumnType>::type i) { return static_cast<ColumnType>(i); });
#else
fi->Read(&type_);
#endif // !DMLC_LITTLE_ENDIAN
fi->Read(&row_ind_);
fi->Read(&feature_offsets_);
index_base_ = index_base;
#if !DMLC_LITTLE_ENDIAN
std::underlying_type<BinTypeSize>::type v;
fi->Read(&v);
bins_type_size_ = static_cast<BinTypeSize>(v);
#else
fi->Read(&bins_type_size_);
#endif
fi->Read(&any_missing_);
return true;
}
size_t Write(dmlc::Stream* fo) const {
size_t bytes{0};
auto write_vec = [&](auto const& vec) {
fo->Write(vec);
bytes += vec.size() * sizeof(typename std::remove_reference_t<decltype(vec)>::value_type) +
sizeof(uint64_t);
};
write_vec(index_);
write_vec(feature_counts_);
#if !DMLC_LITTLE_ENDIAN
// s390x
std::vector<std::underlying_type<ColumnType>::type> int_types(type_.size());
std::transform(type_.begin(), type_.end(), int_types.begin(), [](ColumnType t) {
return static_cast<std::underlying_type<ColumnType>::type>(t);
});
write_vec(int_types);
#else
write_vec(type_);
#endif // !DMLC_LITTLE_ENDIAN
write_vec(row_ind_);
write_vec(feature_offsets_);
#if !DMLC_LITTLE_ENDIAN
auto v = static_cast<std::underlying_type<BinTypeSize>::type>(bins_type_size_);
fo->Write(v);
#else
fo->Write(bins_type_size_);
#endif // DMLC_LITTLE_ENDIAN
bytes += sizeof(bins_type_size_);
fo->Write(any_missing_);
bytes += sizeof(any_missing_);
return bytes;
} }
private: private:

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2019-2021 XGBoost contributors * Copyright 2019-2022 XGBoost contributors
*/ */
#include <memory> #include <memory>
#include <utility> #include <utility>
@ -12,6 +12,13 @@ namespace data {
void EllpackPageSource::Fetch() { void EllpackPageSource::Fetch() {
dh::safe_cuda(cudaSetDevice(param_.gpu_id)); dh::safe_cuda(cudaSetDevice(param_.gpu_id));
if (!this->ReadCache()) { if (!this->ReadCache()) {
if (count_ != 0 && !sync_) {
// source is initialized to be the 0th page during construction, so when count_ is 0
// there's no need to increment the source.
++(*source_);
}
// This is not read from cache so we still need it to be synced with sparse page source.
CHECK_EQ(count_, source_->Iter());
auto const &csr = source_->Page(); auto const &csr = source_->Page();
this->page_.reset(new EllpackPage{}); this->page_.reset(new EllpackPage{});
auto *impl = this->page_->Impl(); auto *impl = this->page_->Impl();

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2019-2021 by XGBoost Contributors * Copyright 2019-2022 by XGBoost Contributors
*/ */
#ifndef XGBOOST_DATA_ELLPACK_PAGE_SOURCE_H_ #ifndef XGBOOST_DATA_ELLPACK_PAGE_SOURCE_H_
@ -25,15 +25,17 @@ class EllpackPageSource : public PageSourceIncMixIn<EllpackPage> {
std::unique_ptr<common::HistogramCuts> cuts_; std::unique_ptr<common::HistogramCuts> cuts_;
public: public:
EllpackPageSource( EllpackPageSource(float missing, int nthreads, bst_feature_t n_features, size_t n_batches,
float missing, int nthreads, bst_feature_t n_features, size_t n_batches,
std::shared_ptr<Cache> cache, BatchParam param, std::shared_ptr<Cache> cache, BatchParam param,
std::unique_ptr<common::HistogramCuts> cuts, bool is_dense, std::unique_ptr<common::HistogramCuts> cuts, bool is_dense, size_t row_stride,
size_t row_stride, common::Span<FeatureType const> feature_types, common::Span<FeatureType const> feature_types,
std::shared_ptr<SparsePageSource> source) std::shared_ptr<SparsePageSource> source)
: PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache), : PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache, false),
is_dense_{is_dense}, row_stride_{row_stride}, param_{std::move(param)}, is_dense_{is_dense},
feature_types_{feature_types}, cuts_{std::move(cuts)} { row_stride_{row_stride},
param_{std::move(param)},
feature_types_{feature_types},
cuts_{std::move(cuts)} {
this->source_ = source; this->source_ = source;
this->Fetch(); this->Fetch();
} }

View File

@ -144,7 +144,6 @@ void GHistIndexMatrix::Init(DMatrix *p_fmat, int max_bins, double sparse_thresh,
hit_count.resize(nbins, 0); hit_count.resize(nbins, 0);
hit_count_tloc_.resize(n_threads * nbins, 0); hit_count_tloc_.resize(n_threads * nbins, 0);
this->p_fmat = p_fmat;
size_t new_size = 1; size_t new_size = 1;
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) { for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
new_size += batch.Size(); new_size += batch.Size();
@ -164,6 +163,16 @@ void GHistIndexMatrix::Init(DMatrix *p_fmat, int max_bins, double sparse_thresh,
prev_sum = row_ptr[rbegin + batch.Size()]; prev_sum = row_ptr[rbegin + batch.Size()];
rbegin += batch.Size(); rbegin += batch.Size();
} }
this->columns_ = std::make_unique<common::ColumnMatrix>();
// hessian is empty when hist tree method is used or when dataset is empty
if (hess.empty() && !std::isnan(sparse_thresh)) {
// hist
CHECK(!sorted_sketch);
for (auto const &page : p_fmat->GetBatches<SparsePage>()) {
this->columns_->Init(page, *this, sparse_thresh, n_threads);
}
}
} }
void GHistIndexMatrix::Init(SparsePage const &batch, common::Span<FeatureType const> ft, void GHistIndexMatrix::Init(SparsePage const &batch, common::Span<FeatureType const> ft,
@ -187,6 +196,10 @@ void GHistIndexMatrix::Init(SparsePage const &batch, common::Span<FeatureType co
size_t prev_sum = 0; size_t prev_sum = 0;
this->PushBatch(batch, ft, rbegin, prev_sum, nbins, n_threads); this->PushBatch(batch, ft, rbegin, prev_sum, nbins, n_threads);
this->columns_ = std::make_unique<common::ColumnMatrix>();
if (!std::isnan(sparse_thresh)) {
this->columns_->Init(batch, *this, sparse_thresh, n_threads);
}
} }
void GHistIndexMatrix::ResizeIndex(const size_t n_index, const bool isDense) { void GHistIndexMatrix::ResizeIndex(const size_t n_index, const bool isDense) {
@ -205,4 +218,17 @@ void GHistIndexMatrix::ResizeIndex(const size_t n_index, const bool isDense) {
index.Resize((sizeof(uint32_t)) * n_index); index.Resize((sizeof(uint32_t)) * n_index);
} }
} }
common::ColumnMatrix const &GHistIndexMatrix::Transpose() const {
CHECK(columns_);
return *columns_;
}
bool GHistIndexMatrix::ReadColumnPage(dmlc::SeekStream *fi) {
return this->columns_->Read(fi, this->cut.Ptrs().data());
}
size_t GHistIndexMatrix::WriteColumnPage(dmlc::Stream *fo) const {
return this->columns_->Write(fo);
}
} // namespace xgboost } // namespace xgboost

View File

@ -40,7 +40,6 @@ class GHistIndexMatrix {
std::vector<size_t> hit_count; std::vector<size_t> hit_count;
/*! \brief The corresponding cuts */ /*! \brief The corresponding cuts */
common::HistogramCuts cut; common::HistogramCuts cut;
DMatrix* p_fmat;
/*! \brief max_bin for each feature. */ /*! \brief max_bin for each feature. */
size_t max_num_bins; size_t max_num_bins;
/*! \brief base row index for current page (used by external memory) */ /*! \brief base row index for current page (used by external memory) */
@ -119,8 +118,12 @@ class GHistIndexMatrix {
return row_ptr.empty() ? 0 : row_ptr.size() - 1; return row_ptr.empty() ? 0 : row_ptr.size() - 1;
} }
bool ReadColumnPage(dmlc::SeekStream* fi);
size_t WriteColumnPage(dmlc::Stream* fo) const;
common::ColumnMatrix const& Transpose() const;
private: private:
// unused at the moment: https://github.com/dmlc/xgboost/pull/7531
std::unique_ptr<common::ColumnMatrix> columns_; std::unique_ptr<common::ColumnMatrix> columns_;
std::vector<size_t> hit_count_tloc_; std::vector<size_t> hit_count_tloc_;
bool isDense_; bool isDense_;

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2021 XGBoost contributors * Copyright 2021-2022 XGBoost contributors
*/ */
#include "sparse_page_writer.h" #include "sparse_page_writer.h"
#include "gradient_index.h" #include "gradient_index.h"
@ -7,7 +7,6 @@
namespace xgboost { namespace xgboost {
namespace data { namespace data {
class GHistIndexRawFormat : public SparsePageFormat<GHistIndexMatrix> { class GHistIndexRawFormat : public SparsePageFormat<GHistIndexMatrix> {
public: public:
bool Read(GHistIndexMatrix* page, dmlc::SeekStream* fi) override { bool Read(GHistIndexMatrix* page, dmlc::SeekStream* fi) override {
@ -50,6 +49,8 @@ class GHistIndexRawFormat : public SparsePageFormat<GHistIndexMatrix> {
if (is_dense) { if (is_dense) {
page->index.SetBinOffset(page->cut.Ptrs()); page->index.SetBinOffset(page->cut.Ptrs());
} }
page->ReadColumnPage(fi);
return true; return true;
} }
@ -81,6 +82,8 @@ class GHistIndexRawFormat : public SparsePageFormat<GHistIndexMatrix> {
bytes += sizeof(page.base_rowid); bytes += sizeof(page.base_rowid);
fo->Write(page.IsDense()); fo->Write(page.IsDense());
bytes += sizeof(page.IsDense()); bytes += sizeof(page.IsDense());
bytes += page.WriteColumnPage(fo);
return bytes; return bytes;
} }
}; };

View File

@ -7,11 +7,18 @@ namespace xgboost {
namespace data { namespace data {
void GradientIndexPageSource::Fetch() { void GradientIndexPageSource::Fetch() {
if (!this->ReadCache()) { if (!this->ReadCache()) {
if (count_ != 0 && !sync_) {
// source is initialized to be the 0th page during construction, so when count_ is 0
// there's no need to increment the source.
++(*source_);
}
// This is not read from cache so we still need it to be synced with sparse page source.
CHECK_EQ(count_, source_->Iter());
auto const& csr = source_->Page(); auto const& csr = source_->Page();
this->page_.reset(new GHistIndexMatrix()); this->page_.reset(new GHistIndexMatrix());
CHECK_NE(cuts_.Values().size(), 0); CHECK_NE(cuts_.Values().size(), 0);
this->page_->Init(*csr, feature_types_, cuts_, max_bin_per_feat_, is_dense_, this->page_->Init(*csr, feature_types_, cuts_, max_bin_per_feat_, is_dense_, sparse_thresh_,
sparse_thresh_, nthreads_); nthreads_);
this->WriteCache(); this->WriteCache();
} }
} }

View File

@ -22,13 +22,14 @@ class GradientIndexPageSource : public PageSourceIncMixIn<GHistIndexMatrix> {
public: public:
GradientIndexPageSource(float missing, int nthreads, bst_feature_t n_features, size_t n_batches, GradientIndexPageSource(float missing, int nthreads, bst_feature_t n_features, size_t n_batches,
std::shared_ptr<Cache> cache, BatchParam param, std::shared_ptr<Cache> cache, BatchParam param,
common::HistogramCuts cuts, bool is_dense, int32_t max_bin_per_feat, common::HistogramCuts cuts, bool is_dense,
common::Span<FeatureType const> feature_types, common::Span<FeatureType const> feature_types,
std::shared_ptr<SparsePageSource> source) std::shared_ptr<SparsePageSource> source)
: PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache), : PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache,
std::isnan(param.sparse_thresh)),
cuts_{std::move(cuts)}, cuts_{std::move(cuts)},
is_dense_{is_dense}, is_dense_{is_dense},
max_bin_per_feat_{max_bin_per_feat}, max_bin_per_feat_{param.max_bin},
feature_types_{feature_types}, feature_types_{feature_types},
sparse_thresh_{param.sparse_thresh} { sparse_thresh_{param.sparse_thresh} {
this->source_ = source; this->source_ = source;

View File

@ -159,21 +159,6 @@ BatchSet<SortedCSCPage> SparsePageDMatrix::GetSortedColumnBatches() {
BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(const BatchParam &param) { BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(const BatchParam &param) {
CHECK_GE(param.max_bin, 2); CHECK_GE(param.max_bin, 2);
if (param.hess.empty() && !param.regen) {
// hist method doesn't support full external memory implementation, so we concatenate
// all index here.
if (!ghist_index_page_ || (param != batch_param_ && param != BatchParam{})) {
this->InitializeSparsePage();
ghist_index_page_.reset(new GHistIndexMatrix{this, param.max_bin, param.sparse_thresh,
param.regen, ctx_.Threads()});
this->InitializeSparsePage();
batch_param_ = param;
}
auto begin_iter = BatchIterator<GHistIndexMatrix>(
new SimpleBatchIteratorImpl<GHistIndexMatrix>(ghist_index_page_));
return BatchSet<GHistIndexMatrix>(begin_iter);
}
auto id = MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_); auto id = MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_);
this->InitializeSparsePage(); this->InitializeSparsePage();
if (!cache_info_.at(id)->written || RegenGHist(batch_param_, param)) { if (!cache_info_.at(id)->written || RegenGHist(batch_param_, param)) {
@ -190,10 +175,9 @@ BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(const BatchParam
ghist_index_source_.reset(); ghist_index_source_.reset();
CHECK_NE(cuts.Values().size(), 0); CHECK_NE(cuts.Values().size(), 0);
auto ft = this->info_.feature_types.ConstHostSpan(); auto ft = this->info_.feature_types.ConstHostSpan();
ghist_index_source_.reset( ghist_index_source_.reset(new GradientIndexPageSource(
new GradientIndexPageSource(this->missing_, this->ctx_.Threads(), this->Info().num_col_, this->missing_, this->ctx_.Threads(), this->Info().num_col_, this->n_batches_,
this->n_batches_, cache_info_.at(id), param, std::move(cuts), cache_info_.at(id), param, std::move(cuts), this->IsDense(), ft, sparse_page_source_));
this->IsDense(), param.max_bin, ft, sparse_page_source_));
} else { } else {
CHECK(ghist_index_source_); CHECK(ghist_index_source_);
ghist_index_source_->Reset(); ghist_index_source_->Reset();

View File

@ -11,6 +11,9 @@ namespace data {
BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(const BatchParam& param) { BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(const BatchParam& param) {
CHECK_GE(param.gpu_id, 0); CHECK_GE(param.gpu_id, 0);
CHECK_GE(param.max_bin, 2); CHECK_GE(param.max_bin, 2);
if (!(batch_param_ != BatchParam{})) {
CHECK(param != BatchParam{}) << "Batch parameter is not initialized.";
}
auto id = MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_); auto id = MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_);
size_t row_stride = 0; size_t row_stride = 0;
this->InitializeSparsePage(); this->InitializeSparsePage();

View File

@ -23,6 +23,7 @@
#include "proxy_dmatrix.h" #include "proxy_dmatrix.h"
#include "../common/common.h" #include "../common/common.h"
#include "../common/timer.h"
namespace xgboost { namespace xgboost {
namespace data { namespace data {
@ -118,26 +119,30 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
size_t n_prefetch_batches = std::min(kPreFetch, n_batches_); size_t n_prefetch_batches = std::min(kPreFetch, n_batches_);
CHECK_GT(n_prefetch_batches, 0) << "total batches:" << n_batches_; CHECK_GT(n_prefetch_batches, 0) << "total batches:" << n_batches_;
size_t fetch_it = count_; size_t fetch_it = count_;
for (size_t i = 0; i < n_prefetch_batches; ++i, ++fetch_it) { for (size_t i = 0; i < n_prefetch_batches; ++i, ++fetch_it) {
fetch_it %= n_batches_; // ring fetch_it %= n_batches_; // ring
if (ring_->at(fetch_it).valid()) { continue; } if (ring_->at(fetch_it).valid()) {
continue;
}
auto const *self = this; // make sure it's const auto const *self = this; // make sure it's const
CHECK_LT(fetch_it, cache_info_->offset.size()); CHECK_LT(fetch_it, cache_info_->offset.size());
ring_->at(fetch_it) = std::async(std::launch::async, [fetch_it, self]() { ring_->at(fetch_it) = std::async(std::launch::async, [fetch_it, self]() {
common::Timer timer;
timer.Start();
std::unique_ptr<SparsePageFormat<S>> fmt{CreatePageFormat<S>("raw")}; std::unique_ptr<SparsePageFormat<S>> fmt{CreatePageFormat<S>("raw")};
auto n = self->cache_info_->ShardName(); auto n = self->cache_info_->ShardName();
size_t offset = self->cache_info_->offset.at(fetch_it); size_t offset = self->cache_info_->offset.at(fetch_it);
std::unique_ptr<dmlc::SeekStream> fi{ std::unique_ptr<dmlc::SeekStream> fi{dmlc::SeekStream::CreateForRead(n.c_str())};
dmlc::SeekStream::CreateForRead(n.c_str())};
fi->Seek(offset); fi->Seek(offset);
CHECK_EQ(fi->Tell(), offset); CHECK_EQ(fi->Tell(), offset);
auto page = std::make_shared<S>(); auto page = std::make_shared<S>();
CHECK(fmt->Read(page.get(), fi.get())); CHECK(fmt->Read(page.get(), fi.get()));
LOG(INFO) << "Read a page in " << timer.ElapsedSeconds() << " seconds.";
return page; return page;
}); });
} }
CHECK_EQ(std::count_if(ring_->cbegin(), ring_->cend(), CHECK_EQ(std::count_if(ring_->cbegin(), ring_->cend(), [](auto const& f) { return f.valid(); }),
[](auto const &f) { return f.valid(); }),
n_prefetch_batches) n_prefetch_batches)
<< "Sparse DMatrix assumes forward iteration."; << "Sparse DMatrix assumes forward iteration.";
page_ = (*ring_)[count_].get(); page_ = (*ring_)[count_].get();
@ -146,12 +151,18 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
void WriteCache() { void WriteCache() {
CHECK(!cache_info_->written); CHECK(!cache_info_->written);
common::Timer timer;
timer.Start();
std::unique_ptr<SparsePageFormat<S>> fmt{CreatePageFormat<S>("raw")}; std::unique_ptr<SparsePageFormat<S>> fmt{CreatePageFormat<S>("raw")};
if (!fo_) { if (!fo_) {
auto n = cache_info_->ShardName(); auto n = cache_info_->ShardName();
fo_.reset(dmlc::Stream::Create(n.c_str(), "w")); fo_.reset(dmlc::Stream::Create(n.c_str(), "w"));
} }
auto bytes = fmt->Write(*page_, fo_.get()); auto bytes = fmt->Write(*page_, fo_.get());
timer.Stop();
LOG(INFO) << static_cast<double>(bytes) / 1024.0 / 1024.0 << " MB written in "
<< timer.ElapsedSeconds() << " seconds.";
cache_info_->offset.push_back(bytes); cache_info_->offset.push_back(bytes);
} }
@ -280,15 +291,24 @@ template <typename S>
class PageSourceIncMixIn : public SparsePageSourceImpl<S> { class PageSourceIncMixIn : public SparsePageSourceImpl<S> {
protected: protected:
std::shared_ptr<SparsePageSource> source_; std::shared_ptr<SparsePageSource> source_;
using Super = SparsePageSourceImpl<S>;
// synchronize the row page, `hist` and `gpu_hist` don't need the original sparse page
// so we avoid fetching it.
bool sync_{true};
public: public:
using SparsePageSourceImpl<S>::SparsePageSourceImpl; PageSourceIncMixIn(float missing, int nthreads, bst_feature_t n_features, uint32_t n_batches,
std::shared_ptr<Cache> cache, bool sync)
: Super::SparsePageSourceImpl{missing, nthreads, n_features, n_batches, cache}, sync_{sync} {}
PageSourceIncMixIn& operator++() final { PageSourceIncMixIn& operator++() final {
TryLockGuard guard{this->single_threaded_}; TryLockGuard guard{this->single_threaded_};
if (sync_) {
++(*source_); ++(*source_);
}
++this->count_; ++this->count_;
this->at_end_ = source_->AtEnd(); this->at_end_ = this->count_ == this->n_batches_;
if (this->at_end_) { if (this->at_end_) {
this->cache_info_->Commit(); this->cache_info_->Commit();
@ -299,7 +319,10 @@ class PageSourceIncMixIn : public SparsePageSourceImpl<S> {
} else { } else {
this->Fetch(); this->Fetch();
} }
if (sync_) {
CHECK_EQ(source_->Iter(), this->count_); CHECK_EQ(source_->Iter(), this->count_);
}
return *this; return *this;
} }
}; };
@ -318,12 +341,9 @@ class CSCPageSource : public PageSourceIncMixIn<CSCPage> {
} }
public: public:
CSCPageSource( CSCPageSource(float missing, int nthreads, bst_feature_t n_features, uint32_t n_batches,
float missing, int nthreads, bst_feature_t n_features, uint32_t n_batches, std::shared_ptr<Cache> cache, std::shared_ptr<SparsePageSource> source)
std::shared_ptr<Cache> cache, : PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache, true) {
std::shared_ptr<SparsePageSource> source)
: PageSourceIncMixIn(missing, nthreads, n_features,
n_batches, cache) {
this->source_ = source; this->source_ = source;
this->Fetch(); this->Fetch();
} }
@ -349,7 +369,7 @@ class SortedCSCPageSource : public PageSourceIncMixIn<SortedCSCPage> {
SortedCSCPageSource(float missing, int nthreads, bst_feature_t n_features, SortedCSCPageSource(float missing, int nthreads, bst_feature_t n_features,
uint32_t n_batches, std::shared_ptr<Cache> cache, uint32_t n_batches, std::shared_ptr<Cache> cache,
std::shared_ptr<SparsePageSource> source) std::shared_ptr<SparsePageSource> source)
: PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache) { : PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache, true) {
this->source_ = source; this->source_ = source;
this->Fetch(); this->Fetch();
} }

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2021 by XGBoost Contributors * Copyright 2021-2022 by XGBoost Contributors
*/ */
#ifndef XGBOOST_TREE_HIST_HISTOGRAM_H_ #ifndef XGBOOST_TREE_HIST_HISTOGRAM_H_
#define XGBOOST_TREE_HIST_HISTOGRAM_H_ #define XGBOOST_TREE_HIST_HISTOGRAM_H_
@ -8,10 +8,11 @@
#include <limits> #include <limits>
#include <vector> #include <vector>
#include "rabit/rabit.h"
#include "xgboost/tree_model.h"
#include "../../common/hist_util.h" #include "../../common/hist_util.h"
#include "../../data/gradient_index.h" #include "../../data/gradient_index.h"
#include "expand_entry.h"
#include "rabit/rabit.h"
#include "xgboost/tree_model.h"
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {
@ -323,6 +324,25 @@ template <typename GradientSumT, typename ExpandEntry> class HistogramBuilder {
(*sync_count) = std::max(1, n_left); (*sync_count) = std::max(1, n_left);
} }
}; };
// Construct a work space for building histogram. Eventually we should move this
// function into histogram builder once hist tree method supports external memory.
template <typename Partitioner>
common::BlockedSpace2d ConstructHistSpace(Partitioner const &partitioners,
std::vector<CPUExpandEntry> const &nodes_to_build) {
std::vector<size_t> partition_size(nodes_to_build.size(), 0);
for (auto const &partition : partitioners) {
size_t k = 0;
for (auto node : nodes_to_build) {
auto n_rows_in_node = partition.Partitions()[node.nid].Size();
partition_size[k] = std::max(partition_size[k], n_rows_in_node);
k++;
}
}
common::BlockedSpace2d space{
nodes_to_build.size(), [&](size_t nidx_in_set) { return partition_size[nidx_in_set]; }, 256};
return space;
}
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_TREE_HIST_HISTOGRAM_H_ #endif // XGBOOST_TREE_HIST_HISTOGRAM_H_

10
src/tree/hist/param.cc Normal file
View File

@ -0,0 +1,10 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#include "param.h"
namespace xgboost {
namespace tree {
DMLC_REGISTER_PARAMETER(CPUHistMakerTrainParam);
} // namespace tree
} // namespace xgboost

View File

@ -94,7 +94,7 @@ class GloablApproxBuilder {
rabit::Allreduce<rabit::op::Sum, double>(reinterpret_cast<double *>(&root_sum), 2); rabit::Allreduce<rabit::op::Sum, double>(reinterpret_cast<double *>(&root_sum), 2);
std::vector<CPUExpandEntry> nodes{best}; std::vector<CPUExpandEntry> nodes{best};
size_t i = 0; size_t i = 0;
auto space = this->ConstructHistSpace(nodes); auto space = ConstructHistSpace(partitioner_, nodes);
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(BatchSpec(param_, hess))) { for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(BatchSpec(param_, hess))) {
histogram_builder_.BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(), nodes, histogram_builder_.BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(), nodes,
{}, gpair); {}, gpair);
@ -123,25 +123,6 @@ class GloablApproxBuilder {
monitor_->Stop(__func__); monitor_->Stop(__func__);
} }
// Construct a work space for building histogram. Eventually we should move this
// function into histogram builder once hist tree method supports external memory.
common::BlockedSpace2d ConstructHistSpace(
std::vector<CPUExpandEntry> const &nodes_to_build) const {
std::vector<size_t> partition_size(nodes_to_build.size(), 0);
for (auto const &partition : partitioner_) {
size_t k = 0;
for (auto node : nodes_to_build) {
auto n_rows_in_node = partition.Partitions()[node.nid].Size();
partition_size[k] = std::max(partition_size[k], n_rows_in_node);
k++;
}
}
common::BlockedSpace2d space{nodes_to_build.size(),
[&](size_t nidx_in_set) { return partition_size[nidx_in_set]; },
256};
return space;
}
void BuildHistogram(DMatrix *p_fmat, RegTree *p_tree, void BuildHistogram(DMatrix *p_fmat, RegTree *p_tree,
std::vector<CPUExpandEntry> const &valid_candidates, std::vector<CPUExpandEntry> const &valid_candidates,
std::vector<GradientPair> const &gpair, common::Span<float> hess) { std::vector<GradientPair> const &gpair, common::Span<float> hess) {
@ -164,7 +145,7 @@ class GloablApproxBuilder {
} }
size_t i = 0; size_t i = 0;
auto space = this->ConstructHistSpace(nodes_to_build); auto space = ConstructHistSpace(partitioner_, nodes_to_build);
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(BatchSpec(param_, hess))) { for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(BatchSpec(param_, hess))) {
histogram_builder_.BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(), histogram_builder_.BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(),
nodes_to_build, nodes_to_sub, gpair); nodes_to_build, nodes_to_sub, gpair);
@ -191,7 +172,7 @@ class GloablApproxBuilder {
Driver<CPUExpandEntry> driver(static_cast<TrainParam::TreeGrowPolicy>(param_.grow_policy)); Driver<CPUExpandEntry> driver(static_cast<TrainParam::TreeGrowPolicy>(param_.grow_policy));
auto &tree = *p_tree; auto &tree = *p_tree;
driver.Push({this->InitRoot(p_fmat, gpair, hess, p_tree)}); driver.Push({this->InitRoot(p_fmat, gpair, hess, p_tree)});
bst_node_t num_leaves = 1; bst_node_t num_leaves{1};
auto expand_set = driver.Pop(); auto expand_set = driver.Pop();
/** /**
@ -223,10 +204,10 @@ class GloablApproxBuilder {
} }
monitor_->Start("UpdatePosition"); monitor_->Start("UpdatePosition");
size_t i = 0; size_t page_id = 0;
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(BatchSpec(param_, hess))) { for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(BatchSpec(param_, hess))) {
partitioner_.at(i).UpdatePosition(ctx_, page, applied, p_tree); partitioner_.at(page_id).UpdatePosition(ctx_, page, applied, p_tree);
i++; page_id++;
} }
monitor_->Stop("UpdatePosition"); monitor_->Stop("UpdatePosition");
@ -288,9 +269,9 @@ class GlobalApproxUpdater : public TreeUpdater {
out["hist_param"] = ToJson(hist_param_); out["hist_param"] = ToJson(hist_param_);
} }
void InitData(TrainParam const &param, HostDeviceVector<GradientPair> *gpair, void InitData(TrainParam const &param, HostDeviceVector<GradientPair> const *gpair,
std::vector<GradientPair> *sampled) { std::vector<GradientPair> *sampled) {
auto const &h_gpair = gpair->HostVector(); auto const &h_gpair = gpair->ConstHostVector();
sampled->resize(h_gpair.size()); sampled->resize(h_gpair.size());
std::copy(h_gpair.cbegin(), h_gpair.cend(), sampled->begin()); std::copy(h_gpair.cbegin(), h_gpair.cend(), sampled->begin());
auto &rnd = common::GlobalRandom(); auto &rnd = common::GlobalRandom();

View File

@ -4,80 +4,39 @@
* \brief use quantized feature values to construct a tree * \brief use quantized feature values to construct a tree
* \author Philip Cho, Tianqi Checn, Egor Smirnov * \author Philip Cho, Tianqi Checn, Egor Smirnov
*/ */
#include <dmlc/timer.h> #include "./updater_quantile_hist.h"
#include <rabit/rabit.h> #include <rabit/rabit.h>
#include <algorithm> #include <algorithm>
#include <cmath>
#include <iomanip>
#include <memory> #include <memory>
#include <numeric> #include <numeric>
#include <queue>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "../common/column_matrix.h"
#include "../common/hist_util.h"
#include "../common/random.h"
#include "../common/threading_utils.h"
#include "constraints.h"
#include "hist/evaluate_splits.h"
#include "param.h"
#include "xgboost/logging.h" #include "xgboost/logging.h"
#include "xgboost/tree_updater.h" #include "xgboost/tree_updater.h"
#include "constraints.h"
#include "param.h"
#include "./updater_quantile_hist.h"
#include "./split_evaluator.h"
#include "../common/random.h"
#include "../common/hist_util.h"
#include "../common/row_set.h"
#include "../common/column_matrix.h"
#include "../common/threading_utils.h"
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {
DMLC_REGISTRY_FILE_TAG(updater_quantile_hist); DMLC_REGISTRY_FILE_TAG(updater_quantile_hist);
DMLC_REGISTER_PARAMETER(CPUHistMakerTrainParam);
void QuantileHistMaker::Configure(const Args &args) { void QuantileHistMaker::Configure(const Args &args) {
// initialize pruner
if (!pruner_) {
pruner_.reset(TreeUpdater::Create("prune", ctx_, task_));
}
pruner_->Configure(args);
param_.UpdateAllowUnknown(args); param_.UpdateAllowUnknown(args);
hist_maker_param_.UpdateAllowUnknown(args); hist_maker_param_.UpdateAllowUnknown(args);
} }
template <typename GradientSumT> void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair, DMatrix *dmat,
void QuantileHistMaker::SetBuilder(const size_t n_trees,
std::unique_ptr<Builder<GradientSumT>>* builder, DMatrix* dmat) {
builder->reset(
new Builder<GradientSumT>(n_trees, param_, std::move(pruner_), dmat, task_, ctx_));
}
template<typename GradientSumT>
void QuantileHistMaker::CallBuilderUpdate(const std::unique_ptr<Builder<GradientSumT>>& builder,
HostDeviceVector<GradientPair> *gpair,
DMatrix *dmat,
GHistIndexMatrix const& gmat,
const std::vector<RegTree *> &trees) { const std::vector<RegTree *> &trees) {
for (auto tree : trees) {
builder->Update(gmat, column_matrix_, gpair, dmat, tree);
}
}
void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair,
DMatrix *dmat,
const std::vector<RegTree *> &trees) {
auto it = dmat->GetBatches<GHistIndexMatrix>(HistBatch(param_)).begin();
auto p_gmat = it.Page();
if (dmat != p_last_dmat_ || is_gmat_initialized_ == false) {
updater_monitor_.Start("GmatInitialization");
column_matrix_.Init(*p_gmat, param_.sparse_threshold, ctx_->Threads());
updater_monitor_.Stop("GmatInitialization");
// A proper solution is puting cut matrix in DMatrix, see:
// https://github.com/dmlc/xgboost/issues/5143
is_gmat_initialized_ = true;
}
// rescale learning rate according to size of trees // rescale learning rate according to size of trees
float lr = param_.learning_rate; float lr = param_.learning_rate;
param_.learning_rate = lr / trees.size(); param_.learning_rate = lr / trees.size();
@ -86,19 +45,23 @@ void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair,
const size_t n_trees = trees.size(); const size_t n_trees = trees.size();
if (hist_maker_param_.single_precision_histogram) { if (hist_maker_param_.single_precision_histogram) {
if (!float_builder_) { if (!float_builder_) {
this->SetBuilder(n_trees, &float_builder_, dmat); float_builder_.reset(new Builder<float>(n_trees, param_, dmat, task_, ctx_));
} }
CallBuilderUpdate(float_builder_, gpair, dmat, *p_gmat, trees);
} else { } else {
if (!double_builder_) { if (!double_builder_) {
SetBuilder(n_trees, &double_builder_, dmat); double_builder_.reset(new Builder<double>(n_trees, param_, dmat, task_, ctx_));
}
}
for (auto p_tree : trees) {
if (hist_maker_param_.single_precision_histogram) {
this->float_builder_->UpdateTree(gpair, dmat, p_tree);
} else {
this->double_builder_->UpdateTree(gpair, dmat, p_tree);
} }
CallBuilderUpdate(double_builder_, gpair, dmat, *p_gmat, trees);
} }
param_.learning_rate = lr; param_.learning_rate = lr;
p_last_dmat_ = dmat;
} }
bool QuantileHistMaker::UpdatePredictionCache(const DMatrix *data, bool QuantileHistMaker::UpdatePredictionCache(const DMatrix *data,
@ -113,23 +76,18 @@ bool QuantileHistMaker::UpdatePredictionCache(const DMatrix *data,
} }
template <typename GradientSumT> template <typename GradientSumT>
template <bool any_missing> CPUExpandEntry QuantileHistMaker::Builder<GradientSumT>::InitRoot(
void QuantileHistMaker::Builder<GradientSumT>::InitRoot( DMatrix *p_fmat, RegTree *p_tree, const std::vector<GradientPair> &gpair_h) {
DMatrix *p_fmat, RegTree *p_tree, const std::vector<GradientPair> &gpair_h,
int *num_leaves, std::vector<CPUExpandEntry> *expand) {
CPUExpandEntry node(RegTree::kRoot, p_tree->GetDepth(0), 0.0f); CPUExpandEntry node(RegTree::kRoot, p_tree->GetDepth(0), 0.0f);
nodes_for_explicit_hist_build_.clear();
nodes_for_subtraction_trick_.clear();
nodes_for_explicit_hist_build_.push_back(node);
auto const& row_set_collection = partitioner_.front().Partitions();
size_t page_id = 0; size_t page_id = 0;
for (auto const& gidx : auto space = ConstructHistSpace(partitioner_, {node});
p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) { for (auto const &gidx : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
this->histogram_builder_->BuildHist( std::vector<CPUExpandEntry> nodes_to_build{node};
page_id, gidx, p_tree, row_set_collection, std::vector<CPUExpandEntry> nodes_to_sub;
nodes_for_explicit_hist_build_, nodes_for_subtraction_trick_, gpair_h); this->histogram_builder_->BuildHist(page_id, space, gidx, p_tree,
partitioner_.at(page_id).Partitions(), nodes_to_build,
nodes_to_sub, gpair_h);
++page_id; ++page_id;
} }
@ -165,148 +123,118 @@ void QuantileHistMaker::Builder<GradientSumT>::InitRoot(
(*p_tree)[RegTree::kRoot].SetLeaf(param_.learning_rate * weight); (*p_tree)[RegTree::kRoot].SetLeaf(param_.learning_rate * weight);
std::vector<CPUExpandEntry> entries{node}; std::vector<CPUExpandEntry> entries{node};
builder_monitor_->Start("EvaluateSplits"); monitor_->Start("EvaluateSplits");
auto ft = p_fmat->Info().feature_types.ConstHostSpan(); auto ft = p_fmat->Info().feature_types.ConstHostSpan();
for (auto const &gmat : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) { for (auto const &gmat : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
evaluator_->EvaluateSplits(histogram_builder_->Histogram(), gmat.cut, ft, evaluator_->EvaluateSplits(histogram_builder_->Histogram(), gmat.cut, ft, *p_tree, &entries);
*p_tree, &entries);
break; break;
} }
builder_monitor_->Stop("EvaluateSplits"); monitor_->Stop("EvaluateSplits");
node = entries.front(); node = entries.front();
} }
expand->push_back(node); return node;
++(*num_leaves);
} }
template <typename GradientSumT> template <typename GradientSumT>
void QuantileHistMaker::Builder<GradientSumT>::AddSplitsToTree( void QuantileHistMaker::Builder<GradientSumT>::BuildHistogram(
const std::vector<CPUExpandEntry>& expand, DMatrix *p_fmat, RegTree *p_tree, std::vector<CPUExpandEntry> const &valid_candidates,
RegTree *p_tree, std::vector<GradientPair> const &gpair) {
int *num_leaves, std::vector<CPUExpandEntry> nodes_to_build(valid_candidates.size());
std::vector<CPUExpandEntry>* nodes_for_apply_split) { std::vector<CPUExpandEntry> nodes_to_sub(valid_candidates.size());
for (auto const& entry : expand) {
if (entry.IsValid(param_, *num_leaves)) { size_t n_idx = 0;
nodes_for_apply_split->push_back(entry); for (auto const &c : valid_candidates) {
evaluator_->ApplyTreeSplit(entry, p_tree); auto left_nidx = (*p_tree)[c.nid].LeftChild();
(*num_leaves)++; auto right_nidx = (*p_tree)[c.nid].RightChild();
} auto fewer_right = c.split.right_sum.GetHess() < c.split.left_sum.GetHess();
auto build_nidx = left_nidx;
auto subtract_nidx = right_nidx;
if (fewer_right) {
std::swap(build_nidx, subtract_nidx);
} }
nodes_to_build[n_idx] = CPUExpandEntry{build_nidx, p_tree->GetDepth(build_nidx), {}};
nodes_to_sub[n_idx] = CPUExpandEntry{subtract_nidx, p_tree->GetDepth(subtract_nidx), {}};
n_idx++;
} }
// Split nodes to 2 sets depending on amount of rows in each node size_t page_id{0};
// Histograms for small nodes will be built explicitly auto space = ConstructHistSpace(partitioner_, nodes_to_build);
// Histograms for big nodes will be built by 'Subtraction Trick' for (auto const &gidx : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
// Exception: in distributed setting, we always build the histogram for the left child node histogram_builder_->BuildHist(page_id, space, gidx, p_tree,
// and use 'Subtraction Trick' to built the histogram for the right child node. partitioner_.at(page_id).Partitions(), nodes_to_build,
// This ensures that the workers operate on the same set of tree nodes. nodes_to_sub, gpair);
template <typename GradientSumT> ++page_id;
void QuantileHistMaker::Builder<GradientSumT>::SplitSiblings(
const std::vector<CPUExpandEntry> &nodes_for_apply_split,
std::vector<CPUExpandEntry> *nodes_to_evaluate, RegTree *p_tree) {
builder_monitor_->Start("SplitSiblings");
auto const& row_set_collection = this->partitioner_.front().Partitions();
for (auto const& entry : nodes_for_apply_split) {
int nid = entry.nid;
const int cleft = (*p_tree)[nid].LeftChild();
const int cright = (*p_tree)[nid].RightChild();
const CPUExpandEntry left_node = CPUExpandEntry(cleft, p_tree->GetDepth(cleft), 0.0);
const CPUExpandEntry right_node = CPUExpandEntry(cright, p_tree->GetDepth(cright), 0.0);
nodes_to_evaluate->push_back(left_node);
nodes_to_evaluate->push_back(right_node);
if (row_set_collection[cleft].Size() < row_set_collection[cright].Size()) {
nodes_for_explicit_hist_build_.push_back(left_node);
nodes_for_subtraction_trick_.push_back(right_node);
} else {
nodes_for_explicit_hist_build_.push_back(right_node);
nodes_for_subtraction_trick_.push_back(left_node);
} }
} }
CHECK_EQ(nodes_for_subtraction_trick_.size(), nodes_for_explicit_hist_build_.size());
builder_monitor_->Stop("SplitSiblings");
}
template <typename GradientSumT> template <typename GradientSumT>
template <bool any_missing>
void QuantileHistMaker::Builder<GradientSumT>::ExpandTree( void QuantileHistMaker::Builder<GradientSumT>::ExpandTree(
const GHistIndexMatrix& gmat, DMatrix *p_fmat, RegTree *p_tree, const std::vector<GradientPair> &gpair_h) {
const common::ColumnMatrix& column_matrix, monitor_->Start(__func__);
DMatrix* p_fmat,
RegTree* p_tree,
const std::vector<GradientPair>& gpair_h) {
builder_monitor_->Start("ExpandTree");
int num_leaves = 0;
Driver<CPUExpandEntry> driver(static_cast<TrainParam::TreeGrowPolicy>(param_.grow_policy)); Driver<CPUExpandEntry> driver(static_cast<TrainParam::TreeGrowPolicy>(param_.grow_policy));
std::vector<CPUExpandEntry> expand; driver.Push(this->InitRoot(p_fmat, p_tree, gpair_h));
InitRoot<any_missing>(p_fmat, p_tree, gpair_h, &num_leaves, &expand); bst_node_t num_leaves{1};
driver.Push(expand[0]); auto expand_set = driver.Pop();
int32_t depth = 0; while (!expand_set.empty()) {
while (!driver.IsEmpty()) { // candidates that can be further splited.
expand = driver.Pop(); std::vector<CPUExpandEntry> valid_candidates;
depth = expand[0].depth + 1; // candidaates that can be applied.
std::vector<CPUExpandEntry> nodes_for_apply_split; std::vector<CPUExpandEntry> applied;
std::vector<CPUExpandEntry> nodes_to_evaluate; int32_t depth = expand_set.front().depth + 1;
nodes_for_explicit_hist_build_.clear(); for (auto const& candidate : expand_set) {
nodes_for_subtraction_trick_.clear(); if (!candidate.IsValid(param_, num_leaves)) {
continue;
AddSplitsToTree(expand, p_tree, &num_leaves, &nodes_for_apply_split); }
evaluator_->ApplyTreeSplit(candidate, p_tree);
if (nodes_for_apply_split.size() != 0) { applied.push_back(candidate);
HistRowPartitioner &partitioner = this->partitioner_.front(); num_leaves++;
if (gmat.cut.HasCategorical()) { if (CPUExpandEntry::ChildIsValid(param_, depth, num_leaves)) {
partitioner.UpdatePosition<any_missing, true>(this->ctx_, gmat, column_matrix, valid_candidates.emplace_back(candidate);
nodes_for_apply_split, p_tree); }
} else {
partitioner.UpdatePosition<any_missing, false>(this->ctx_, gmat, column_matrix,
nodes_for_apply_split, p_tree);
} }
SplitSiblings(nodes_for_apply_split, &nodes_to_evaluate, p_tree); monitor_->Start("UpdatePosition");
size_t page_id{0};
if (param_.max_depth == 0 || depth < param_.max_depth) { for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
size_t i = 0; partitioner_.at(page_id).UpdatePosition(ctx_, page, applied, p_tree);
for (auto const &gidx : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) { ++page_id;
this->histogram_builder_->BuildHist(i, gidx, p_tree, partitioner_.front().Partitions(),
nodes_for_explicit_hist_build_,
nodes_for_subtraction_trick_, gpair_h);
++i;
}
} else {
int starting_index = std::numeric_limits<int>::max();
int sync_count = 0;
this->histogram_builder_->AddHistRows(
&starting_index, &sync_count, nodes_for_explicit_hist_build_,
nodes_for_subtraction_trick_, p_tree);
} }
monitor_->Stop("UpdatePosition");
builder_monitor_->Start("EvaluateSplits"); std::vector<CPUExpandEntry> best_splits;
if (!valid_candidates.empty()) {
this->BuildHistogram(p_fmat, p_tree, valid_candidates, gpair_h);
auto const &tree = *p_tree;
for (auto const &candidate : valid_candidates) {
int left_child_nidx = tree[candidate.nid].LeftChild();
int right_child_nidx = tree[candidate.nid].RightChild();
CPUExpandEntry l_best{left_child_nidx, depth, 0.0};
CPUExpandEntry r_best{right_child_nidx, depth, 0.0};
best_splits.push_back(l_best);
best_splits.push_back(r_best);
}
auto const &histograms = histogram_builder_->Histogram();
auto ft = p_fmat->Info().feature_types.ConstHostSpan(); auto ft = p_fmat->Info().feature_types.ConstHostSpan();
evaluator_->EvaluateSplits(this->histogram_builder_->Histogram(), for (auto const &gmat : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
gmat.cut, ft, *p_tree, &nodes_to_evaluate); evaluator_->EvaluateSplits(histograms, gmat.cut, ft, *p_tree, &best_splits);
builder_monitor_->Stop("EvaluateSplits"); break;
}
}
driver.Push(best_splits.begin(), best_splits.end());
expand_set = driver.Pop();
}
for (size_t i = 0; i < nodes_for_apply_split.size(); ++i) { monitor_->Stop(__func__);
CPUExpandEntry left_node = nodes_to_evaluate.at(i * 2 + 0);
CPUExpandEntry right_node = nodes_to_evaluate.at(i * 2 + 1);
driver.Push(left_node);
driver.Push(right_node);
}
}
}
builder_monitor_->Stop("ExpandTree");
} }
template <typename GradientSumT> template <typename GradientSumT>
void QuantileHistMaker::Builder<GradientSumT>::Update( void QuantileHistMaker::Builder<GradientSumT>::UpdateTree(HostDeviceVector<GradientPair> *gpair,
const GHistIndexMatrix &gmat,
const common::ColumnMatrix &column_matrix,
HostDeviceVector<GradientPair> *gpair,
DMatrix *p_fmat, RegTree *p_tree) { DMatrix *p_fmat, RegTree *p_tree) {
builder_monitor_->Start("Update"); monitor_->Start(__func__);
std::vector<GradientPair> *gpair_ptr = &(gpair->HostVector()); std::vector<GradientPair> *gpair_ptr = &(gpair->HostVector());
// in case 'num_parallel_trees != 1' no posibility to change initial gpair // in case 'num_parallel_trees != 1' no posibility to change initial gpair
@ -315,18 +243,12 @@ void QuantileHistMaker::Builder<GradientSumT>::Update(
gpair_local_ = *gpair_ptr; gpair_local_ = *gpair_ptr;
gpair_ptr = &gpair_local_; gpair_ptr = &gpair_local_;
} }
p_last_fmat_mutable_ = p_fmat;
this->InitData(gmat, p_fmat, *p_tree, gpair_ptr); this->InitData(p_fmat, *p_tree, gpair_ptr);
if (column_matrix.AnyMissing()) { ExpandTree(p_fmat, p_tree, *gpair_ptr);
ExpandTree<true>(gmat, column_matrix, p_fmat, p_tree, *gpair_ptr);
} else {
ExpandTree<false>(gmat, column_matrix, p_fmat, p_tree, *gpair_ptr);
}
pruner_->Update(gpair, p_fmat, std::vector<RegTree*>{p_tree});
builder_monitor_->Stop("Update"); monitor_->Stop(__func__);
} }
template <typename GradientSumT> template <typename GradientSumT>
@ -334,20 +256,20 @@ bool QuantileHistMaker::Builder<GradientSumT>::UpdatePredictionCache(
DMatrix const *data, linalg::VectorView<float> out_preds) const { DMatrix const *data, linalg::VectorView<float> out_preds) const {
// p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in // p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in
// conjunction with Update(). // conjunction with Update().
if (!p_last_fmat_ || !p_last_tree_ || data != p_last_fmat_ || if (!p_last_fmat_ || !p_last_tree_ || data != p_last_fmat_) {
p_last_fmat_ != p_last_fmat_mutable_) {
return false; return false;
} }
builder_monitor_->Start(__func__); monitor_->Start(__func__);
CHECK_EQ(out_preds.Size(), data->Info().num_row_); CHECK_EQ(out_preds.Size(), data->Info().num_row_);
UpdatePredictionCacheImpl(ctx_, p_last_tree_, partitioner_, *evaluator_, param_, out_preds); UpdatePredictionCacheImpl(ctx_, p_last_tree_, partitioner_, *evaluator_, param_, out_preds);
builder_monitor_->Stop(__func__); monitor_->Stop(__func__);
return true; return true;
} }
template <typename GradientSumT> template <typename GradientSumT>
void QuantileHistMaker::Builder<GradientSumT>::InitSampling(const DMatrix &fmat, void QuantileHistMaker::Builder<GradientSumT>::InitSampling(const DMatrix &fmat,
std::vector<GradientPair> *gpair) { std::vector<GradientPair> *gpair) {
monitor_->Start(__func__);
const auto &info = fmat.Info(); const auto &info = fmat.Info();
auto& rnd = common::GlobalRandom(); auto& rnd = common::GlobalRandom();
std::vector<GradientPair>& gpair_ref = *gpair; std::vector<GradientPair>& gpair_ref = *gpair;
@ -380,6 +302,7 @@ void QuantileHistMaker::Builder<GradientSumT>::InitSampling(const DMatrix &fmat,
} }
exc.Rethrow(); exc.Rethrow();
#endif // XGBOOST_CUSTOMIZE_GLOBAL_PRNG #endif // XGBOOST_CUSTOMIZE_GLOBAL_PRNG
monitor_->Stop(__func__);
} }
template<typename GradientSumT> template<typename GradientSumT>
size_t QuantileHistMaker::Builder<GradientSumT>::GetNumberOfTrees() { size_t QuantileHistMaker::Builder<GradientSumT>::GetNumberOfTrees() {
@ -387,10 +310,9 @@ size_t QuantileHistMaker::Builder<GradientSumT>::GetNumberOfTrees() {
} }
template <typename GradientSumT> template <typename GradientSumT>
void QuantileHistMaker::Builder<GradientSumT>::InitData(const GHistIndexMatrix &gmat, DMatrix *fmat, void QuantileHistMaker::Builder<GradientSumT>::InitData(DMatrix *fmat, const RegTree &tree,
const RegTree &tree,
std::vector<GradientPair> *gpair) { std::vector<GradientPair> *gpair) {
builder_monitor_->Start("InitData"); monitor_->Start(__func__);
const auto& info = fmat->Info(); const auto& info = fmat->Info();
{ {
@ -406,18 +328,14 @@ void QuantileHistMaker::Builder<GradientSumT>::InitData(const GHistIndexMatrix &
partitioner_.emplace_back(page.Size(), page.base_rowid, this->ctx_->Threads()); partitioner_.emplace_back(page.Size(), page.base_rowid, this->ctx_->Threads());
++page_id; ++page_id;
} }
histogram_builder_->Reset(n_total_bins, BatchParam{param_.max_bin, param_.sparse_threshold}, histogram_builder_->Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id,
ctx_->Threads(), page_id, rabit::IsDistributed()); rabit::IsDistributed());
if (param_.subsample < 1.0f) { if (param_.subsample < 1.0f) {
CHECK_EQ(param_.sampling_method, TrainParam::kUniform) CHECK_EQ(param_.sampling_method, TrainParam::kUniform)
<< "Only uniform sampling is supported, " << "Only uniform sampling is supported, "
<< "gradient-based sampling is only support by GPU Hist."; << "gradient-based sampling is only support by GPU Hist.";
builder_monitor_->Start("InitSampling");
InitSampling(*fmat, gpair); InitSampling(*fmat, gpair);
builder_monitor_->Stop("InitSampling");
// We should check that the partitioning was done correctly
// and each row of the dataset fell into exactly one of the categories
} }
} }
@ -426,7 +344,7 @@ void QuantileHistMaker::Builder<GradientSumT>::InitData(const GHistIndexMatrix &
evaluator_.reset(new HistEvaluator<GradientSumT, CPUExpandEntry>{ evaluator_.reset(new HistEvaluator<GradientSumT, CPUExpandEntry>{
param_, info, this->ctx_->Threads(), column_sampler_, task_}); param_, info, this->ctx_->Threads(), column_sampler_, task_});
builder_monitor_->Stop("InitData"); monitor_->Stop(__func__);
} }
void HistRowPartitioner::FindSplitConditions(const std::vector<CPUExpandEntry> &nodes, void HistRowPartitioner::FindSplitConditions(const std::vector<CPUExpandEntry> &nodes,
@ -470,21 +388,8 @@ void HistRowPartitioner::AddSplitsToRowSet(const std::vector<CPUExpandEntry> &no
template struct QuantileHistMaker::Builder<float>; template struct QuantileHistMaker::Builder<float>;
template struct QuantileHistMaker::Builder<double>; template struct QuantileHistMaker::Builder<double>;
XGBOOST_REGISTER_TREE_UPDATER(FastHistMaker, "grow_fast_histmaker")
.describe("(Deprecated, use grow_quantile_histmaker instead.)"
" Grow tree using quantized histogram.")
.set_body(
[](ObjInfo task) {
LOG(WARNING) << "grow_fast_histmaker is deprecated, "
<< "use grow_quantile_histmaker instead.";
return new QuantileHistMaker(task);
});
XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker") XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker")
.describe("Grow tree using quantized histogram.") .describe("Grow tree using quantized histogram.")
.set_body( .set_body([](ObjInfo task) { return new QuantileHistMaker(task); });
[](ObjInfo task) {
return new QuantileHistMaker(task);
});
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost

View File

@ -7,7 +7,6 @@
#ifndef XGBOOST_TREE_UPDATER_QUANTILE_HIST_H_ #ifndef XGBOOST_TREE_UPDATER_QUANTILE_HIST_H_
#define XGBOOST_TREE_UPDATER_QUANTILE_HIST_H_ #define XGBOOST_TREE_UPDATER_QUANTILE_HIST_H_
#include <dmlc/timer.h>
#include <rabit/rabit.h> #include <rabit/rabit.h>
#include <xgboost/tree_updater.h> #include <xgboost/tree_updater.h>
@ -29,7 +28,6 @@
#include "constraints.h" #include "constraints.h"
#include "./param.h" #include "./param.h"
#include "./driver.h" #include "./driver.h"
#include "./split_evaluator.h"
#include "../common/random.h" #include "../common/random.h"
#include "../common/timer.h" #include "../common/timer.h"
#include "../common/hist_util.h" #include "../common/hist_util.h"
@ -194,6 +192,24 @@ class HistRowPartitioner {
AddSplitsToRowSet(nodes, p_tree); AddSplitsToRowSet(nodes, p_tree);
} }
void UpdatePosition(GenericParameter const* ctx, GHistIndexMatrix const& page,
std::vector<CPUExpandEntry> const& applied, RegTree const* p_tree) {
auto const& column_matrix = page.Transpose();
if (page.cut.HasCategorical()) {
if (column_matrix.AnyMissing()) {
this->template UpdatePosition<true, true>(ctx, page, column_matrix, applied, p_tree);
} else {
this->template UpdatePosition<false, true>(ctx, page, column_matrix, applied, p_tree);
}
} else {
if (column_matrix.AnyMissing()) {
this->template UpdatePosition<true, false>(ctx, page, column_matrix, applied, p_tree);
} else {
this->template UpdatePosition<false, false>(ctx, page, column_matrix, applied, p_tree);
}
}
}
auto const& Partitions() const { return row_set_collection_; } auto const& Partitions() const { return row_set_collection_; }
size_t Size() const { size_t Size() const {
return std::distance(row_set_collection_.begin(), row_set_collection_.end()); return std::distance(row_set_collection_.begin(), row_set_collection_.end());
@ -209,9 +225,7 @@ inline BatchParam HistBatch(TrainParam const& param) {
/*! \brief construct a tree using quantized feature values */ /*! \brief construct a tree using quantized feature values */
class QuantileHistMaker: public TreeUpdater { class QuantileHistMaker: public TreeUpdater {
public: public:
explicit QuantileHistMaker(ObjInfo task) : task_{task} { explicit QuantileHistMaker(ObjInfo task) : task_{task} {}
updater_monitor_.Init("QuantileHistMaker");
}
void Configure(const Args& args) override; void Configure(const Args& args) override;
void Update(HostDeviceVector<GradientPair>* gpair, void Update(HostDeviceVector<GradientPair>* gpair,
@ -256,10 +270,6 @@ class QuantileHistMaker: public TreeUpdater {
CPUHistMakerTrainParam hist_maker_param_; CPUHistMakerTrainParam hist_maker_param_;
// training parameter // training parameter
TrainParam param_; TrainParam param_;
// column accessor
common::ColumnMatrix column_matrix_;
DMatrix const* p_last_dmat_ {nullptr};
bool is_gmat_initialized_ {false};
// actual builder that runs the algorithm // actual builder that runs the algorithm
template<typename GradientSumT> template<typename GradientSumT>
@ -267,60 +277,40 @@ class QuantileHistMaker: public TreeUpdater {
public: public:
using GradientPairT = xgboost::detail::GradientPairInternal<GradientSumT>; using GradientPairT = xgboost::detail::GradientPairInternal<GradientSumT>;
// constructor // constructor
explicit Builder(const size_t n_trees, const TrainParam& param, explicit Builder(const size_t n_trees, const TrainParam& param, DMatrix const* fmat,
std::unique_ptr<TreeUpdater> pruner, DMatrix const* fmat, ObjInfo task, ObjInfo task, GenericParameter const* ctx)
GenericParameter const* ctx)
: n_trees_(n_trees), : n_trees_(n_trees),
param_(param), param_(param),
pruner_(std::move(pruner)),
p_last_fmat_(fmat), p_last_fmat_(fmat),
histogram_builder_{new HistogramBuilder<GradientSumT, CPUExpandEntry>}, histogram_builder_{new HistogramBuilder<GradientSumT, CPUExpandEntry>},
task_{task}, task_{task},
ctx_{ctx}, ctx_{ctx},
builder_monitor_{std::make_unique<common::Monitor>()} { monitor_{std::make_unique<common::Monitor>()} {
builder_monitor_->Init("Quantile::Builder"); monitor_->Init("Quantile::Builder");
} }
// update one tree, growing // update one tree, growing
void Update(const GHistIndexMatrix& gmat, const common::ColumnMatrix& column_matrix, void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat, RegTree* p_tree);
HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat, RegTree* p_tree);
bool UpdatePredictionCache(DMatrix const* data, linalg::VectorView<float> out_preds) const; bool UpdatePredictionCache(DMatrix const* data, linalg::VectorView<float> out_preds) const;
protected: private:
// initialize temp data structure // initialize temp data structure
void InitData(const GHistIndexMatrix& gmat, DMatrix* fmat, const RegTree& tree, void InitData(DMatrix* fmat, const RegTree& tree, std::vector<GradientPair>* gpair);
std::vector<GradientPair>* gpair);
size_t GetNumberOfTrees(); size_t GetNumberOfTrees();
void InitSampling(const DMatrix& fmat, std::vector<GradientPair>* gpair); void InitSampling(const DMatrix& fmat, std::vector<GradientPair>* gpair);
template <bool any_missing> CPUExpandEntry InitRoot(DMatrix* p_fmat, RegTree* p_tree,
void InitRoot(DMatrix* p_fmat,
RegTree *p_tree,
const std::vector<GradientPair> &gpair_h,
int *num_leaves, std::vector<CPUExpandEntry> *expand);
// Split nodes to 2 sets depending on amount of rows in each node
// Histograms for small nodes will be built explicitly
// Histograms for big nodes will be built by 'Subtraction Trick'
void SplitSiblings(const std::vector<CPUExpandEntry>& nodes,
std::vector<CPUExpandEntry>* nodes_to_evaluate,
RegTree *p_tree);
void AddSplitsToTree(const std::vector<CPUExpandEntry>& expand,
RegTree *p_tree,
int *num_leaves,
std::vector<CPUExpandEntry>* nodes_for_apply_split);
template <bool any_missing>
void ExpandTree(const GHistIndexMatrix& gmat,
const common::ColumnMatrix& column_matrix,
DMatrix* p_fmat,
RegTree* p_tree,
const std::vector<GradientPair>& gpair_h); const std::vector<GradientPair>& gpair_h);
// --data fields-- void BuildHistogram(DMatrix* p_fmat, RegTree* p_tree,
std::vector<CPUExpandEntry> const& valid_candidates,
std::vector<GradientPair> const& gpair);
void ExpandTree(DMatrix* p_fmat, RegTree* p_tree, const std::vector<GradientPair>& gpair_h);
private:
const size_t n_trees_; const size_t n_trees_;
const TrainParam& param_; const TrainParam& param_;
std::shared_ptr<common::ColumnSampler> column_sampler_{ std::shared_ptr<common::ColumnSampler> column_sampler_{
@ -328,48 +318,24 @@ class QuantileHistMaker: public TreeUpdater {
std::vector<GradientPair> gpair_local_; std::vector<GradientPair> gpair_local_;
std::unique_ptr<TreeUpdater> pruner_;
std::unique_ptr<HistEvaluator<GradientSumT, CPUExpandEntry>> evaluator_; std::unique_ptr<HistEvaluator<GradientSumT, CPUExpandEntry>> evaluator_;
// Right now there's only 1 partitioner in this vector, when external memory is fully
// supported we will have number of partitioners equal to number of pages.
std::vector<HistRowPartitioner> partitioner_; std::vector<HistRowPartitioner> partitioner_;
// back pointers to tree and data matrix // back pointers to tree and data matrix
const RegTree* p_last_tree_{nullptr}; const RegTree* p_last_tree_{nullptr};
DMatrix const* const p_last_fmat_; DMatrix const* const p_last_fmat_;
DMatrix* p_last_fmat_mutable_;
// key is the node id which should be calculated by Subtraction Trick, value is the node which
// provides the evidence for subtraction
std::vector<CPUExpandEntry> nodes_for_subtraction_trick_;
// list of nodes whose histograms would be built explicitly.
std::vector<CPUExpandEntry> nodes_for_explicit_hist_build_;
enum class DataLayout { kDenseDataZeroBased, kDenseDataOneBased, kSparseData };
std::unique_ptr<HistogramBuilder<GradientSumT, CPUExpandEntry>> histogram_builder_; std::unique_ptr<HistogramBuilder<GradientSumT, CPUExpandEntry>> histogram_builder_;
ObjInfo task_; ObjInfo task_;
// Context for number of threads // Context for number of threads
GenericParameter const* ctx_; GenericParameter const* ctx_;
std::unique_ptr<common::Monitor> builder_monitor_; std::unique_ptr<common::Monitor> monitor_;
}; };
common::Monitor updater_monitor_;
template<typename GradientSumT>
void SetBuilder(const size_t n_trees, std::unique_ptr<Builder<GradientSumT>>*, DMatrix *dmat);
template<typename GradientSumT>
void CallBuilderUpdate(const std::unique_ptr<Builder<GradientSumT>>& builder,
HostDeviceVector<GradientPair> *gpair,
DMatrix *dmat,
GHistIndexMatrix const& gmat,
const std::vector<RegTree *> &trees);
protected: protected:
std::unique_ptr<Builder<float>> float_builder_; std::unique_ptr<Builder<float>> float_builder_;
std::unique_ptr<Builder<double>> double_builder_; std::unique_ptr<Builder<double>> double_builder_;
std::unique_ptr<TreeUpdater> pruner_;
ObjInfo task_; ObjInfo task_;
}; };
} // namespace tree } // namespace tree

View File

@ -21,7 +21,9 @@ TEST(DenseColumn, Test) {
GHistIndexMatrix gmat{dmat.get(), max_num_bin, sparse_thresh, false, GHistIndexMatrix gmat{dmat.get(), max_num_bin, sparse_thresh, false,
common::OmpGetNumThreads(0)}; common::OmpGetNumThreads(0)};
ColumnMatrix column_matrix; ColumnMatrix column_matrix;
column_matrix.Init(gmat, 0.2, common::OmpGetNumThreads(0)); for (auto const& page : dmat->GetBatches<SparsePage>()) {
column_matrix.Init(page, gmat, sparse_thresh, common::OmpGetNumThreads(0));
}
for (auto i = 0ull; i < dmat->Info().num_row_; i++) { for (auto i = 0ull; i < dmat->Info().num_row_; i++) {
for (auto j = 0ull; j < dmat->Info().num_col_; j++) { for (auto j = 0ull; j < dmat->Info().num_col_; j++) {
@ -68,7 +70,9 @@ TEST(SparseColumn, Test) {
auto dmat = RandomDataGenerator(100, 1, 0.85).GenerateDMatrix(); auto dmat = RandomDataGenerator(100, 1, 0.85).GenerateDMatrix();
GHistIndexMatrix gmat{dmat.get(), max_num_bin, 0.5f, false, common::OmpGetNumThreads(0)}; GHistIndexMatrix gmat{dmat.get(), max_num_bin, 0.5f, false, common::OmpGetNumThreads(0)};
ColumnMatrix column_matrix; ColumnMatrix column_matrix;
column_matrix.Init(gmat, 0.5, common::OmpGetNumThreads(0)); for (auto const& page : dmat->GetBatches<SparsePage>()) {
column_matrix.Init(page, gmat, 1.0, common::OmpGetNumThreads(0));
}
switch (column_matrix.GetTypeSize()) { switch (column_matrix.GetTypeSize()) {
case kUint8BinsTypeSize: { case kUint8BinsTypeSize: {
auto col = column_matrix.GetColumn<uint8_t, true>(0); auto col = column_matrix.GetColumn<uint8_t, true>(0);
@ -106,9 +110,11 @@ TEST(DenseColumnWithMissing, Test) {
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 2}; static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 2};
for (int32_t max_num_bin : max_num_bins) { for (int32_t max_num_bin : max_num_bins) {
auto dmat = RandomDataGenerator(100, 1, 0.5).GenerateDMatrix(); auto dmat = RandomDataGenerator(100, 1, 0.5).GenerateDMatrix();
GHistIndexMatrix gmat{dmat.get(), max_num_bin, 0.2, false, common::OmpGetNumThreads(0)}; GHistIndexMatrix gmat(dmat.get(), max_num_bin, 0.2, false, common::OmpGetNumThreads(0));
ColumnMatrix column_matrix; ColumnMatrix column_matrix;
column_matrix.Init(gmat, 0.2, common::OmpGetNumThreads(0)); for (auto const& page : dmat->GetBatches<SparsePage>()) {
column_matrix.Init(page, gmat, 0.2, common::OmpGetNumThreads(0));
}
switch (column_matrix.GetTypeSize()) { switch (column_matrix.GetTypeSize()) {
case kUint8BinsTypeSize: { case kUint8BinsTypeSize: {
auto col = column_matrix.GetColumn<uint8_t, true>(0); auto col = column_matrix.GetColumn<uint8_t, true>(0);

View File

@ -3,6 +3,7 @@
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "../../../src/common/column_matrix.h"
#include "../../../src/data/gradient_index.h" #include "../../../src/data/gradient_index.h"
#include "../../../src/data/sparse_page_source.h" #include "../../../src/data/sparse_page_source.h"
#include "../helpers.h" #include "../helpers.h"
@ -15,33 +16,31 @@ TEST(GHistIndexPageRawFormat, IO) {
auto m = RandomDataGenerator{100, 14, 0.5}.GenerateDMatrix(); auto m = RandomDataGenerator{100, 14, 0.5}.GenerateDMatrix();
dmlc::TemporaryDirectory tmpdir; dmlc::TemporaryDirectory tmpdir;
std::string path = tmpdir.path + "/ghistindex.page"; std::string path = tmpdir.path + "/ghistindex.page";
auto batch = BatchParam{256, 0.5};
{ {
std::unique_ptr<dmlc::Stream> fo{dmlc::Stream::Create(path.c_str(), "w")}; std::unique_ptr<dmlc::Stream> fo{dmlc::Stream::Create(path.c_str(), "w")};
for (auto const &index : for (auto const &index : m->GetBatches<GHistIndexMatrix>(batch)) {
m->GetBatches<GHistIndexMatrix>({GenericParameter::kCpuId, 256})) {
format->Write(index, fo.get()); format->Write(index, fo.get());
} }
} }
GHistIndexMatrix page; GHistIndexMatrix page;
std::unique_ptr<dmlc::SeekStream> fi{ std::unique_ptr<dmlc::SeekStream> fi{dmlc::SeekStream::CreateForRead(path.c_str())};
dmlc::SeekStream::CreateForRead(path.c_str())};
format->Read(&page, fi.get()); format->Read(&page, fi.get());
for (auto const &gidx : for (auto const &gidx : m->GetBatches<GHistIndexMatrix>(batch)) {
m->GetBatches<GHistIndexMatrix>({GenericParameter::kCpuId, 256})) {
auto const &loaded = gidx; auto const &loaded = gidx;
ASSERT_EQ(loaded.cut.Ptrs(), page.cut.Ptrs()); ASSERT_EQ(loaded.cut.Ptrs(), page.cut.Ptrs());
ASSERT_EQ(loaded.cut.MinValues(), page.cut.MinValues()); ASSERT_EQ(loaded.cut.MinValues(), page.cut.MinValues());
ASSERT_EQ(loaded.cut.Values(), page.cut.Values()); ASSERT_EQ(loaded.cut.Values(), page.cut.Values());
ASSERT_EQ(loaded.base_rowid, page.base_rowid); ASSERT_EQ(loaded.base_rowid, page.base_rowid);
ASSERT_EQ(loaded.IsDense(), page.IsDense()); ASSERT_EQ(loaded.IsDense(), page.IsDense());
ASSERT_TRUE(std::equal(loaded.index.begin(), loaded.index.end(), ASSERT_TRUE(std::equal(loaded.index.begin(), loaded.index.end(), page.index.begin()));
page.index.begin())); ASSERT_TRUE(std::equal(loaded.index.Offset(), loaded.index.Offset() + loaded.index.OffsetSize(),
ASSERT_TRUE(std::equal(loaded.index.Offset(),
loaded.index.Offset() + loaded.index.OffsetSize(),
page.index.Offset())); page.index.Offset()));
ASSERT_EQ(loaded.Transpose().GetTypeSize(), loaded.Transpose().GetTypeSize());
} }
} }
} // namespace data } // namespace data

View File

@ -446,6 +446,12 @@ void TestHistogramExternalMemory(BatchParam batch_param, bool is_approx) {
TEST(CPUHistogram, ExternalMemory) { TEST(CPUHistogram, ExternalMemory) {
int32_t constexpr kBins = 256; int32_t constexpr kBins = 256;
TestHistogramExternalMemory(BatchParam{kBins, common::Span<float>{}, false}, true); TestHistogramExternalMemory(BatchParam{kBins, common::Span<float>{}, false}, true);
float sparse_thresh{0.5};
TestHistogramExternalMemory({kBins, sparse_thresh}, false);
sparse_thresh = std::numeric_limits<float>::quiet_NaN();
TestHistogramExternalMemory({kBins, sparse_thresh}, false);
} }
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost

View File

@ -18,138 +18,6 @@
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {
class QuantileHistMock : public QuantileHistMaker {
static double constexpr kEps = 1e-6;
template <typename GradientSumT>
struct BuilderMock : public QuantileHistMaker::Builder<GradientSumT> {
using RealImpl = QuantileHistMaker::Builder<GradientSumT>;
BuilderMock(const TrainParam &param, std::unique_ptr<TreeUpdater> pruner,
DMatrix const *fmat, GenericParameter const* ctx)
: RealImpl(1, param, std::move(pruner), fmat, ObjInfo{ObjInfo::kRegression}, ctx) {}
public:
void TestInitData(const GHistIndexMatrix& gmat,
std::vector<GradientPair>* gpair,
DMatrix* p_fmat,
const RegTree& tree) {
RealImpl::InitData(gmat, p_fmat, tree, gpair);
/* The creation of HistCutMatrix and GHistIndexMatrix are not technically
* part of QuantileHist updater logic, but we include it here because
* QuantileHist updater object currently stores GHistIndexMatrix
* internally. According to https://github.com/dmlc/xgboost/pull/3803,
* we should eventually move GHistIndexMatrix out of the QuantileHist
* updater. */
const size_t num_row = p_fmat->Info().num_row_;
const size_t num_col = p_fmat->Info().num_col_;
/* Validate HistCutMatrix */
ASSERT_EQ(gmat.cut.Ptrs().size(), num_col + 1);
for (size_t fid = 0; fid < num_col; ++fid) {
const size_t ibegin = gmat.cut.Ptrs()[fid];
const size_t iend = gmat.cut.Ptrs()[fid + 1];
// Ordered, but empty feature is allowed.
ASSERT_LE(ibegin, iend);
for (size_t i = ibegin; i < iend - 1; ++i) {
// Quantile points must be sorted in ascending order
// No duplicates allowed
ASSERT_LT(gmat.cut.Values()[i], gmat.cut.Values()[i + 1])
<< "ibegin: " << ibegin << ", "
<< "iend: " << iend;
}
}
/* Validate GHistIndexMatrix */
ASSERT_EQ(gmat.row_ptr.size(), num_row + 1);
ASSERT_LT(*std::max_element(gmat.index.begin(), gmat.index.end()),
gmat.cut.Ptrs().back());
for (const auto& batch : p_fmat->GetBatches<xgboost::SparsePage>()) {
auto page = batch.GetView();
for (size_t i = 0; i < batch.Size(); ++i) {
const size_t rid = batch.base_rowid + i;
ASSERT_LT(rid, num_row);
const size_t gmat_row_offset = gmat.row_ptr[rid];
ASSERT_LT(gmat_row_offset, gmat.index.Size());
SparsePage::Inst inst = page[i];
ASSERT_EQ(gmat.row_ptr[rid] + inst.size(), gmat.row_ptr[rid + 1]);
for (size_t j = 0; j < inst.size(); ++j) {
// Each entry of GHistIndexMatrix represents a bin ID
const size_t bin_id = gmat.index[gmat_row_offset + j];
const size_t fid = inst[j].index;
// The bin ID must correspond to correct feature
ASSERT_GE(bin_id, gmat.cut.Ptrs()[fid]);
ASSERT_LT(bin_id, gmat.cut.Ptrs()[fid + 1]);
// The bin ID must correspond to a region between two
// suitable quantile points
ASSERT_LT(inst[j].fvalue, gmat.cut.Values()[bin_id]);
if (bin_id > gmat.cut.Ptrs()[fid]) {
ASSERT_GE(inst[j].fvalue, gmat.cut.Values()[bin_id - 1]);
} else {
ASSERT_GE(inst[j].fvalue, gmat.cut.MinValues()[fid]);
}
}
}
}
}
};
int static constexpr kNRows = 8, kNCols = 16;
std::shared_ptr<xgboost::DMatrix> dmat_;
GenericParameter ctx_;
const std::vector<std::pair<std::string, std::string> > cfg_;
std::shared_ptr<BuilderMock<float> > float_builder_;
std::shared_ptr<BuilderMock<double> > double_builder_;
public:
explicit QuantileHistMock(
const std::vector<std::pair<std::string, std::string> >& args,
const bool single_precision_histogram = false, bool batch = true) :
QuantileHistMaker{ObjInfo{ObjInfo::kRegression}}, cfg_{args} {
QuantileHistMaker::Configure(args);
dmat_ = RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix();
ctx_.UpdateAllowUnknown(Args{});
if (single_precision_histogram) {
float_builder_.reset(new BuilderMock<float>(param_, std::move(pruner_), dmat_.get(), &ctx_));
} else {
double_builder_.reset(
new BuilderMock<double>(param_, std::move(pruner_), dmat_.get(), &ctx_));
}
}
~QuantileHistMock() override = default;
static size_t GetNumColumns() { return kNCols; }
void TestInitData() {
int32_t constexpr kMaxBins = 4;
GHistIndexMatrix gmat{dmat_.get(), kMaxBins, 0.0f, false, common::OmpGetNumThreads(0)};
RegTree tree = RegTree();
tree.param.UpdateAllowUnknown(cfg_);
std::vector<GradientPair> gpair =
{ {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f},
{0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} };
if (double_builder_) {
double_builder_->TestInitData(gmat, &gpair, dmat_.get(), tree);
} else {
float_builder_->TestInitData(gmat, &gpair, dmat_.get(), tree);
}
}
};
TEST(QuantileHist, InitData) {
std::vector<std::pair<std::string, std::string>> cfg
{{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())}};
QuantileHistMock maker(cfg);
maker.TestInitData();
const bool single_precision_histogram = true;
QuantileHistMock maker_float(cfg, single_precision_histogram);
maker_float.TestInitData();
}
TEST(QuantileHist, Partitioner) { TEST(QuantileHist, Partitioner) {
size_t n_samples = 1024, n_features = 1, base_rowid = 0; size_t n_samples = 1024, n_features = 1, base_rowid = 0;
GenericParameter ctx; GenericParameter ctx;
@ -163,45 +31,44 @@ TEST(QuantileHist, Partitioner) {
auto Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true); auto Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true);
std::vector<CPUExpandEntry> candidates{{0, 0, 0.4}}; std::vector<CPUExpandEntry> candidates{{0, 0, 0.4}};
auto grad = GenerateRandomGradients(n_samples); auto cuts = common::SketchOnDMatrix(Xy.get(), 64, ctx.Threads());
std::vector<float> hess(grad.Size());
std::transform(grad.HostVector().cbegin(), grad.HostVector().cend(), hess.begin(),
[](auto gpair) { return gpair.GetHess(); });
for (auto const& page : Xy->GetBatches<GHistIndexMatrix>({64, 0.5})) { for (auto const& page : Xy->GetBatches<SparsePage>()) {
GHistIndexMatrix gmat;
gmat.Init(page, {}, cuts, 64, false, 0.5, ctx.Threads());
bst_feature_t const split_ind = 0; bst_feature_t const split_ind = 0;
common::ColumnMatrix column_indices; common::ColumnMatrix column_indices;
column_indices.Init(page, 0.5, ctx.Threads()); column_indices.Init(page, gmat, 0.5, ctx.Threads());
{ {
auto min_value = page.cut.MinValues()[split_ind]; auto min_value = gmat.cut.MinValues()[split_ind];
RegTree tree; RegTree tree;
HistRowPartitioner partitioner{n_samples, base_rowid, ctx.Threads()}; HistRowPartitioner partitioner{n_samples, base_rowid, ctx.Threads()};
GetSplit(&tree, min_value, &candidates); GetSplit(&tree, min_value, &candidates);
partitioner.UpdatePosition<false, true>(&ctx, page, column_indices, candidates, &tree); partitioner.UpdatePosition<false, true>(&ctx, gmat, column_indices, candidates, &tree);
ASSERT_EQ(partitioner.Size(), 3); ASSERT_EQ(partitioner.Size(), 3);
ASSERT_EQ(partitioner[1].Size(), 0); ASSERT_EQ(partitioner[1].Size(), 0);
ASSERT_EQ(partitioner[2].Size(), n_samples); ASSERT_EQ(partitioner[2].Size(), n_samples);
} }
{ {
HistRowPartitioner partitioner{n_samples, base_rowid, ctx.Threads()}; HistRowPartitioner partitioner{n_samples, base_rowid, ctx.Threads()};
auto ptr = page.cut.Ptrs()[split_ind + 1]; auto ptr = gmat.cut.Ptrs()[split_ind + 1];
float split_value = page.cut.Values().at(ptr / 2); float split_value = gmat.cut.Values().at(ptr / 2);
RegTree tree; RegTree tree;
GetSplit(&tree, split_value, &candidates); GetSplit(&tree, split_value, &candidates);
auto left_nidx = tree[RegTree::kRoot].LeftChild(); auto left_nidx = tree[RegTree::kRoot].LeftChild();
partitioner.UpdatePosition<false, true>(&ctx, page, column_indices, candidates, &tree); partitioner.UpdatePosition<false, true>(&ctx, gmat, column_indices, candidates, &tree);
auto elem = partitioner[left_nidx]; auto elem = partitioner[left_nidx];
ASSERT_LT(elem.Size(), n_samples); ASSERT_LT(elem.Size(), n_samples);
ASSERT_GT(elem.Size(), 1); ASSERT_GT(elem.Size(), 1);
for (auto it = elem.begin; it != elem.end; ++it) { for (auto it = elem.begin; it != elem.end; ++it) {
auto value = page.cut.Values().at(page.index[*it]); auto value = gmat.cut.Values().at(gmat.index[*it]);
ASSERT_LE(value, split_value); ASSERT_LE(value, split_value);
} }
auto right_nidx = tree[RegTree::kRoot].RightChild(); auto right_nidx = tree[RegTree::kRoot].RightChild();
elem = partitioner[right_nidx]; elem = partitioner[right_nidx];
for (auto it = elem.begin; it != elem.end; ++it) { for (auto it = elem.begin; it != elem.end; ++it) {
auto value = page.cut.Values().at(page.index[*it]); auto value = gmat.cut.Values().at(gmat.index[*it]);
ASSERT_GT(value, split_value) << *it; ASSERT_GT(value, split_value) << *it;
} }
} }

View File

@ -1,7 +1,7 @@
import xgboost as xgb import xgboost as xgb
from xgboost.data import SingleBatchInternalIter as SingleBatch from xgboost.data import SingleBatchInternalIter as SingleBatch
import numpy as np import numpy as np
from testing import IteratorForTest from testing import IteratorForTest, non_increasing
from typing import Tuple, List from typing import Tuple, List
import pytest import pytest
from hypothesis import given, strategies, settings from hypothesis import given, strategies, settings
@ -108,7 +108,7 @@ def run_data_iterator(
evals_result=results_from_it, evals_result=results_from_it,
verbose_eval=False, verbose_eval=False,
) )
it_predt = from_it.predict(Xy) assert non_increasing(results_from_it["Train"]["rmse"])
X, y = it.as_arrays() X, y = it.as_arrays()
Xy = xgb.DMatrix(X, y) Xy = xgb.DMatrix(X, y)
@ -125,13 +125,13 @@ def run_data_iterator(
verbose_eval=False, verbose_eval=False,
) )
arr_predt = from_arrays.predict(Xy) arr_predt = from_arrays.predict(Xy)
assert non_increasing(results_from_arrays["Train"]["rmse"])
if tree_method != "gpu_hist": rtol = 1e-2
rtol = 1e-1 # flaky # CPU sketching is more memory efficient but less consistent due to small chunks
else: it_predt = from_it.predict(Xy)
# Model can be sensitive to quantiles, use 1e-2 to relax the test. arr_predt = from_arrays.predict(Xy)
np.testing.assert_allclose(it_predt, arr_predt, rtol=1e-2) np.testing.assert_allclose(it_predt, arr_predt, rtol=rtol)
rtol = 1e-6
np.testing.assert_allclose( np.testing.assert_allclose(
results_from_it["Train"]["rmse"], results_from_it["Train"]["rmse"],