Prepare external memory support for hist. (#7638)

This PR prepares the GHistIndexMatrix to host the column matrix which is used by the hist tree method by accepting sparse_threshold parameter.

Some cleanups are made to ensure the correct batch param is being passed into DMatrix along with some additional tests for correctness of SimpleDMatrix.
This commit is contained in:
Jiaming Yuan 2022-02-10 16:58:02 +08:00 committed by GitHub
parent 87c01f49d8
commit 2775c2a1ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 368 additions and 201 deletions

View File

@ -18,6 +18,7 @@
#include <xgboost/string_view.h>
#include <algorithm>
#include <limits>
#include <memory>
#include <numeric>
#include <string>
@ -217,24 +218,33 @@ struct BatchParam {
common::Span<float> hess;
/*! \brief Whether should DMatrix regenerate the batch. Only used for GHistIndex. */
bool regen {false};
/*! \brief Parameter used to generate column matrix for hist. */
double sparse_thresh{std::numeric_limits<double>::quiet_NaN()};
BatchParam() = default;
// GPU Hist
BatchParam(int32_t device, int32_t max_bin)
: gpu_id{device}, max_bin{max_bin} {}
// Hist
BatchParam(int32_t max_bin, double sparse_thresh)
: max_bin{max_bin}, sparse_thresh{sparse_thresh} {}
// Approx
/**
* \brief Get batch with sketch weighted by hessian. The batch will be regenerated if
* the span is changed, so caller should keep the span for each iteration.
*/
BatchParam(int32_t device, int32_t max_bin, common::Span<float> hessian,
bool regenerate = false)
: gpu_id{device}, max_bin{max_bin}, hess{hessian}, regen{regenerate} {}
BatchParam(int32_t max_bin, common::Span<float> hessian, bool regenerate)
: max_bin{max_bin}, hess{hessian}, regen{regenerate} {}
bool operator!=(const BatchParam& other) const {
bool operator!=(BatchParam const& other) const {
if (hess.empty() && other.hess.empty()) {
return gpu_id != other.gpu_id || max_bin != other.max_bin;
}
return gpu_id != other.gpu_id || max_bin != other.max_bin || hess.data() != other.hess.data();
}
bool operator==(BatchParam const& other) const {
return !(*this != other);
}
};
struct HostSparsePageView {
@ -477,8 +487,10 @@ class DMatrix {
/**
* \brief Gets batches. Use range based for loop over BatchSet to access individual batches.
*/
template<typename T>
BatchSet<T> GetBatches(const BatchParam& param = {});
template <typename T>
BatchSet<T> GetBatches();
template <typename T>
BatchSet<T> GetBatches(const BatchParam& param);
template <typename T>
bool PageExists() const;
@ -592,7 +604,7 @@ class DMatrix {
};
template<>
inline BatchSet<SparsePage> DMatrix::GetBatches(const BatchParam&) {
inline BatchSet<SparsePage> DMatrix::GetBatches() {
return GetRowBatches();
}
@ -607,12 +619,12 @@ inline bool DMatrix::PageExists<SparsePage>() const {
}
template<>
inline BatchSet<CSCPage> DMatrix::GetBatches(const BatchParam&) {
inline BatchSet<CSCPage> DMatrix::GetBatches() {
return GetColumnBatches();
}
template<>
inline BatchSet<SortedCSCPage> DMatrix::GetBatches(const BatchParam&) {
inline BatchSet<SortedCSCPage> DMatrix::GetBatches() {
return GetSortedColumnBatches();
}

View File

@ -379,12 +379,11 @@ class ColumnMatrix {
std::vector<size_t> feature_offsets_;
// index_base_[fid]: least bin id for feature fid
uint32_t* index_base_;
uint32_t const* index_base_;
std::vector<bool> missing_flags_;
BinTypeSize bins_type_size_;
bool any_missing_;
};
} // namespace common
} // namespace xgboost
#endif // XGBOOST_COMMON_COLUMN_MATRIX_H_

View File

@ -2,13 +2,27 @@
* Copyright 2017-2022 by XGBoost Contributors
* \brief Data type for fast histogram aggregation.
*/
#include "gradient_index.h"
#include <algorithm>
#include <limits>
#include "gradient_index.h"
#include <memory>
#include "../common/column_matrix.h"
#include "../common/hist_util.h"
namespace xgboost {
GHistIndexMatrix::GHistIndexMatrix() : columns_{std::make_unique<common::ColumnMatrix>()} {}
GHistIndexMatrix::GHistIndexMatrix(DMatrix *x, int32_t max_bin, double sparse_thresh,
bool sorted_sketch, int32_t n_threads,
common::Span<float> hess) {
this->Init(x, max_bin, sparse_thresh, sorted_sketch, n_threads, hess);
}
GHistIndexMatrix::~GHistIndexMatrix() = default;
void GHistIndexMatrix::PushBatch(SparsePage const &batch,
common::Span<FeatureType const> ft,
size_t rbegin, size_t prev_sum, uint32_t nbins,
@ -90,18 +104,15 @@ void GHistIndexMatrix::PushBatch(SparsePage const &batch,
[offsets](auto idx, auto j) {
return static_cast<uint8_t>(idx - offsets[j]);
});
} else if (curent_bin_size == common::kUint16BinsTypeSize) {
common::Span<uint16_t> index_data_span = {index.data<uint16_t>(),
n_index};
common::Span<uint16_t> index_data_span = {index.data<uint16_t>(), n_index};
SetIndexData(index_data_span, ft, batch_threads, batch, rbegin, nbins,
[offsets](auto idx, auto j) {
return static_cast<uint16_t>(idx - offsets[j]);
});
} else {
CHECK_EQ(curent_bin_size, common::kUint32BinsTypeSize);
common::Span<uint32_t> index_data_span = {index.data<uint32_t>(),
n_index};
common::Span<uint32_t> index_data_span = {index.data<uint32_t>(), n_index};
SetIndexData(index_data_span, ft, batch_threads, batch, rbegin, nbins,
[offsets](auto idx, auto j) {
return static_cast<uint32_t>(idx - offsets[j]);
@ -125,8 +136,8 @@ void GHistIndexMatrix::PushBatch(SparsePage const &batch,
});
}
void GHistIndexMatrix::Init(DMatrix *p_fmat, int max_bins, bool sorted_sketch, int32_t n_threads,
common::Span<float> hess) {
void GHistIndexMatrix::Init(DMatrix *p_fmat, int max_bins, double sparse_thresh, bool sorted_sketch,
int32_t n_threads, common::Span<float> hess) {
// We use sorted sketching for approx tree method since it's more efficient in
// computation time (but higher memory usage).
cut = common::SketchOnDMatrix(p_fmat, max_bins, n_threads, sorted_sketch, hess);
@ -158,11 +169,9 @@ void GHistIndexMatrix::Init(DMatrix *p_fmat, int max_bins, bool sorted_sketch, i
}
}
void GHistIndexMatrix::Init(SparsePage const &batch,
common::Span<FeatureType const> ft,
common::HistogramCuts const &cuts,
int32_t max_bins_per_feat, bool isDense,
int32_t n_threads) {
void GHistIndexMatrix::Init(SparsePage const &batch, common::Span<FeatureType const> ft,
common::HistogramCuts const &cuts, int32_t max_bins_per_feat,
bool isDense, double sparse_thresh, int32_t n_threads) {
CHECK_GE(n_threads, 1);
base_rowid = batch.base_rowid;
isDense_ = isDense;
@ -183,13 +192,13 @@ void GHistIndexMatrix::Init(SparsePage const &batch,
this->PushBatch(batch, ft, rbegin, prev_sum, nbins, n_threads);
}
void GHistIndexMatrix::ResizeIndex(const size_t n_index,
const bool isDense) {
void GHistIndexMatrix::ResizeIndex(const size_t n_index, const bool isDense) {
if ((max_num_bins - 1 <= static_cast<int>(std::numeric_limits<uint8_t>::max())) && isDense) {
index.SetBinTypeSize(common::kUint8BinsTypeSize);
index.Resize((sizeof(uint8_t)) * n_index);
} else if ((max_num_bins - 1 > static_cast<int>(std::numeric_limits<uint8_t>::max()) &&
max_num_bins - 1 <= static_cast<int>(std::numeric_limits<uint16_t>::max())) && isDense) {
max_num_bins - 1 <= static_cast<int>(std::numeric_limits<uint16_t>::max())) &&
isDense) {
index.SetBinTypeSize(common::kUint16BinsTypeSize);
index.Resize((sizeof(uint16_t)) * n_index);
} else {

View File

@ -4,12 +4,14 @@
*/
#ifndef XGBOOST_DATA_GRADIENT_INDEX_H_
#define XGBOOST_DATA_GRADIENT_INDEX_H_
#include <memory>
#include <vector>
#include "xgboost/base.h"
#include "xgboost/data.h"
#include "../common/categorical.h"
#include "../common/hist_util.h"
#include "../common/threading_utils.h"
#include "xgboost/base.h"
#include "xgboost/data.h"
namespace xgboost {
/*!
@ -32,20 +34,22 @@ class GHistIndexMatrix {
/*! \brief The corresponding cuts */
common::HistogramCuts cut;
DMatrix* p_fmat;
/*! \brief max_bin for each feature. */
size_t max_num_bins;
/*! \brief base row index for current page (used by external memory) */
size_t base_rowid{0};
GHistIndexMatrix() = default;
GHistIndexMatrix(DMatrix* x, int32_t max_bin, bool sorted_sketch, int32_t n_threads,
common::Span<float> hess = {}) {
this->Init(x, max_bin, sorted_sketch, n_threads, hess);
}
GHistIndexMatrix();
GHistIndexMatrix(DMatrix* x, int32_t max_bin, double sparse_thresh, bool sorted_sketch,
int32_t n_threads, common::Span<float> hess = {});
~GHistIndexMatrix();
// Create a global histogram matrix, given cut
void Init(DMatrix* p_fmat, int max_num_bins, bool sorted_sketch, int32_t n_threads,
common::Span<float> hess);
void Init(DMatrix* p_fmat, int max_bins, double sparse_thresh, bool sorted_sketch,
int32_t n_threads, common::Span<float> hess);
void Init(SparsePage const& page, common::Span<FeatureType const> ft,
common::HistogramCuts const& cuts, int32_t max_bins_per_feat, bool is_dense,
int32_t n_threads);
double sparse_thresh, int32_t n_threads);
// specific method for sparse data as no possibility to reduce allocated memory
template <typename BinIdxType, typename GetOffset>
@ -74,7 +78,7 @@ class GHistIndexMatrix {
index_data[ibegin + j] = get_offset(bin_idx, j);
++hit_count_tloc_[tid * nbins + bin_idx];
} else {
uint32_t idx = cut.SearchBin(inst[j].fvalue, inst[j].index, ptrs, values);
uint32_t idx = cut.SearchBin(e.fvalue, e.index, ptrs, values);
index_data[ibegin + j] = get_offset(idx, j);
++hit_count_tloc_[tid * nbins + idx];
}
@ -82,10 +86,9 @@ class GHistIndexMatrix {
});
}
void ResizeIndex(const size_t n_index,
const bool isDense);
void ResizeIndex(const size_t n_index, const bool isDense);
inline void GetFeatureCounts(size_t* counts) const {
void GetFeatureCounts(size_t* counts) const {
auto nfeature = cut.Ptrs().size() - 1;
for (unsigned fid = 0; fid < nfeature; ++fid) {
auto ibegin = cut.Ptrs()[fid];
@ -95,7 +98,8 @@ class GHistIndexMatrix {
}
}
}
inline bool IsDense() const {
bool IsDense() const {
return isDense_;
}
void SetDense(bool is_dense) { isDense_ = is_dense; }
@ -105,6 +109,8 @@ class GHistIndexMatrix {
}
private:
// unused at the moment: https://github.com/dmlc/xgboost/pull/7531
std::unique_ptr<common::ColumnMatrix> columns_;
std::vector<size_t> hit_count_tloc_;
bool isDense_;
};
@ -117,7 +123,19 @@ class GHistIndexMatrix {
*/
inline bool RegenGHist(BatchParam old, BatchParam p) {
// parameter is renewed or caller requests a regen
return p.regen || (old.gpu_id != p.gpu_id || old.max_bin != p.max_bin);
if (p == BatchParam{}) {
// empty parameter is passed in, don't regenerate so that we can use gindex in
// predictor, which doesn't have any training parameter.
return false;
}
// Avoid comparing nan values.
bool l_nan = std::isnan(old.sparse_thresh);
bool r_nan = std::isnan(p.sparse_thresh);
// regenerate if parameter is changed.
bool st_chg = (l_nan != r_nan) || (!l_nan && !r_nan && (old.sparse_thresh != p.sparse_thresh));
bool param_chg = old.gpu_id != p.gpu_id || old.max_bin != p.max_bin;
return p.regen || param_chg || st_chg;
}
} // namespace xgboost
#endif // XGBOOST_DATA_GRADIENT_INDEX_H_

View File

@ -1,5 +1,5 @@
/*!
* Copyright 2021 by XGBoost Contributors
* Copyright 2021-2022 by XGBoost Contributors
*/
#include "gradient_index_page_source.h"
@ -11,7 +11,7 @@ void GradientIndexPageSource::Fetch() {
this->page_.reset(new GHistIndexMatrix());
CHECK_NE(cuts_.Values().size(), 0);
this->page_->Init(*csr, feature_types_, cuts_, max_bin_per_feat_, is_dense_,
nthreads_);
sparse_thresh_, nthreads_);
this->WriteCache();
}
}

View File

@ -1,5 +1,5 @@
/*!
* Copyright 2021 by XGBoost Contributors
* Copyright 2021-2022 by XGBoost Contributors
*/
#ifndef XGBOOST_DATA_GRADIENT_INDEX_PAGE_SOURCE_H_
#define XGBOOST_DATA_GRADIENT_INDEX_PAGE_SOURCE_H_
@ -7,8 +7,8 @@
#include <memory>
#include <utility>
#include "sparse_page_source.h"
#include "gradient_index.h"
#include "sparse_page_source.h"
namespace xgboost {
namespace data {
@ -17,17 +17,20 @@ class GradientIndexPageSource : public PageSourceIncMixIn<GHistIndexMatrix> {
bool is_dense_;
int32_t max_bin_per_feat_;
common::Span<FeatureType const> feature_types_;
double sparse_thresh_;
public:
GradientIndexPageSource(float missing, int nthreads, bst_feature_t n_features,
size_t n_batches, std::shared_ptr<Cache> cache,
BatchParam param, common::HistogramCuts cuts,
bool is_dense, int32_t max_bin_per_feat,
GradientIndexPageSource(float missing, int nthreads, bst_feature_t n_features, size_t n_batches,
std::shared_ptr<Cache> cache, BatchParam param,
common::HistogramCuts cuts, bool is_dense, int32_t max_bin_per_feat,
common::Span<FeatureType const> feature_types,
std::shared_ptr<SparsePageSource> source)
: PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache),
cuts_{std::move(cuts)}, is_dense_{is_dense},
max_bin_per_feat_{max_bin_per_feat}, feature_types_{feature_types} {
cuts_{std::move(cuts)},
is_dense_{is_dense},
max_bin_per_feat_{max_bin_per_feat},
feature_types_{feature_types},
sparse_thresh_{param.sparse_thresh} {
this->source_ = source;
this->Fetch();
}

View File

@ -74,12 +74,18 @@ BatchSet<SortedCSCPage> SimpleDMatrix::GetSortedColumnBatches() {
return BatchSet<SortedCSCPage>(begin_iter);
}
namespace {
void CheckEmpty(BatchParam const& l, BatchParam const& r) {
if (l == BatchParam{}) {
CHECK(r != BatchParam{}) << "Batch parameter is not initialized.";
}
}
} // anonymous namespace
BatchSet<EllpackPage> SimpleDMatrix::GetEllpackBatches(const BatchParam& param) {
// ELLPACK page doesn't exist, generate it
if (!(batch_param_ != BatchParam{})) {
CHECK(param != BatchParam{}) << "Batch parameter is not initialized.";
}
if (!ellpack_page_ || (batch_param_ != param && param != BatchParam{})) {
CheckEmpty(batch_param_, param);
if (!ellpack_page_ || RegenGHist(batch_param_, param)) {
CHECK_GE(param.gpu_id, 0);
CHECK_GE(param.max_bin, 2);
ellpack_page_.reset(new EllpackPage(this, param));
@ -91,17 +97,15 @@ BatchSet<EllpackPage> SimpleDMatrix::GetEllpackBatches(const BatchParam& param)
}
BatchSet<GHistIndexMatrix> SimpleDMatrix::GetGradientIndex(const BatchParam& param) {
if (!(batch_param_ != BatchParam{})) {
CHECK(param != BatchParam{}) << "Batch parameter is not initialized.";
}
CheckEmpty(batch_param_, param);
if (!gradient_index_ || RegenGHist(batch_param_, param)) {
LOG(INFO) << "Generating new Gradient Index.";
CHECK_GE(param.max_bin, 2);
CHECK_EQ(param.gpu_id, -1);
// Used only by approx.
auto sorted_sketch = param.regen;
gradient_index_.reset(
new GHistIndexMatrix(this, param.max_bin, sorted_sketch, this->ctx_.Threads(), param.hess));
gradient_index_.reset(new GHistIndexMatrix(this, param.max_bin, param.sparse_thresh,
sorted_sketch, this->ctx_.Threads(), param.hess));
batch_param_ = param;
CHECK_EQ(batch_param_.hess.data(), param.hess.data());
}

View File

@ -39,7 +39,7 @@ class SimpleDMatrix : public DMatrix {
/*! \brief magic number used to identify SimpleDMatrix binary files */
static const int kMagic = 0xffffab01;
private:
protected:
BatchSet<SparsePage> GetRowBatches() override;
BatchSet<CSCPage> GetColumnBatches() override;
BatchSet<SortedCSCPage> GetSortedColumnBatches() override;

View File

@ -164,8 +164,8 @@ BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(const BatchParam
// all index here.
if (!ghist_index_page_ || (param != batch_param_ && param != BatchParam{})) {
this->InitializeSparsePage();
ghist_index_page_.reset(
new GHistIndexMatrix{this, param.max_bin, param.regen, ctx_.Threads()});
ghist_index_page_.reset(new GHistIndexMatrix{this, param.max_bin, param.sparse_thresh,
param.regen, ctx_.Threads()});
this->InitializeSparsePage();
batch_param_ = param;
}

View File

@ -708,7 +708,7 @@ class GPUPredictor : public xgboost::Predictor {
}
} else {
size_t batch_offset = 0;
for (auto const& page : dmat->GetBatches<EllpackPage>()) {
for (auto const& page : dmat->GetBatches<EllpackPage>(BatchParam{})) {
dmat->Info().feature_types.SetDevice(ctx_->gpu_id);
auto feature_types = dmat->Info().feature_types.ConstDeviceSpan();
this->PredictInternal(
@ -984,7 +984,7 @@ class GPUPredictor : public xgboost::Predictor {
batch_offset += batch.Size();
}
} else {
for (auto const& batch : p_fmat->GetBatches<EllpackPage>()) {
for (auto const& batch : p_fmat->GetBatches<EllpackPage>(BatchParam{})) {
bst_row_t batch_offset = 0;
EllpackDeviceAccessor data{batch.Impl()->GetDeviceAccessor(ctx_->gpu_id)};
size_t num_rows = batch.Size();

View File

@ -31,11 +31,11 @@ namespace {
template <typename GradientSumT>
auto BatchSpec(TrainParam const &p, common::Span<float> hess,
HistEvaluator<GradientSumT, CPUExpandEntry> const &evaluator) {
return BatchParam{GenericParameter::kCpuId, p.max_bin, hess, !evaluator.Task().const_hess};
return BatchParam{p.max_bin, hess, !evaluator.Task().const_hess};
}
auto BatchSpec(TrainParam const &p, common::Span<float> hess) {
return BatchParam{GenericParameter::kCpuId, p.max_bin, hess, false};
return BatchParam{p.max_bin, hess, false};
}
} // anonymous namespace

View File

@ -68,9 +68,7 @@ void QuantileHistMaker::CallBuilderUpdate(const std::unique_ptr<Builder<Gradient
void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair,
DMatrix *dmat,
const std::vector<RegTree *> &trees) {
auto it = dmat->GetBatches<GHistIndexMatrix>(
BatchParam{GenericParameter::kCpuId, param_.max_bin})
.begin();
auto it = dmat->GetBatches<GHistIndexMatrix>(HistBatch(param_)).begin();
auto p_gmat = it.Page();
if (dmat != p_last_dmat_ || is_gmat_initialized_ == false) {
updater_monitor_.Start("GmatInitialization");
@ -127,8 +125,8 @@ void QuantileHistMaker::Builder<GradientSumT>::InitRoot(
nodes_for_explicit_hist_build_.push_back(node);
size_t page_id = 0;
for (auto const &gidx : p_fmat->GetBatches<GHistIndexMatrix>(
{GenericParameter::kCpuId, param_.max_bin})) {
for (auto const& gidx :
p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
this->histogram_builder_->BuildHist(
page_id, gidx, p_tree, row_set_collection_,
nodes_for_explicit_hist_build_, nodes_for_subtraction_trick_, gpair_h);
@ -141,10 +139,7 @@ void QuantileHistMaker::Builder<GradientSumT>::InitRoot(
GradientPairT grad_stat;
if (data_layout_ == DataLayout::kDenseDataZeroBased ||
data_layout_ == DataLayout::kDenseDataOneBased) {
auto const &gmat = *(p_fmat
->GetBatches<GHistIndexMatrix>(BatchParam{
GenericParameter::kCpuId, param_.max_bin})
.begin());
auto const& gmat = *(p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_)).begin());
const std::vector<uint32_t> &row_ptr = gmat.cut.Ptrs();
const uint32_t ibegin = row_ptr[fid_least_bins_];
const uint32_t iend = row_ptr[fid_least_bins_ + 1];
@ -170,8 +165,7 @@ void QuantileHistMaker::Builder<GradientSumT>::InitRoot(
std::vector<CPUExpandEntry> entries{node};
builder_monitor_.Start("EvaluateSplits");
auto ft = p_fmat->Info().feature_types.ConstHostSpan();
for (auto const &gmat : p_fmat->GetBatches<GHistIndexMatrix>(
BatchParam{GenericParameter::kCpuId, param_.max_bin})) {
for (auto const& gmat : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
evaluator_->EvaluateSplits(histogram_builder_->Histogram(), gmat.cut, ft,
*p_tree, &entries);
break;
@ -264,8 +258,7 @@ void QuantileHistMaker::Builder<GradientSumT>::ExpandTree(
if (param_.max_depth == 0 || depth < param_.max_depth) {
size_t i = 0;
for (auto const &gidx : p_fmat->GetBatches<GHistIndexMatrix>(
{GenericParameter::kCpuId, param_.max_bin})) {
for (auto const& gidx : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
this->histogram_builder_->BuildHist(
i, gidx, p_tree, row_set_collection_,
nodes_for_explicit_hist_build_, nodes_for_subtraction_trick_,

View File

@ -92,6 +92,10 @@ using xgboost::common::GHistBuilder;
using xgboost::common::ColumnMatrix;
using xgboost::common::Column;
inline BatchParam HistBatch(TrainParam const& param) {
return {param.max_bin, param.sparse_threshold};
}
/*! \brief construct a tree using quantized feature values */
class QuantileHistMaker: public TreeUpdater {
public:

View File

@ -12,12 +12,14 @@ namespace xgboost {
namespace common {
TEST(DenseColumn, Test) {
uint64_t max_num_bins[] = {static_cast<uint64_t>(std::numeric_limits<uint8_t>::max()) + 1,
static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 1,
static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 2};
for (size_t max_num_bin : max_num_bins) {
int32_t max_num_bins[] = {static_cast<int32_t>(std::numeric_limits<uint8_t>::max()) + 1,
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 1,
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 2};
for (int32_t max_num_bin : max_num_bins) {
auto dmat = RandomDataGenerator(100, 10, 0.0).GenerateDMatrix();
GHistIndexMatrix gmat(dmat.get(), max_num_bin, false, common::OmpGetNumThreads(0));
auto sparse_thresh = 0.2;
GHistIndexMatrix gmat{dmat.get(), max_num_bin, sparse_thresh, false,
common::OmpGetNumThreads(0)};
ColumnMatrix column_matrix;
column_matrix.Init(gmat, 0.2, common::OmpGetNumThreads(0));
@ -59,12 +61,12 @@ inline void CheckSparseColumn(const Column<BinIdxType>& col_input, const GHistIn
}
TEST(SparseColumn, Test) {
uint64_t max_num_bins[] = {static_cast<uint64_t>(std::numeric_limits<uint8_t>::max()) + 1,
static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 1,
static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 2};
for (size_t max_num_bin : max_num_bins) {
int32_t max_num_bins[] = {static_cast<int32_t>(std::numeric_limits<uint8_t>::max()) + 1,
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 1,
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 2};
for (int32_t max_num_bin : max_num_bins) {
auto dmat = RandomDataGenerator(100, 1, 0.85).GenerateDMatrix();
GHistIndexMatrix gmat(dmat.get(), max_num_bin, false, common::OmpGetNumThreads(0));
GHistIndexMatrix gmat{dmat.get(), max_num_bin, 0.5f, false, common::OmpGetNumThreads(0)};
ColumnMatrix column_matrix;
column_matrix.Init(gmat, 0.5, common::OmpGetNumThreads(0));
switch (column_matrix.GetTypeSize()) {
@ -99,12 +101,12 @@ inline void CheckColumWithMissingValue(const Column<BinIdxType>& col_input,
}
TEST(DenseColumnWithMissing, Test) {
uint64_t max_num_bins[] = { static_cast<uint64_t>(std::numeric_limits<uint8_t>::max()) + 1,
static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 1,
static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 2 };
for (size_t max_num_bin : max_num_bins) {
int32_t max_num_bins[] = {static_cast<int32_t>(std::numeric_limits<uint8_t>::max()) + 1,
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 1,
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 2};
for (int32_t max_num_bin : max_num_bins) {
auto dmat = RandomDataGenerator(100, 1, 0.5).GenerateDMatrix();
GHistIndexMatrix gmat(dmat.get(), max_num_bin, false, common::OmpGetNumThreads(0));
GHistIndexMatrix gmat{dmat.get(), max_num_bin, 0.2, false, common::OmpGetNumThreads(0)};
ColumnMatrix column_matrix;
column_matrix.Init(gmat, 0.2, common::OmpGetNumThreads(0));
switch (column_matrix.GetTypeSize()) {
@ -131,9 +133,8 @@ void TestGHistIndexMatrixCreation(size_t nthreads) {
size_t constexpr kPageSize = 1024, kEntriesPerCol = 3;
size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2;
/* This should create multiple sparse pages */
std::unique_ptr<DMatrix> dmat{ CreateSparsePageDMatrix(kEntries) };
omp_set_num_threads(nthreads);
GHistIndexMatrix gmat(dmat.get(), 256, false, common::OmpGetNumThreads(0));
std::unique_ptr<DMatrix> dmat{CreateSparsePageDMatrix(kEntries)};
GHistIndexMatrix gmat(dmat.get(), 256, 0.5f, false, common::OmpGetNumThreads(nthreads));
}
TEST(HistIndexCreationWithExternalMemory, Test) {

View File

@ -299,7 +299,7 @@ TEST(HistUtil, IndexBinBound) {
for (auto max_bin : bin_sizes) {
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
GHistIndexMatrix hmat(p_fmat.get(), max_bin, false, common::OmpGetNumThreads(0));
GHistIndexMatrix hmat(p_fmat.get(), max_bin, 0.5, false, common::OmpGetNumThreads(0));
EXPECT_EQ(hmat.index.Size(), kRows*kCols);
EXPECT_EQ(expected_bin_type_sizes[bin_id++], hmat.index.GetBinTypeSize());
}
@ -322,7 +322,7 @@ TEST(HistUtil, IndexBinData) {
for (auto max_bin : kBinSizes) {
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
GHistIndexMatrix hmat(p_fmat.get(), max_bin, false, common::OmpGetNumThreads(0));
GHistIndexMatrix hmat(p_fmat.get(), max_bin, 0.5, false, common::OmpGetNumThreads(0));
uint32_t* offsets = hmat.index.Offset();
EXPECT_EQ(hmat.index.Size(), kRows*kCols);
switch (max_bin) {

View File

@ -81,13 +81,13 @@ TEST(EllpackPage, BuildGidxSparse) {
TEST(EllpackPage, FromCategoricalBasic) {
using common::AsCat;
size_t constexpr kRows = 1000, kCats = 13, kCols = 1;
size_t max_bins = 8;
int32_t max_bins = 8;
auto x = GenerateRandomCategoricalSingleColumn(kRows, kCats);
auto m = GetDMatrixFromData(x, kRows, 1);
auto& h_ft = m->Info().feature_types.HostVector();
h_ft.resize(kCols, FeatureType::kCategorical);
BatchParam p(0, max_bins);
BatchParam p{0, max_bins};
auto ellpack = EllpackPage(m.get(), p);
auto accessor = ellpack.Impl()->GetDeviceAccessor(0);
ASSERT_EQ(kCats, accessor.NumBins());

View File

@ -4,8 +4,8 @@
#include <gtest/gtest.h>
#include <xgboost/data.h>
#include "../helpers.h"
#include "../../../src/data/gradient_index.h"
#include "../helpers.h"
namespace xgboost {
namespace data {
@ -13,12 +13,22 @@ TEST(GradientIndex, ExternalMemory) {
std::unique_ptr<DMatrix> dmat = CreateSparsePageDMatrix(10000);
std::vector<size_t> base_rowids;
std::vector<float> hessian(dmat->Info().num_row_, 1);
for (auto const &page : dmat->GetBatches<GHistIndexMatrix>(
{GenericParameter::kCpuId, 64, hessian})) {
for (auto const &page : dmat->GetBatches<GHistIndexMatrix>({64, hessian, true})) {
base_rowids.push_back(page.base_rowid);
}
size_t i = 0;
for (auto const& page : dmat->GetBatches<SparsePage>()) {
for (auto const &page : dmat->GetBatches<SparsePage>()) {
ASSERT_EQ(base_rowids[i], page.base_rowid);
++i;
}
base_rowids.clear();
for (auto const &page : dmat->GetBatches<GHistIndexMatrix>({64, hessian, false})) {
base_rowids.push_back(page.base_rowid);
}
i = 0;
for (auto const &page : dmat->GetBatches<SparsePage>()) {
ASSERT_EQ(base_rowids[i], page.base_rowid);
++i;
}
@ -33,10 +43,10 @@ TEST(GradientIndex, FromCategoricalBasic) {
auto &h_ft = m->Info().feature_types.HostVector();
h_ft.resize(kCols, FeatureType::kCategorical);
BatchParam p(0, max_bins);
BatchParam p(max_bins, 0.8);
GHistIndexMatrix gidx;
gidx.Init(m.get(), max_bins, false, common::OmpGetNumThreads(0), {});
gidx.Init(m.get(), max_bins, p.sparse_thresh, false, common::OmpGetNumThreads(0), {});
auto x_copy = x;
std::sort(x_copy.begin(), x_copy.end());

View File

@ -21,7 +21,7 @@ void TestEquivalent(float sparsity) {
std::unique_ptr<EllpackPageImpl> page_concatenated {
new EllpackPageImpl(0, first->Cuts(), first->is_dense,
first->row_stride, 1000 * 100)};
for (auto& batch : m.GetBatches<EllpackPage>()) {
for (auto& batch : m.GetBatches<EllpackPage>({})) {
auto page = batch.Impl();
size_t num_elements = page_concatenated->Copy(0, page, offset);
offset += num_elements;
@ -93,7 +93,7 @@ TEST(IterativeDeviceDMatrix, RowMajor) {
0, 256);
size_t n_batches = 0;
std::string interface_str = iter.AsArray();
for (auto& ellpack : m.GetBatches<EllpackPage>()) {
for (auto& ellpack : m.GetBatches<EllpackPage>({})) {
n_batches ++;
auto impl = ellpack.Impl();
common::CompressedIterator<uint32_t> iterator(

View File

@ -68,6 +68,7 @@ TEST(GPUPredictor, EllpackBasic) {
.Bins(bins)
.Device(0)
.GenerateDeviceDMatrix(true);
ASSERT_FALSE(p_m->PageExists<SparsePage>());
TestPredictionFromGradientIndex<EllpackPage>("gpu_predictor", rows, kCols, p_m);
TestPredictionFromGradientIndex<EllpackPage>("gpu_predictor", bins, kCols, p_m);
}

View File

@ -31,8 +31,7 @@ template <typename GradientSumT> void TestEvaluateSplits() {
size_t constexpr kMaxBins = 4;
// dense, no missing values
GHistIndexMatrix gmat(dmat.get(), kMaxBins, false, common::OmpGetNumThreads(0));
GHistIndexMatrix gmat(dmat.get(), kMaxBins, 0.5, false, common::OmpGetNumThreads(0));
common::RowSetCollection row_set_collection;
std::vector<size_t> &row_indices = *row_set_collection.Data();
row_indices.resize(kRows);
@ -127,7 +126,7 @@ TEST(HistEvaluator, CategoricalPartition) {
auto evaluator = HistEvaluator<GradientSumT, CPUExpandEntry>{
param, dmat->Info(), n_threads, sampler, ObjInfo{ObjInfo::kRegression}};
for (auto const &gmat : dmat->GetBatches<GHistIndexMatrix>({GenericParameter::kCpuId, 32})) {
for (auto const &gmat : dmat->GetBatches<GHistIndexMatrix>({32, param.sparse_threshold})) {
common::HistCollection<GradientSumT> hist;
std::vector<CPUExpandEntry> entries(1);
@ -212,7 +211,7 @@ auto CompareOneHotAndPartition(bool onehot) {
param, dmat->Info(), n_threads, sampler, ObjInfo{ObjInfo::kRegression}};
std::vector<CPUExpandEntry> entries(1);
for (auto const &gmat : dmat->GetBatches<GHistIndexMatrix>({GenericParameter::kCpuId, 32})) {
for (auto const &gmat : dmat->GetBatches<GHistIndexMatrix>({32, param.sparse_threshold})) {
common::HistCollection<GradientSumT> hist;
entries.front().nid = 0;

View File

@ -1,20 +1,20 @@
/*!
* Copyright 2018-2021 by Contributors
* Copyright 2018-2022 by Contributors
*/
#include <gtest/gtest.h>
#include "../../helpers.h"
#include "../../categorical_helpers.h"
#include <limits>
#include "../../../../src/common/categorical.h"
#include "../../../../src/tree/hist/histogram.h"
#include "../../../../src/tree/updater_quantile_hist.h"
#include "../../categorical_helpers.h"
#include "../../helpers.h"
namespace xgboost {
namespace tree {
namespace {
void InitRowPartitionForTest(RowSetCollection *row_set, size_t n_samples,
size_t base_rowid = 0) {
void InitRowPartitionForTest(RowSetCollection *row_set, size_t n_samples, size_t base_rowid = 0) {
auto &row_indices = *row_set->Data();
row_indices.resize(n_samples);
std::iota(row_indices.begin(), row_indices.end(), base_rowid);
@ -33,10 +33,7 @@ void TestAddHistRows(bool is_distributed) {
int32_t constexpr kMaxBins = 4;
auto p_fmat =
RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix();
auto const &gmat = *(p_fmat
->GetBatches<GHistIndexMatrix>(
BatchParam{GenericParameter::kCpuId, kMaxBins})
.begin());
auto const &gmat = *(p_fmat->GetBatches<GHistIndexMatrix>(BatchParam{kMaxBins, 0.5}).begin());
RegTree tree;
@ -49,9 +46,8 @@ void TestAddHistRows(bool is_distributed) {
nodes_for_subtraction_trick_.emplace_back(6, tree.GetDepth(6), 0.0f);
HistogramBuilder<GradientSumT, CPUExpandEntry> histogram_builder;
histogram_builder.Reset(gmat.cut.TotalBins(),
{GenericParameter::kCpuId, kMaxBins},
omp_get_max_threads(), 1, is_distributed);
histogram_builder.Reset(gmat.cut.TotalBins(), {kMaxBins, 0.5}, omp_get_max_threads(), 1,
is_distributed);
histogram_builder.AddHistRows(&starting_index, &sync_count,
nodes_for_explicit_hist_build_,
nodes_for_subtraction_trick_, &tree);
@ -89,15 +85,11 @@ void TestSyncHist(bool is_distributed) {
auto p_fmat =
RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix();
auto const &gmat = *(p_fmat
->GetBatches<GHistIndexMatrix>(
BatchParam{GenericParameter::kCpuId, kMaxBins})
.begin());
auto const &gmat = *(p_fmat->GetBatches<GHistIndexMatrix>(BatchParam{kMaxBins, 0.5}).begin());
HistogramBuilder<GradientSumT, CPUExpandEntry> histogram;
uint32_t total_bins = gmat.cut.Ptrs().back();
histogram.Reset(total_bins, {GenericParameter::kCpuId, kMaxBins},
omp_get_max_threads(), 1, is_distributed);
histogram.Reset(total_bins, {kMaxBins, 0.5}, omp_get_max_threads(), 1, is_distributed);
RowSetCollection row_set_collection_;
{
@ -250,10 +242,7 @@ void TestBuildHistogram(bool is_distributed) {
int32_t constexpr kMaxBins = 4;
auto p_fmat =
RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix();
auto const &gmat = *(p_fmat
->GetBatches<GHistIndexMatrix>(
BatchParam{GenericParameter::kCpuId, kMaxBins})
.begin());
auto const &gmat = *(p_fmat->GetBatches<GHistIndexMatrix>(BatchParam{kMaxBins, 0.5}).begin());
uint32_t total_bins = gmat.cut.Ptrs().back();
static double constexpr kEps = 1e-6;
@ -263,8 +252,7 @@ void TestBuildHistogram(bool is_distributed) {
bst_node_t nid = 0;
HistogramBuilder<GradientSumT, CPUExpandEntry> histogram;
histogram.Reset(total_bins, {GenericParameter::kCpuId, kMaxBins},
omp_get_max_threads(), 1, is_distributed);
histogram.Reset(total_bins, {kMaxBins, 0.5}, omp_get_max_threads(), 1, is_distributed);
RegTree tree;
@ -278,8 +266,7 @@ void TestBuildHistogram(bool is_distributed) {
CPUExpandEntry node(RegTree::kRoot, tree.GetDepth(0), 0.0f);
std::vector<CPUExpandEntry> nodes_for_explicit_hist_build;
nodes_for_explicit_hist_build.push_back(node);
for (auto const &gidx : p_fmat->GetBatches<GHistIndexMatrix>(
{GenericParameter::kCpuId, kMaxBins})) {
for (auto const &gidx : p_fmat->GetBatches<GHistIndexMatrix>({kMaxBins, 0.5})) {
histogram.BuildHist(0, gidx, &tree, row_set_collection,
nodes_for_explicit_hist_build, {}, gpair);
}
@ -342,11 +329,9 @@ void TestHistogramCategorical(size_t n_categories) {
* Generate hist with cat data.
*/
HistogramBuilder<double, CPUExpandEntry> cat_hist;
for (auto const &gidx : cat_m->GetBatches<GHistIndexMatrix>(
{GenericParameter::kCpuId, kBins})) {
for (auto const &gidx : cat_m->GetBatches<GHistIndexMatrix>({kBins, 0.5})) {
auto total_bins = gidx.cut.TotalBins();
cat_hist.Reset(total_bins, {GenericParameter::kCpuId, kBins},
omp_get_max_threads(), 1, false);
cat_hist.Reset(total_bins, {kBins, 0.5}, omp_get_max_threads(), 1, false);
cat_hist.BuildHist(0, gidx, &tree, row_set_collection,
nodes_for_explicit_hist_build, {}, gpair.HostVector());
}
@ -357,13 +342,10 @@ void TestHistogramCategorical(size_t n_categories) {
auto x_encoded = OneHotEncodeFeature(x, n_categories);
auto encode_m = GetDMatrixFromData(x_encoded, kRows, n_categories);
HistogramBuilder<double, CPUExpandEntry> onehot_hist;
for (auto const &gidx : encode_m->GetBatches<GHistIndexMatrix>(
{GenericParameter::kCpuId, kBins})) {
for (auto const &gidx : encode_m->GetBatches<GHistIndexMatrix>({kBins, 0.5})) {
auto total_bins = gidx.cut.TotalBins();
onehot_hist.Reset(total_bins, {GenericParameter::kCpuId, kBins},
omp_get_max_threads(), 1, false);
onehot_hist.BuildHist(0, gidx, &tree, row_set_collection,
nodes_for_explicit_hist_build, {},
onehot_hist.Reset(total_bins, {kBins, 0.5}, omp_get_max_threads(), 1, false);
onehot_hist.BuildHist(0, gidx, &tree, row_set_collection, nodes_for_explicit_hist_build, {},
gpair.HostVector());
}
@ -378,11 +360,16 @@ TEST(CPUHistogram, Categorical) {
TestHistogramCategorical(n_categories);
}
}
TEST(CPUHistogram, ExternalMemory) {
namespace {
void TestHistogramExternalMemory(BatchParam batch_param, bool is_approx) {
size_t constexpr kEntries = 1 << 16;
int32_t constexpr kBins = 32;
auto m = CreateSparsePageDMatrix(kEntries, "cache");
std::vector<float> hess(m->Info().num_row_, 1.0);
if (is_approx) {
batch_param.hess = hess;
}
std::vector<size_t> partition_size(1, 0);
size_t total_bins{0};
size_t n_samples{0};
@ -401,9 +388,7 @@ TEST(CPUHistogram, ExternalMemory) {
* Multi page
*/
std::vector<RowSetCollection> rows_set;
std::vector<float> hess(m->Info().num_row_, 1.0);
for (auto const &page : m->GetBatches<GHistIndexMatrix>(
{GenericParameter::kCpuId, kBins, hess})) {
for (auto const &page : m->GetBatches<GHistIndexMatrix>(batch_param)) {
CHECK_LT(page.base_rowid, m->Info().num_row_);
auto n_rows_in_node = page.Size();
partition_size[0] = std::max(partition_size[0], n_rows_in_node);
@ -419,12 +404,10 @@ TEST(CPUHistogram, ExternalMemory) {
1, [&](size_t nidx_in_set) { return partition_size.at(nidx_in_set); },
256};
multi_build.Reset(total_bins, {GenericParameter::kCpuId, kBins},
omp_get_max_threads(), rows_set.size(), false);
multi_build.Reset(total_bins, batch_param, common::OmpGetNumThreads(0), rows_set.size(), false);
size_t page_idx{0};
for (auto const &page : m->GetBatches<GHistIndexMatrix>(
{GenericParameter::kCpuId, kBins, hess})) {
for (auto const &page : m->GetBatches<GHistIndexMatrix>(batch_param)) {
multi_build.BuildHist(page_idx, space, page, &tree, rows_set.at(page_idx), nodes, {},
h_gpair);
++page_idx;
@ -442,16 +425,13 @@ TEST(CPUHistogram, ExternalMemory) {
RowSetCollection row_set_collection;
InitRowPartitionForTest(&row_set_collection, n_samples);
single_build.Reset(total_bins, {GenericParameter::kCpuId, kBins},
omp_get_max_threads(), 1, false);
size_t n_batches{0};
for (auto const &page :
m->GetBatches<GHistIndexMatrix>({GenericParameter::kCpuId, kBins})) {
single_build.BuildHist(0, page, &tree, row_set_collection, nodes, {},
h_gpair);
n_batches ++;
}
ASSERT_EQ(n_batches, 1);
single_build.Reset(total_bins, batch_param, common::OmpGetNumThreads(0), 1, false);
SparsePage concat;
GHistIndexMatrix gmat;
std::vector<float> hess(m->Info().num_row_, 1.0f);
gmat.Init(m.get(), batch_param.max_bin, std::numeric_limits<double>::quiet_NaN(), false,
common::OmpGetNumThreads(0), hess);
single_build.BuildHist(0, gmat, &tree, row_set_collection, nodes, {}, h_gpair);
single_page = single_build.Histogram()[0];
}
@ -460,5 +440,11 @@ TEST(CPUHistogram, ExternalMemory) {
ASSERT_NEAR(single_page[i].GetHess(), multi_page[i].GetHess(), kRtEps);
}
}
} // anonymous namespace
TEST(CPUHistogram, ExternalMemory) {
int32_t constexpr kBins = 256;
TestHistogramExternalMemory(BatchParam{kBins, common::Span<float>{}, false}, true);
}
} // namespace tree
} // namespace xgboost

View File

@ -8,6 +8,19 @@
namespace xgboost {
namespace tree {
namespace {
void GetSplit(RegTree *tree, float split_value, std::vector<CPUExpandEntry> *candidates) {
tree->ExpandNode(
/*nid=*/RegTree::kRoot, /*split_index=*/0, /*split_value=*/split_value,
/*default_left=*/true, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
/*left_sum=*/0.0f,
/*right_sum=*/0.0f);
candidates->front().split.split_value = split_value;
candidates->front().split.sindex = 0;
candidates->front().split.sindex |= (1U << 31);
}
} // anonymous namespace
TEST(Approx, Partitioner) {
size_t n_samples = 1024, n_features = 1, base_rowid = 0;
ApproxRowPartitioner partitioner{n_samples, base_rowid};
@ -20,20 +33,18 @@ TEST(Approx, Partitioner) {
ctx.InitAllowUnknown(Args{});
std::vector<CPUExpandEntry> candidates{{0, 0, 0.4}};
for (auto const &page : Xy->GetBatches<GHistIndexMatrix>({GenericParameter::kCpuId, 64})) {
bst_feature_t split_ind = 0;
auto grad = GenerateRandomGradients(n_samples);
std::vector<float> hess(grad.Size());
std::transform(grad.HostVector().cbegin(), grad.HostVector().cend(), hess.begin(),
[](auto gpair) { return gpair.GetHess(); });
for (auto const &page : Xy->GetBatches<GHistIndexMatrix>({64, hess, true})) {
bst_feature_t const split_ind = 0;
{
auto min_value = page.cut.MinValues()[split_ind];
RegTree tree;
tree.ExpandNode(
/*nid=*/0, /*split_index=*/0, /*split_value=*/min_value,
/*default_left=*/true, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
/*left_sum=*/0.0f,
/*right_sum=*/0.0f);
ApproxRowPartitioner partitioner{n_samples, base_rowid};
candidates.front().split.split_value = min_value;
candidates.front().split.sindex = 0;
candidates.front().split.sindex |= (1U << 31);
GetSplit(&tree, min_value, &candidates);
partitioner.UpdatePosition(&ctx, page, candidates, &tree);
ASSERT_EQ(partitioner.Size(), 3);
ASSERT_EQ(partitioner[1].Size(), 0);
@ -44,16 +55,8 @@ TEST(Approx, Partitioner) {
auto ptr = page.cut.Ptrs()[split_ind + 1];
float split_value = page.cut.Values().at(ptr / 2);
RegTree tree;
tree.ExpandNode(
/*nid=*/RegTree::kRoot, /*split_index=*/split_ind,
/*split_value=*/split_value,
/*default_left=*/true, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
/*left_sum=*/0.0f,
/*right_sum=*/0.0f);
GetSplit(&tree, split_value, &candidates);
auto left_nidx = tree[RegTree::kRoot].LeftChild();
candidates.front().split.split_value = split_value;
candidates.front().split.sindex = 0;
candidates.front().split.sindex |= (1U << 31);
partitioner.UpdatePosition(&ctx, page, candidates, &tree);
auto elem = partitioner[left_nidx];

View File

@ -155,18 +155,19 @@ class QuantileHistMock : public QuantileHistMaker {
std::vector<GradientPair> row_gpairs =
{ {1.23f, 0.24f}, {0.24f, 0.25f}, {0.26f, 0.27f}, {2.27f, 0.28f},
{0.27f, 0.29f}, {0.37f, 0.39f}, {-0.47f, 0.49f}, {0.57f, 0.59f} };
size_t constexpr kMaxBins = 4;
int32_t constexpr kMaxBins = 4;
// try out different sparsity to get different number of missing values
for (double sparsity : {0.0, 0.1, 0.2}) {
// kNRows samples with kNCols features
auto dmat = RandomDataGenerator(kNRows, kNCols, sparsity).Seed(3).GenerateDMatrix();
GHistIndexMatrix gmat(dmat.get(), kMaxBins, false, common::OmpGetNumThreads(0));
float sparse_th = 0.0;
GHistIndexMatrix gmat{dmat.get(), kMaxBins, sparse_th, false, common::OmpGetNumThreads(0)};
ColumnMatrix cm;
// treat everything as dense, as this is what we intend to test here
cm.Init(gmat, 0.0, common::OmpGetNumThreads(0));
cm.Init(gmat, sparse_th, common::OmpGetNumThreads(0));
RealImpl::InitData(gmat, *dmat, tree, &row_gpairs);
const size_t num_row = dmat->Info().num_row_;
// split by feature 0
@ -247,8 +248,8 @@ class QuantileHistMock : public QuantileHistMaker {
static size_t GetNumColumns() { return kNCols; }
void TestInitData() {
size_t constexpr kMaxBins = 4;
GHistIndexMatrix gmat(dmat_.get(), kMaxBins, false, common::OmpGetNumThreads(0));
int32_t constexpr kMaxBins = 4;
GHistIndexMatrix gmat{dmat_.get(), kMaxBins, 0.0f, false, common::OmpGetNumThreads(0)};
RegTree tree = RegTree();
tree.param.UpdateAllowUnknown(cfg_);
@ -264,8 +265,8 @@ class QuantileHistMock : public QuantileHistMaker {
}
void TestInitDataSampling() {
size_t constexpr kMaxBins = 4;
GHistIndexMatrix gmat(dmat_.get(), kMaxBins, false, common::OmpGetNumThreads(0));
int32_t constexpr kMaxBins = 4;
GHistIndexMatrix gmat{dmat_.get(), kMaxBins, 0.0f, false, common::OmpGetNumThreads(0)};
RegTree tree = RegTree();
tree.param.UpdateAllowUnknown(cfg_);

View File

@ -0,0 +1,124 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#include <gtest/gtest.h>
#include "../../../src/data/adapter.h"
#include "../../../src/data/simple_dmatrix.h"
#include "../helpers.h"
namespace xgboost {
namespace {
class DMatrixForTest : public data::SimpleDMatrix {
size_t n_regen_{0};
public:
using SimpleDMatrix::SimpleDMatrix;
BatchSet<GHistIndexMatrix> GetGradientIndex(const BatchParam& param) override {
auto backup = this->gradient_index_;
auto iter = SimpleDMatrix::GetGradientIndex(param);
n_regen_ += (backup != this->gradient_index_);
return iter;
}
BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) override {
auto backup = this->ellpack_page_;
auto iter = SimpleDMatrix::GetEllpackBatches(param);
n_regen_ += (backup != this->ellpack_page_);
return iter;
}
auto NumRegen() const { return n_regen_; }
void Reset() {
this->gradient_index_.reset();
this->ellpack_page_.reset();
n_regen_ = 0;
}
};
/**
* \brief Test for whether the gradient index is correctly regenerated.
*/
class RegenTest : public ::testing::Test {
protected:
std::shared_ptr<DMatrix> p_fmat_;
void SetUp() override {
size_t constexpr kRows = 256, kCols = 10;
HostDeviceVector<float> storage;
auto dense = RandomDataGenerator{kRows, kCols, 0.5}.GenerateArrayInterface(&storage);
auto adapter = data::ArrayAdapter(StringView{dense});
p_fmat_ = std::shared_ptr<DMatrix>(new DMatrixForTest{
&adapter, std::numeric_limits<float>::quiet_NaN(), common::OmpGetNumThreads(0)});
p_fmat_->Info().labels.Reshape(256, 1);
auto labels = p_fmat_->Info().labels.Data();
RandomDataGenerator{kRows, 1, 0}.GenerateDense(labels);
}
auto constexpr Iter() const { return 4; }
template <typename Page>
size_t TestTreeMethod(std::string tree_method, std::string obj, bool reset = true) const {
auto learner = std::unique_ptr<Learner>{Learner::Create({p_fmat_})};
learner->SetParam("tree_method", tree_method);
learner->SetParam("objective", obj);
learner->Configure();
for (auto i = 0; i < Iter(); ++i) {
learner->UpdateOneIter(i, p_fmat_);
}
auto for_test = dynamic_cast<DMatrixForTest*>(p_fmat_.get());
CHECK(for_test);
auto backup = for_test->NumRegen();
for_test->GetBatches<Page>(BatchParam{});
CHECK_EQ(for_test->NumRegen(), backup);
if (reset) {
for_test->Reset();
}
return backup;
}
};
} // anonymous namespace
TEST_F(RegenTest, Approx) {
auto n = this->TestTreeMethod<GHistIndexMatrix>("approx", "reg:squarederror");
ASSERT_EQ(n, 1);
n = this->TestTreeMethod<GHistIndexMatrix>("approx", "reg:logistic");
ASSERT_EQ(n, this->Iter());
}
TEST_F(RegenTest, Hist) {
auto n = this->TestTreeMethod<GHistIndexMatrix>("hist", "reg:squarederror");
ASSERT_EQ(n, 1);
n = this->TestTreeMethod<GHistIndexMatrix>("hist", "reg:logistic");
ASSERT_EQ(n, 1);
}
TEST_F(RegenTest, Mixed) {
auto n = this->TestTreeMethod<GHistIndexMatrix>("hist", "reg:squarederror", false);
ASSERT_EQ(n, 1);
n = this->TestTreeMethod<GHistIndexMatrix>("approx", "reg:logistic", true);
ASSERT_EQ(n, this->Iter() + 1);
n = this->TestTreeMethod<GHistIndexMatrix>("approx", "reg:logistic", false);
ASSERT_EQ(n, this->Iter());
n = this->TestTreeMethod<GHistIndexMatrix>("hist", "reg:squarederror", true);
ASSERT_EQ(n, this->Iter() + 1);
}
#if defined(XGBOOST_USE_CUDA)
TEST_F(RegenTest, GpuHist) {
auto n = this->TestTreeMethod<EllpackPage>("gpu_hist", "reg:squarederror");
ASSERT_EQ(n, 1);
n = this->TestTreeMethod<EllpackPage>("gpu_hist", "reg:logistic", false);
ASSERT_EQ(n, 1);
n = this->TestTreeMethod<EllpackPage>("hist", "reg:logistic");
ASSERT_EQ(n, 2);
}
#endif // defined(XGBOOST_USE_CUDA)
} // namespace xgboost