Reduce 'InitSampling' complexity and set gradients to zero (#6922)

Co-authored-by: Kirill Shvets <kirill.shvets@intel.com>
This commit is contained in:
ShvetsKS 2021-05-28 23:52:23 +03:00 committed by GitHub
parent 89a49cf30e
commit 55b823b27d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 189 additions and 197 deletions

View File

@ -48,9 +48,11 @@ void QuantileHistMaker::Configure(const Args& args) {
}
template<typename GradientSumT>
void QuantileHistMaker::SetBuilder(std::unique_ptr<Builder<GradientSumT>>* builder,
void QuantileHistMaker::SetBuilder(const size_t n_trees,
std::unique_ptr<Builder<GradientSumT>>* builder,
DMatrix *dmat) {
builder->reset(new Builder<GradientSumT>(
n_trees,
param_,
std::move(pruner_),
int_constraint_, dmat));
@ -92,14 +94,15 @@ void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair,
param_.learning_rate = lr / trees.size();
int_constraint_.Configure(param_, dmat->Info().num_col_);
// build tree
const size_t n_trees = trees.size();
if (hist_maker_param_.single_precision_histogram) {
if (!float_builder_) {
SetBuilder(&float_builder_, dmat);
SetBuilder(n_trees, &float_builder_, dmat);
}
CallBuilderUpdate(float_builder_, gpair, dmat, trees);
} else {
if (!double_builder_) {
SetBuilder(&double_builder_, dmat);
SetBuilder(n_trees, &double_builder_, dmat);
}
CallBuilderUpdate(double_builder_, gpair, dmat, trees);
}
@ -545,7 +548,6 @@ void QuantileHistMaker::Builder<GradientSumT>::ExpandWithLossGuide(
e.best.DefaultLeft(), e.weight, left_leaf_weight,
right_leaf_weight, e.best.loss_chg, e.stats.GetHess(),
e.best.left_sum.GetHess(), e.best.right_sum.GetHess());
this->ApplySplit({candidate}, gmat, column_matrix, hist_, p_tree);
const int cleft = (*p_tree)[nid].LeftChild();
@ -589,18 +591,23 @@ void QuantileHistMaker::Builder<GradientSumT>::Update(
DMatrix *p_fmat, RegTree *p_tree) {
builder_monitor_.Start("Update");
const std::vector<GradientPair>& gpair_h = gpair->ConstHostVector();
std::vector<GradientPair>* gpair_ptr = &(gpair->HostVector());
// in case 'num_parallel_trees != 1' no posibility to change initial gpair
if (GetNumberOfTrees() != 1) {
gpair_local_.resize(gpair_ptr->size());
gpair_local_ = *gpair_ptr;
gpair_ptr = &gpair_local_;
}
tree_evaluator_ =
TreeEvaluator(param_, p_fmat->Info().num_col_, GenericParameter::kCpuId);
interaction_constraints_.Reset();
p_last_fmat_mutable_ = p_fmat;
this->InitData(gmat, gpair_h, *p_fmat, *p_tree);
this->InitData(gmat, *p_fmat, *p_tree, gpair_ptr);
if (param_.grow_policy == TrainParam::kLossGuide) {
ExpandWithLossGuide(gmat, gmatb, column_matrix, p_fmat, p_tree, gpair_h);
ExpandWithLossGuide(gmat, gmatb, column_matrix, p_fmat, p_tree, *gpair_ptr);
} else {
ExpandWithDepthWise(gmat, gmatb, column_matrix, p_fmat, p_tree, gpair_h);
ExpandWithDepthWise(gmat, gmatb, column_matrix, p_fmat, p_tree, *gpair_ptr);
}
for (int nid = 0; nid < p_tree->param.num_nodes; ++nid) {
@ -654,69 +661,33 @@ bool QuantileHistMaker::Builder<GradientSumT>::UpdatePredictionCache(
}
});
if (param_.subsample < 1.0f) {
// Making a real prediction for the remaining rows
size_t fvecs_size = feat_vecs_.size();
feat_vecs_.resize(omp_get_max_threads(), RegTree::FVec());
while (fvecs_size < feat_vecs_.size()) {
feat_vecs_[fvecs_size++].Init(data->Info().num_col_);
}
for (auto&& batch : p_last_fmat_mutable_->GetBatches<SparsePage>()) {
HostSparsePageView page_view = batch.GetView();
const auto num_parallel_ops = static_cast<bst_omp_uint>(unused_rows_.size());
common::ParallelFor(num_parallel_ops, [&](bst_omp_uint block_id) {
RegTree::FVec &feats = feat_vecs_[omp_get_thread_num()];
const SparsePage::Inst inst = page_view[unused_rows_[block_id]];
feats.Fill(inst);
const size_t row_num = unused_rows_[block_id] + batch.base_rowid;
const int lid = feats.HasMissing() ? p_last_tree_->GetLeafIndex<true>(feats) :
p_last_tree_->GetLeafIndex<false>(feats);
out_preds[row_num] += (*p_last_tree_)[lid].LeafValue();
feats.Drop(inst);
});
}
}
builder_monitor_.Stop("UpdatePredictionCache");
return true;
}
template<typename GradientSumT>
void QuantileHistMaker::Builder<GradientSumT>::InitSampling(const std::vector<GradientPair>& gpair,
const DMatrix& fmat,
void QuantileHistMaker::Builder<GradientSumT>::InitSampling(const DMatrix& fmat,
std::vector<GradientPair>* gpair,
std::vector<size_t>* row_indices) {
const auto& info = fmat.Info();
auto& rnd = common::GlobalRandom();
unused_rows_.resize(info.num_row_);
size_t* p_row_indices_used = row_indices->data();
size_t* p_row_indices_unused = unused_rows_.data();
std::vector<GradientPair>& gpair_ref = *gpair;
#if XGBOOST_CUSTOMIZE_GLOBAL_PRNG
std::bernoulli_distribution coin_flip(param_.subsample);
size_t used = 0, unused = 0;
for (size_t i = 0; i < info.num_row_; ++i) {
if (gpair[i].GetHess() >= 0.0f && coin_flip(rnd)) {
p_row_indices_used[used++] = i;
} else {
p_row_indices_unused[unused++] = i;
if (!(gpair_ref[i].GetHess() >= 0.0f && coin_flip(rnd)) || gpair_ref[i].GetGrad() == 0.0f) {
gpair_ref[i] = GradientPair(0);
}
}
/* resize row_indices to reduce memory */
row_indices->resize(used);
unused_rows_.resize(unused);
#else
const size_t nthread = this->nthread_;
std::vector<size_t> row_offsets_used(nthread, 0);
std::vector<size_t> row_offsets_unused(nthread, 0);
/* usage of mt19937_64 give 2x speed up for subsampling */
std::vector<std::mt19937> rnds(nthread);
/* create engine for each thread */
for (std::mt19937& r : rnds) {
r = rnd;
}
uint64_t initial_seed = rnd();
const size_t discard_size = info.num_row_ / nthread;
auto upper_border = static_cast<float>(std::numeric_limits<uint32_t>::max());
uint32_t coin_flip_border = static_cast<uint32_t>(upper_border * param_.subsample);
std::bernoulli_distribution coin_flip(param_.subsample);
dmlc::OMPException exc;
#pragma omp parallel num_threads(nthread)
{
@ -725,60 +696,24 @@ void QuantileHistMaker::Builder<GradientSumT>::InitSampling(const std::vector<Gr
const size_t ibegin = tid * discard_size;
const size_t iend = (tid == (nthread - 1)) ?
info.num_row_ : ibegin + discard_size;
rnds[tid].discard(discard_size * tid);
for (size_t i = ibegin; i < iend; ++i) {
if (gpair[i].GetHess() >= 0.0f && rnds[tid]() < coin_flip_border) {
p_row_indices_used[ibegin + row_offsets_used[tid]++] = i;
} else {
p_row_indices_unused[ibegin + row_offsets_unused[tid]++] = i;
}
}
#pragma omp barrier
if (tid == 0ul) {
size_t prefix_sum_used = row_offsets_used[0];
for (size_t i = 1; i < nthread; ++i) {
const size_t ibegin = i * discard_size;
for (size_t k = 0; k < row_offsets_used[i]; ++k) {
p_row_indices_used[prefix_sum_used + k] = p_row_indices_used[ibegin + k];
}
prefix_sum_used += row_offsets_used[i];
}
/* resize row_indices to reduce memory */
row_indices->resize(prefix_sum_used);
}
if (nthread == 1ul || tid == 1ul) {
size_t prefix_sum_unused = row_offsets_unused[0];
for (size_t i = 1; i < nthread; ++i) {
const size_t ibegin = i * discard_size;
for (size_t k = 0; k < row_offsets_unused[i]; ++k) {
p_row_indices_unused[prefix_sum_unused + k] = p_row_indices_unused[ibegin + k];
}
prefix_sum_unused += row_offsets_unused[i];
}
/* resize row_indices to reduce memory */
unused_rows_.resize(prefix_sum_unused);
}
RandomReplace::MakeIf([&](size_t i, RandomReplace::EngineT& eng) {
return !(gpair_ref[i].GetHess() >= 0.0f && coin_flip(eng));
}, GradientPair(0), initial_seed, ibegin, iend, &gpair_ref);
});
}
exc.Rethrow();
/* discard global engine */
rnd = rnds[nthread - 1];
#endif // XGBOOST_CUSTOMIZE_GLOBAL_PRNG
}
template<typename GradientSumT>
size_t QuantileHistMaker::Builder<GradientSumT>::GetNumberOfTrees() {
return n_trees_;
}
template<typename GradientSumT>
void QuantileHistMaker::Builder<GradientSumT>::InitData(const GHistIndexMatrix& gmat,
const std::vector<GradientPair>& gpair,
const DMatrix& fmat,
const RegTree& tree) {
const RegTree& tree,
std::vector<GradientPair>* gpair) {
CHECK((param_.max_depth > 0 || param_.max_leaves > 0))
<< "max_depth or max_leaves cannot be both 0 (unlimited); "
<< "at least one should be a positive quantity.";
@ -818,17 +753,53 @@ void QuantileHistMaker::Builder<GradientSumT>::InitData(const GHistIndexMatrix&
CHECK_EQ(param_.sampling_method, TrainParam::kUniform)
<< "Only uniform sampling is supported, "
<< "gradient-based sampling is only support by GPU Hist.";
InitSampling(gpair, fmat, &row_indices);
builder_monitor_.Start("InitSampling");
InitSampling(fmat, gpair, &row_indices);
builder_monitor_.Stop("InitSampling");
CHECK_EQ(row_indices.size(), info.num_row_);
// We should check that the partitioning was done correctly
// and each row of the dataset fell into exactly one of the categories
CHECK_EQ(row_indices.size() + unused_rows_.size(), info.num_row_);
}
MemStackAllocator<bool, 128> buff(this->nthread_);
bool* p_buff = buff.Get();
std::fill(p_buff, p_buff + this->nthread_, false);
const size_t block_size = info.num_row_ / this->nthread_ + !!(info.num_row_ % this->nthread_);
#pragma omp parallel num_threads(this->nthread_)
{
exc.Run([&]() {
const size_t tid = omp_get_thread_num();
const size_t ibegin = tid * block_size;
const size_t iend = std::min(static_cast<size_t>(ibegin + block_size),
static_cast<size_t>(info.num_row_));
for (size_t i = ibegin; i < iend; ++i) {
if ((*gpair)[i].GetHess() < 0.0f) {
p_buff[tid] = true;
break;
}
}
});
}
exc.Rethrow();
bool has_neg_hess = false;
for (int32_t tid = 0; tid < this->nthread_; ++tid) {
if (p_buff[tid]) {
has_neg_hess = true;
}
}
if (has_neg_hess) {
size_t j = 0;
for (size_t i = 0; i < info.num_row_; ++i) {
if ((*gpair)[i].GetHess() >= 0.0f) {
p_row_indices[j++] = i;
}
}
row_indices.resize(j);
} else {
MemStackAllocator<bool, 128> buff(this->nthread_);
bool* p_buff = buff.Get();
std::fill(p_buff, p_buff + this->nthread_, false);
const size_t block_size = info.num_row_ / this->nthread_ + !!(info.num_row_ % this->nthread_);
#pragma omp parallel num_threads(this->nthread_)
{
exc.Run([&]() {
@ -836,47 +807,12 @@ void QuantileHistMaker::Builder<GradientSumT>::InitData(const GHistIndexMatrix&
const size_t ibegin = tid * block_size;
const size_t iend = std::min(static_cast<size_t>(ibegin + block_size),
static_cast<size_t>(info.num_row_));
for (size_t i = ibegin; i < iend; ++i) {
if (gpair[i].GetHess() < 0.0f) {
p_buff[tid] = true;
break;
}
p_row_indices[i] = i;
}
});
}
exc.Rethrow();
bool has_neg_hess = false;
for (int32_t tid = 0; tid < this->nthread_; ++tid) {
if (p_buff[tid]) {
has_neg_hess = true;
}
}
if (has_neg_hess) {
size_t j = 0;
for (size_t i = 0; i < info.num_row_; ++i) {
if (gpair[i].GetHess() >= 0.0f) {
p_row_indices[j++] = i;
}
}
row_indices.resize(j);
} else {
#pragma omp parallel num_threads(this->nthread_)
{
exc.Run([&]() {
const size_t tid = omp_get_thread_num();
const size_t ibegin = tid * block_size;
const size_t iend = std::min(static_cast<size_t>(ibegin + block_size),
static_cast<size_t>(info.num_row_));
for (size_t i = ibegin; i < iend; ++i) {
p_row_indices[i] = i;
}
});
}
exc.Rethrow();
}
}
}
@ -1074,9 +1010,9 @@ inline std::pair<size_t, size_t> PartitionDenseKernel(const common::DenseColumn<
// Handle sparse columns
template<bool default_left, typename BinIdxType>
inline std::pair<size_t, size_t> PartitionSparseKernel(
const common::SparseColumn<BinIdxType>& column,
common::Span<const size_t> rid_span, const int32_t split_cond,
const common::SparseColumn<BinIdxType>& column, common::Span<size_t> left_part,
common::Span<size_t> right_part) {
common::Span<size_t> left_part, common::Span<size_t> right_part) {
size_t* p_left_part = left_part.data();
size_t* p_right_part = right_part.data();
@ -1131,7 +1067,7 @@ inline std::pair<size_t, size_t> PartitionSparseKernel(
template <typename GradientSumT>
template <typename BinIdxType>
void QuantileHistMaker::Builder<GradientSumT>::PartitionKernel(
const size_t node_in_set, const size_t nid, common::Range1d range,
const size_t node_in_set, const size_t nid, const common::Range1d range,
const int32_t split_cond, const ColumnMatrix& column_matrix, const RegTree& tree) {
const size_t* rid = row_set_collection_[nid].begin;
@ -1151,28 +1087,30 @@ void QuantileHistMaker::Builder<GradientSumT>::PartitionKernel(
static_cast<const common::DenseColumn<BinIdxType>& >(*(column_ptr.get()));
if (default_left) {
if (column_matrix.AnyMissing()) {
child_nodes_sizes = PartitionDenseKernel<true, true>(column, rid_span, split_cond,
left, right);
child_nodes_sizes = PartitionDenseKernel<true, true>(column, rid_span,
split_cond, left, right);
} else {
child_nodes_sizes = PartitionDenseKernel<true, false>(column, rid_span, split_cond,
left, right);
child_nodes_sizes = PartitionDenseKernel<true, false>(column, rid_span,
split_cond, left, right);
}
} else {
if (column_matrix.AnyMissing()) {
child_nodes_sizes = PartitionDenseKernel<false, true>(column, rid_span, split_cond,
left, right);
child_nodes_sizes = PartitionDenseKernel<false, true>(column, rid_span,
split_cond, left, right);
} else {
child_nodes_sizes = PartitionDenseKernel<false, false>(column, rid_span, split_cond,
left, right);
child_nodes_sizes = PartitionDenseKernel<false, false>(column, rid_span,
split_cond, left, right);
}
}
} else {
const common::SparseColumn<BinIdxType>& column
= static_cast<const common::SparseColumn<BinIdxType>& >(*(column_ptr.get()));
if (default_left) {
child_nodes_sizes = PartitionSparseKernel<true>(rid_span, split_cond, column, left, right);
child_nodes_sizes = PartitionSparseKernel<true>(column, rid_span,
split_cond, left, right);
} else {
child_nodes_sizes = PartitionSparseKernel<false>(rid_span, split_cond, column, left, right);
child_nodes_sizes = PartitionSparseKernel<false>(column, rid_span,
split_cond, left, right);
}
}

View File

@ -31,6 +31,48 @@
namespace xgboost {
struct RandomReplace {
public:
// similar value as for minstd_rand
static constexpr uint64_t kBase = 16807;
static constexpr uint64_t kMod = static_cast<uint64_t>(1) << 63;
using EngineT = std::linear_congruential_engine<uint64_t, kBase, 0, kMod>;
/*
Right-to-left binary method: https://en.wikipedia.org/wiki/Modular_exponentiation
*/
static uint64_t SimpleSkip(uint64_t exponent, uint64_t initial_seed,
uint64_t base, uint64_t mod) {
CHECK_LE(exponent, mod);
uint64_t result = 1;
while (exponent > 0) {
if (exponent % 2 == 1) {
result = (result * base) % mod;
}
base = (base * base) % mod;
exponent = exponent >> 1;
}
// with result we can now find the new seed
return (result * initial_seed) % mod;
}
template<typename Condition, typename ContainerData>
static void MakeIf(Condition condition, const typename ContainerData::value_type replace_value,
const uint64_t initial_seed, const size_t ibegin,
const size_t iend, ContainerData* gpair) {
ContainerData& gpair_ref = *gpair;
const uint64_t displaced_seed = SimpleSkip(ibegin, initial_seed, kBase, kMod);
EngineT eng(displaced_seed);
for (size_t i = ibegin; i < iend; ++i) {
if (condition(i, eng)) {
gpair_ref[i] = replace_value;
}
}
}
};
/*!
* \brief A C-style array with in-stack allocation. As long as the array is smaller than MaxStackSize, it will be allocated inside the stack. Otherwise, it will be heap-allocated.
*/
@ -201,11 +243,13 @@ class QuantileHistMaker: public TreeUpdater {
using GHistRowT = GHistRow<GradientSumT>;
using GradientPairT = xgboost::detail::GradientPairInternal<GradientSumT>;
// constructor
explicit Builder(const TrainParam& param,
explicit Builder(const size_t n_trees,
const TrainParam& param,
std::unique_ptr<TreeUpdater> pruner,
FeatureInteractionConstraintHost int_constraints_,
DMatrix const* fmat)
: param_(param),
: n_trees_(n_trees),
param_(param),
tree_evaluator_(param, fmat->Info().num_col_, GenericParameter::kCpuId),
pruner_(std::move(pruner)),
interaction_constraints_{std::move(int_constraints_)},
@ -279,12 +323,15 @@ class QuantileHistMaker: public TreeUpdater {
// initialize temp data structure
void InitData(const GHistIndexMatrix& gmat,
const std::vector<GradientPair>& gpair,
const DMatrix& fmat,
const RegTree& tree);
const RegTree& tree,
std::vector<GradientPair>* gpair);
void InitSampling(const std::vector<GradientPair>& gpair,
const DMatrix& fmat, std::vector<size_t>* row_indices);
size_t GetNumberOfTrees();
void InitSampling(const DMatrix& fmat,
std::vector<GradientPair>* gpair,
std::vector<size_t>* row_indices);
void EvaluateSplits(const std::vector<ExpandEntry>& nodes_set,
const GHistIndexMatrix& gmat,
@ -298,7 +345,7 @@ class QuantileHistMaker: public TreeUpdater {
RegTree* p_tree);
template <typename BinIdxType>
void PartitionKernel(const size_t node_in_set, const size_t nid, common::Range1d range,
void PartitionKernel(const size_t node_in_set, const size_t nid, const common::Range1d range,
const int32_t split_cond,
const ColumnMatrix& column_matrix, const RegTree& tree);
@ -398,6 +445,7 @@ class QuantileHistMaker: public TreeUpdater {
}
}
// --data fields--
const size_t n_trees_;
const TrainParam& param_;
// number of omp thread used during training
int nthread_;
@ -413,6 +461,7 @@ class QuantileHistMaker: public TreeUpdater {
std::vector<SplitEntry> best_split_tloc_;
/*! \brief TreeNode Data: statistics for each constructed node */
std::vector<NodeEntry> snode_;
std::vector<GradientPair> gpair_local_;
/*! \brief culmulative histogram of gradients. */
HistCollection<GradientSumT> hist_;
/*! \brief culmulative local parent histogram of gradients. */
@ -458,7 +507,7 @@ class QuantileHistMaker: public TreeUpdater {
common::Monitor updater_monitor_;
template<typename GradientSumT>
void SetBuilder(std::unique_ptr<Builder<GradientSumT>>*, DMatrix *dmat);
void SetBuilder(const size_t n_trees, std::unique_ptr<Builder<GradientSumT>>*, DMatrix *dmat);
template<typename GradientSumT>
void CallBuilderUpdate(const std::unique_ptr<Builder<GradientSumT>>& builder,

View File

@ -31,15 +31,15 @@ class QuantileHistMock : public QuantileHistMaker {
std::unique_ptr<TreeUpdater> pruner,
FeatureInteractionConstraintHost int_constraint,
DMatrix const* fmat)
: RealImpl(param, std::move(pruner),
: RealImpl(1, param, std::move(pruner),
std::move(int_constraint), fmat) {}
public:
void TestInitData(const GHistIndexMatrix& gmat,
const std::vector<GradientPair>& gpair,
std::vector<GradientPair>* gpair,
DMatrix* p_fmat,
const RegTree& tree) {
RealImpl::InitData(gmat, gpair, *p_fmat, tree);
RealImpl::InitData(gmat, *p_fmat, tree, gpair);
ASSERT_EQ(this->data_layout_, RealImpl::DataLayout::kSparseData);
/* The creation of HistCutMatrix and GHistIndexMatrix are not technically
@ -101,29 +101,34 @@ class QuantileHistMock : public QuantileHistMaker {
}
void TestInitDataSampling(const GHistIndexMatrix& gmat,
const std::vector<GradientPair>& gpair,
std::vector<GradientPair>* gpair,
DMatrix* p_fmat,
const RegTree& tree) {
// check SimpleSkip
size_t initial_seed = 777;
std::linear_congruential_engine<std::uint_fast64_t, 16807, 0,
static_cast<uint64_t>(1) << 63 > eng_first(initial_seed);
for (size_t i = 0; i < 100; ++i) {
eng_first();
}
uint64_t initial_seed_th = RandomReplace::SimpleSkip(100, initial_seed, 16807, RandomReplace::kMod);
std::linear_congruential_engine<std::uint_fast64_t, RandomReplace::kBase, 0,
RandomReplace::kMod > eng_second(initial_seed_th);
ASSERT_EQ(eng_first(), eng_second());
const size_t nthreads = omp_get_num_threads();
// save state of global rng engine
auto initial_rnd = common::GlobalRandom();
std::vector<size_t> unused_rows_cpy = this->unused_rows_;
RealImpl::InitData(gmat, gpair, *p_fmat, tree);
RealImpl::InitData(gmat, *p_fmat, tree, gpair);
std::vector<size_t> row_indices_initial = *(this->row_set_collection_.Data());
std::vector<size_t> unused_row_indices_initial = this->unused_rows_;
ASSERT_EQ(row_indices_initial.size(), p_fmat->Info().num_row_);
auto check_each_row_occurs_in_one_of_arrays = [](const std::vector<size_t>& first,
const std::vector<size_t>& second,
size_t nrows) {
std::vector<size_t> arr_union(nrows);
for (auto&& row_indice : first) {
++arr_union[row_indice];
}
for (auto&& row_indice : second) {
++arr_union[row_indice];
}
for (auto&& row_cnt : arr_union) {
ASSERT_EQ(row_cnt, 1ul);
}
ASSERT_EQ(first.size(), nrows);
ASSERT_EQ(second.size(), 0);
};
check_each_row_occurs_in_one_of_arrays(row_indices_initial, unused_row_indices_initial,
p_fmat->Info().num_row_);
@ -133,7 +138,7 @@ class QuantileHistMock : public QuantileHistMaker {
// return initial state of global rng engine
common::GlobalRandom() = initial_rnd;
this->unused_rows_ = unused_rows_cpy;
RealImpl::InitData(gmat, gpair, *p_fmat, tree);
RealImpl::InitData(gmat, *p_fmat, tree, gpair);
std::vector<size_t>& row_indices = *(this->row_set_collection_.Data());
ASSERT_EQ(row_indices_initial.size(), row_indices.size());
for (size_t i = 0; i < row_indices_initial.size(); ++i) {
@ -151,10 +156,10 @@ class QuantileHistMock : public QuantileHistMaker {
}
void TestAddHistRows(const GHistIndexMatrix& gmat,
const std::vector<GradientPair>& gpair,
std::vector<GradientPair>* gpair,
DMatrix* p_fmat,
RegTree* tree) {
RealImpl::InitData(gmat, gpair, *p_fmat, *tree);
RealImpl::InitData(gmat, *p_fmat, *tree, gpair);
int starting_index = std::numeric_limits<int>::max();
int sync_count = 0;
@ -183,11 +188,11 @@ class QuantileHistMock : public QuantileHistMaker {
void TestSyncHistograms(const GHistIndexMatrix& gmat,
const std::vector<GradientPair>& gpair,
std::vector<GradientPair>* gpair,
DMatrix* p_fmat,
RegTree* tree) {
// init
RealImpl::InitData(gmat, gpair, *p_fmat, *tree);
RealImpl::InitData(gmat, *p_fmat, *tree, gpair);
int starting_index = std::numeric_limits<int>::max();
int sync_count = 0;
@ -295,10 +300,10 @@ class QuantileHistMock : public QuantileHistMaker {
const GHistIndexMatrix& gmat,
const DMatrix& fmat,
const RegTree& tree) {
const std::vector<GradientPair> gpair =
std::vector<GradientPair> gpair =
{ {0.23f, 0.24f}, {0.24f, 0.25f}, {0.26f, 0.27f}, {0.27f, 0.28f},
{0.27f, 0.29f}, {0.37f, 0.39f}, {0.47f, 0.49f}, {0.57f, 0.59f} };
RealImpl::InitData(gmat, gpair, fmat, tree);
RealImpl::InitData(gmat, fmat, tree, &gpair);
GHistIndexBlockMatrix dummy;
this->hist_.AddHistRow(nid);
this->hist_.AllocateAllData();
@ -341,7 +346,7 @@ class QuantileHistMock : public QuantileHistMaker {
common::GHistIndexMatrix gmat;
gmat.Init(dmat.get(), kMaxBins);
RealImpl::InitData(gmat, row_gpairs, *dmat, tree);
RealImpl::InitData(gmat, *dmat, tree, &row_gpairs);
this->hist_.AddHistRow(0);
this->hist_.AllocateAllData();
this->BuildHist(row_gpairs, this->row_set_collection_[0],
@ -437,7 +442,7 @@ class QuantileHistMock : public QuantileHistMaker {
// treat everything as dense, as this is what we intend to test here
cm.Init(gmat, 0.0);
RealImpl::InitData(gmat, row_gpairs, *dmat, tree);
RealImpl::InitData(gmat, *dmat, tree, &row_gpairs);
this->hist_.AddHistRow(0);
this->hist_.AllocateAllData();
RealImpl::InitNewNode(0, gmat, row_gpairs, *dmat, tree);
@ -548,9 +553,9 @@ class QuantileHistMock : public QuantileHistMaker {
{ {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f},
{0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} };
if (double_builder_) {
double_builder_->TestInitData(gmat, gpair, dmat_.get(), tree);
double_builder_->TestInitData(gmat, &gpair, dmat_.get(), tree);
} else {
float_builder_->TestInitData(gmat, gpair, dmat_.get(), tree);
float_builder_->TestInitData(gmat, &gpair, dmat_.get(), tree);
}
}
@ -566,9 +571,9 @@ class QuantileHistMock : public QuantileHistMaker {
{ {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f},
{0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} };
if (double_builder_) {
double_builder_->TestInitDataSampling(gmat, gpair, dmat_.get(), tree);
double_builder_->TestInitDataSampling(gmat, &gpair, dmat_.get(), tree);
} else {
float_builder_->TestInitDataSampling(gmat, gpair, dmat_.get(), tree);
float_builder_->TestInitDataSampling(gmat, &gpair, dmat_.get(), tree);
}
}
@ -583,9 +588,9 @@ class QuantileHistMock : public QuantileHistMaker {
{ {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f},
{0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} };
if (double_builder_) {
double_builder_->TestAddHistRows(gmat, gpair, dmat_.get(), &tree);
double_builder_->TestAddHistRows(gmat, &gpair, dmat_.get(), &tree);
} else {
float_builder_->TestAddHistRows(gmat, gpair, dmat_.get(), &tree);
float_builder_->TestAddHistRows(gmat, &gpair, dmat_.get(), &tree);
}
}
@ -600,9 +605,9 @@ class QuantileHistMock : public QuantileHistMaker {
{ {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f},
{0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} };
if (double_builder_) {
double_builder_->TestSyncHistograms(gmat, gpair, dmat_.get(), &tree);
double_builder_->TestSyncHistograms(gmat, &gpair, dmat_.get(), &tree);
} else {
float_builder_->TestSyncHistograms(gmat, gpair, dmat_.get(), &tree);
float_builder_->TestSyncHistograms(gmat, &gpair, dmat_.get(), &tree);
}
}