Reduce 'InitSampling' complexity and set gradients to zero (#6922)
Co-authored-by: Kirill Shvets <kirill.shvets@intel.com>
This commit is contained in:
parent
89a49cf30e
commit
55b823b27d
@ -48,9 +48,11 @@ void QuantileHistMaker::Configure(const Args& args) {
|
||||
}
|
||||
|
||||
template<typename GradientSumT>
|
||||
void QuantileHistMaker::SetBuilder(std::unique_ptr<Builder<GradientSumT>>* builder,
|
||||
void QuantileHistMaker::SetBuilder(const size_t n_trees,
|
||||
std::unique_ptr<Builder<GradientSumT>>* builder,
|
||||
DMatrix *dmat) {
|
||||
builder->reset(new Builder<GradientSumT>(
|
||||
n_trees,
|
||||
param_,
|
||||
std::move(pruner_),
|
||||
int_constraint_, dmat));
|
||||
@ -92,14 +94,15 @@ void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair,
|
||||
param_.learning_rate = lr / trees.size();
|
||||
int_constraint_.Configure(param_, dmat->Info().num_col_);
|
||||
// build tree
|
||||
const size_t n_trees = trees.size();
|
||||
if (hist_maker_param_.single_precision_histogram) {
|
||||
if (!float_builder_) {
|
||||
SetBuilder(&float_builder_, dmat);
|
||||
SetBuilder(n_trees, &float_builder_, dmat);
|
||||
}
|
||||
CallBuilderUpdate(float_builder_, gpair, dmat, trees);
|
||||
} else {
|
||||
if (!double_builder_) {
|
||||
SetBuilder(&double_builder_, dmat);
|
||||
SetBuilder(n_trees, &double_builder_, dmat);
|
||||
}
|
||||
CallBuilderUpdate(double_builder_, gpair, dmat, trees);
|
||||
}
|
||||
@ -545,7 +548,6 @@ void QuantileHistMaker::Builder<GradientSumT>::ExpandWithLossGuide(
|
||||
e.best.DefaultLeft(), e.weight, left_leaf_weight,
|
||||
right_leaf_weight, e.best.loss_chg, e.stats.GetHess(),
|
||||
e.best.left_sum.GetHess(), e.best.right_sum.GetHess());
|
||||
|
||||
this->ApplySplit({candidate}, gmat, column_matrix, hist_, p_tree);
|
||||
|
||||
const int cleft = (*p_tree)[nid].LeftChild();
|
||||
@ -589,18 +591,23 @@ void QuantileHistMaker::Builder<GradientSumT>::Update(
|
||||
DMatrix *p_fmat, RegTree *p_tree) {
|
||||
builder_monitor_.Start("Update");
|
||||
|
||||
const std::vector<GradientPair>& gpair_h = gpair->ConstHostVector();
|
||||
|
||||
std::vector<GradientPair>* gpair_ptr = &(gpair->HostVector());
|
||||
// in case 'num_parallel_trees != 1' no posibility to change initial gpair
|
||||
if (GetNumberOfTrees() != 1) {
|
||||
gpair_local_.resize(gpair_ptr->size());
|
||||
gpair_local_ = *gpair_ptr;
|
||||
gpair_ptr = &gpair_local_;
|
||||
}
|
||||
tree_evaluator_ =
|
||||
TreeEvaluator(param_, p_fmat->Info().num_col_, GenericParameter::kCpuId);
|
||||
interaction_constraints_.Reset();
|
||||
p_last_fmat_mutable_ = p_fmat;
|
||||
|
||||
this->InitData(gmat, gpair_h, *p_fmat, *p_tree);
|
||||
this->InitData(gmat, *p_fmat, *p_tree, gpair_ptr);
|
||||
if (param_.grow_policy == TrainParam::kLossGuide) {
|
||||
ExpandWithLossGuide(gmat, gmatb, column_matrix, p_fmat, p_tree, gpair_h);
|
||||
ExpandWithLossGuide(gmat, gmatb, column_matrix, p_fmat, p_tree, *gpair_ptr);
|
||||
} else {
|
||||
ExpandWithDepthWise(gmat, gmatb, column_matrix, p_fmat, p_tree, gpair_h);
|
||||
ExpandWithDepthWise(gmat, gmatb, column_matrix, p_fmat, p_tree, *gpair_ptr);
|
||||
}
|
||||
|
||||
for (int nid = 0; nid < p_tree->param.num_nodes; ++nid) {
|
||||
@ -654,69 +661,33 @@ bool QuantileHistMaker::Builder<GradientSumT>::UpdatePredictionCache(
|
||||
}
|
||||
});
|
||||
|
||||
if (param_.subsample < 1.0f) {
|
||||
// Making a real prediction for the remaining rows
|
||||
size_t fvecs_size = feat_vecs_.size();
|
||||
feat_vecs_.resize(omp_get_max_threads(), RegTree::FVec());
|
||||
while (fvecs_size < feat_vecs_.size()) {
|
||||
feat_vecs_[fvecs_size++].Init(data->Info().num_col_);
|
||||
}
|
||||
for (auto&& batch : p_last_fmat_mutable_->GetBatches<SparsePage>()) {
|
||||
HostSparsePageView page_view = batch.GetView();
|
||||
const auto num_parallel_ops = static_cast<bst_omp_uint>(unused_rows_.size());
|
||||
common::ParallelFor(num_parallel_ops, [&](bst_omp_uint block_id) {
|
||||
RegTree::FVec &feats = feat_vecs_[omp_get_thread_num()];
|
||||
const SparsePage::Inst inst = page_view[unused_rows_[block_id]];
|
||||
feats.Fill(inst);
|
||||
|
||||
const size_t row_num = unused_rows_[block_id] + batch.base_rowid;
|
||||
const int lid = feats.HasMissing() ? p_last_tree_->GetLeafIndex<true>(feats) :
|
||||
p_last_tree_->GetLeafIndex<false>(feats);
|
||||
out_preds[row_num] += (*p_last_tree_)[lid].LeafValue();
|
||||
|
||||
feats.Drop(inst);
|
||||
});
|
||||
}
|
||||
}
|
||||
builder_monitor_.Stop("UpdatePredictionCache");
|
||||
return true;
|
||||
}
|
||||
|
||||
template<typename GradientSumT>
|
||||
void QuantileHistMaker::Builder<GradientSumT>::InitSampling(const std::vector<GradientPair>& gpair,
|
||||
const DMatrix& fmat,
|
||||
void QuantileHistMaker::Builder<GradientSumT>::InitSampling(const DMatrix& fmat,
|
||||
std::vector<GradientPair>* gpair,
|
||||
std::vector<size_t>* row_indices) {
|
||||
const auto& info = fmat.Info();
|
||||
auto& rnd = common::GlobalRandom();
|
||||
unused_rows_.resize(info.num_row_);
|
||||
size_t* p_row_indices_used = row_indices->data();
|
||||
size_t* p_row_indices_unused = unused_rows_.data();
|
||||
std::vector<GradientPair>& gpair_ref = *gpair;
|
||||
|
||||
#if XGBOOST_CUSTOMIZE_GLOBAL_PRNG
|
||||
std::bernoulli_distribution coin_flip(param_.subsample);
|
||||
size_t used = 0, unused = 0;
|
||||
for (size_t i = 0; i < info.num_row_; ++i) {
|
||||
if (gpair[i].GetHess() >= 0.0f && coin_flip(rnd)) {
|
||||
p_row_indices_used[used++] = i;
|
||||
} else {
|
||||
p_row_indices_unused[unused++] = i;
|
||||
if (!(gpair_ref[i].GetHess() >= 0.0f && coin_flip(rnd)) || gpair_ref[i].GetGrad() == 0.0f) {
|
||||
gpair_ref[i] = GradientPair(0);
|
||||
}
|
||||
}
|
||||
/* resize row_indices to reduce memory */
|
||||
row_indices->resize(used);
|
||||
unused_rows_.resize(unused);
|
||||
#else
|
||||
const size_t nthread = this->nthread_;
|
||||
std::vector<size_t> row_offsets_used(nthread, 0);
|
||||
std::vector<size_t> row_offsets_unused(nthread, 0);
|
||||
/* usage of mt19937_64 give 2x speed up for subsampling */
|
||||
std::vector<std::mt19937> rnds(nthread);
|
||||
/* create engine for each thread */
|
||||
for (std::mt19937& r : rnds) {
|
||||
r = rnd;
|
||||
}
|
||||
uint64_t initial_seed = rnd();
|
||||
|
||||
const size_t discard_size = info.num_row_ / nthread;
|
||||
auto upper_border = static_cast<float>(std::numeric_limits<uint32_t>::max());
|
||||
uint32_t coin_flip_border = static_cast<uint32_t>(upper_border * param_.subsample);
|
||||
std::bernoulli_distribution coin_flip(param_.subsample);
|
||||
|
||||
dmlc::OMPException exc;
|
||||
#pragma omp parallel num_threads(nthread)
|
||||
{
|
||||
@ -725,60 +696,24 @@ void QuantileHistMaker::Builder<GradientSumT>::InitSampling(const std::vector<Gr
|
||||
const size_t ibegin = tid * discard_size;
|
||||
const size_t iend = (tid == (nthread - 1)) ?
|
||||
info.num_row_ : ibegin + discard_size;
|
||||
|
||||
rnds[tid].discard(discard_size * tid);
|
||||
for (size_t i = ibegin; i < iend; ++i) {
|
||||
if (gpair[i].GetHess() >= 0.0f && rnds[tid]() < coin_flip_border) {
|
||||
p_row_indices_used[ibegin + row_offsets_used[tid]++] = i;
|
||||
} else {
|
||||
p_row_indices_unused[ibegin + row_offsets_unused[tid]++] = i;
|
||||
}
|
||||
}
|
||||
|
||||
#pragma omp barrier
|
||||
|
||||
if (tid == 0ul) {
|
||||
size_t prefix_sum_used = row_offsets_used[0];
|
||||
for (size_t i = 1; i < nthread; ++i) {
|
||||
const size_t ibegin = i * discard_size;
|
||||
|
||||
for (size_t k = 0; k < row_offsets_used[i]; ++k) {
|
||||
p_row_indices_used[prefix_sum_used + k] = p_row_indices_used[ibegin + k];
|
||||
}
|
||||
|
||||
prefix_sum_used += row_offsets_used[i];
|
||||
}
|
||||
/* resize row_indices to reduce memory */
|
||||
row_indices->resize(prefix_sum_used);
|
||||
}
|
||||
|
||||
if (nthread == 1ul || tid == 1ul) {
|
||||
size_t prefix_sum_unused = row_offsets_unused[0];
|
||||
for (size_t i = 1; i < nthread; ++i) {
|
||||
const size_t ibegin = i * discard_size;
|
||||
|
||||
for (size_t k = 0; k < row_offsets_unused[i]; ++k) {
|
||||
p_row_indices_unused[prefix_sum_unused + k] = p_row_indices_unused[ibegin + k];
|
||||
}
|
||||
|
||||
prefix_sum_unused += row_offsets_unused[i];
|
||||
}
|
||||
/* resize row_indices to reduce memory */
|
||||
unused_rows_.resize(prefix_sum_unused);
|
||||
}
|
||||
RandomReplace::MakeIf([&](size_t i, RandomReplace::EngineT& eng) {
|
||||
return !(gpair_ref[i].GetHess() >= 0.0f && coin_flip(eng));
|
||||
}, GradientPair(0), initial_seed, ibegin, iend, &gpair_ref);
|
||||
});
|
||||
}
|
||||
exc.Rethrow();
|
||||
/* discard global engine */
|
||||
rnd = rnds[nthread - 1];
|
||||
#endif // XGBOOST_CUSTOMIZE_GLOBAL_PRNG
|
||||
}
|
||||
template<typename GradientSumT>
|
||||
size_t QuantileHistMaker::Builder<GradientSumT>::GetNumberOfTrees() {
|
||||
return n_trees_;
|
||||
}
|
||||
|
||||
template<typename GradientSumT>
|
||||
void QuantileHistMaker::Builder<GradientSumT>::InitData(const GHistIndexMatrix& gmat,
|
||||
const std::vector<GradientPair>& gpair,
|
||||
const DMatrix& fmat,
|
||||
const RegTree& tree) {
|
||||
const RegTree& tree,
|
||||
std::vector<GradientPair>* gpair) {
|
||||
CHECK((param_.max_depth > 0 || param_.max_leaves > 0))
|
||||
<< "max_depth or max_leaves cannot be both 0 (unlimited); "
|
||||
<< "at least one should be a positive quantity.";
|
||||
@ -818,17 +753,53 @@ void QuantileHistMaker::Builder<GradientSumT>::InitData(const GHistIndexMatrix&
|
||||
CHECK_EQ(param_.sampling_method, TrainParam::kUniform)
|
||||
<< "Only uniform sampling is supported, "
|
||||
<< "gradient-based sampling is only support by GPU Hist.";
|
||||
InitSampling(gpair, fmat, &row_indices);
|
||||
builder_monitor_.Start("InitSampling");
|
||||
InitSampling(fmat, gpair, &row_indices);
|
||||
builder_monitor_.Stop("InitSampling");
|
||||
CHECK_EQ(row_indices.size(), info.num_row_);
|
||||
// We should check that the partitioning was done correctly
|
||||
// and each row of the dataset fell into exactly one of the categories
|
||||
CHECK_EQ(row_indices.size() + unused_rows_.size(), info.num_row_);
|
||||
}
|
||||
MemStackAllocator<bool, 128> buff(this->nthread_);
|
||||
bool* p_buff = buff.Get();
|
||||
std::fill(p_buff, p_buff + this->nthread_, false);
|
||||
|
||||
const size_t block_size = info.num_row_ / this->nthread_ + !!(info.num_row_ % this->nthread_);
|
||||
|
||||
#pragma omp parallel num_threads(this->nthread_)
|
||||
{
|
||||
exc.Run([&]() {
|
||||
const size_t tid = omp_get_thread_num();
|
||||
const size_t ibegin = tid * block_size;
|
||||
const size_t iend = std::min(static_cast<size_t>(ibegin + block_size),
|
||||
static_cast<size_t>(info.num_row_));
|
||||
|
||||
for (size_t i = ibegin; i < iend; ++i) {
|
||||
if ((*gpair)[i].GetHess() < 0.0f) {
|
||||
p_buff[tid] = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
exc.Rethrow();
|
||||
|
||||
bool has_neg_hess = false;
|
||||
for (int32_t tid = 0; tid < this->nthread_; ++tid) {
|
||||
if (p_buff[tid]) {
|
||||
has_neg_hess = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (has_neg_hess) {
|
||||
size_t j = 0;
|
||||
for (size_t i = 0; i < info.num_row_; ++i) {
|
||||
if ((*gpair)[i].GetHess() >= 0.0f) {
|
||||
p_row_indices[j++] = i;
|
||||
}
|
||||
}
|
||||
row_indices.resize(j);
|
||||
} else {
|
||||
MemStackAllocator<bool, 128> buff(this->nthread_);
|
||||
bool* p_buff = buff.Get();
|
||||
std::fill(p_buff, p_buff + this->nthread_, false);
|
||||
|
||||
const size_t block_size = info.num_row_ / this->nthread_ + !!(info.num_row_ % this->nthread_);
|
||||
|
||||
#pragma omp parallel num_threads(this->nthread_)
|
||||
{
|
||||
exc.Run([&]() {
|
||||
@ -836,47 +807,12 @@ void QuantileHistMaker::Builder<GradientSumT>::InitData(const GHistIndexMatrix&
|
||||
const size_t ibegin = tid * block_size;
|
||||
const size_t iend = std::min(static_cast<size_t>(ibegin + block_size),
|
||||
static_cast<size_t>(info.num_row_));
|
||||
|
||||
for (size_t i = ibegin; i < iend; ++i) {
|
||||
if (gpair[i].GetHess() < 0.0f) {
|
||||
p_buff[tid] = true;
|
||||
break;
|
||||
}
|
||||
p_row_indices[i] = i;
|
||||
}
|
||||
});
|
||||
}
|
||||
exc.Rethrow();
|
||||
|
||||
bool has_neg_hess = false;
|
||||
for (int32_t tid = 0; tid < this->nthread_; ++tid) {
|
||||
if (p_buff[tid]) {
|
||||
has_neg_hess = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (has_neg_hess) {
|
||||
size_t j = 0;
|
||||
for (size_t i = 0; i < info.num_row_; ++i) {
|
||||
if (gpair[i].GetHess() >= 0.0f) {
|
||||
p_row_indices[j++] = i;
|
||||
}
|
||||
}
|
||||
row_indices.resize(j);
|
||||
} else {
|
||||
#pragma omp parallel num_threads(this->nthread_)
|
||||
{
|
||||
exc.Run([&]() {
|
||||
const size_t tid = omp_get_thread_num();
|
||||
const size_t ibegin = tid * block_size;
|
||||
const size_t iend = std::min(static_cast<size_t>(ibegin + block_size),
|
||||
static_cast<size_t>(info.num_row_));
|
||||
for (size_t i = ibegin; i < iend; ++i) {
|
||||
p_row_indices[i] = i;
|
||||
}
|
||||
});
|
||||
}
|
||||
exc.Rethrow();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1074,9 +1010,9 @@ inline std::pair<size_t, size_t> PartitionDenseKernel(const common::DenseColumn<
|
||||
// Handle sparse columns
|
||||
template<bool default_left, typename BinIdxType>
|
||||
inline std::pair<size_t, size_t> PartitionSparseKernel(
|
||||
const common::SparseColumn<BinIdxType>& column,
|
||||
common::Span<const size_t> rid_span, const int32_t split_cond,
|
||||
const common::SparseColumn<BinIdxType>& column, common::Span<size_t> left_part,
|
||||
common::Span<size_t> right_part) {
|
||||
common::Span<size_t> left_part, common::Span<size_t> right_part) {
|
||||
size_t* p_left_part = left_part.data();
|
||||
size_t* p_right_part = right_part.data();
|
||||
|
||||
@ -1131,7 +1067,7 @@ inline std::pair<size_t, size_t> PartitionSparseKernel(
|
||||
template <typename GradientSumT>
|
||||
template <typename BinIdxType>
|
||||
void QuantileHistMaker::Builder<GradientSumT>::PartitionKernel(
|
||||
const size_t node_in_set, const size_t nid, common::Range1d range,
|
||||
const size_t node_in_set, const size_t nid, const common::Range1d range,
|
||||
const int32_t split_cond, const ColumnMatrix& column_matrix, const RegTree& tree) {
|
||||
const size_t* rid = row_set_collection_[nid].begin;
|
||||
|
||||
@ -1151,28 +1087,30 @@ void QuantileHistMaker::Builder<GradientSumT>::PartitionKernel(
|
||||
static_cast<const common::DenseColumn<BinIdxType>& >(*(column_ptr.get()));
|
||||
if (default_left) {
|
||||
if (column_matrix.AnyMissing()) {
|
||||
child_nodes_sizes = PartitionDenseKernel<true, true>(column, rid_span, split_cond,
|
||||
left, right);
|
||||
child_nodes_sizes = PartitionDenseKernel<true, true>(column, rid_span,
|
||||
split_cond, left, right);
|
||||
} else {
|
||||
child_nodes_sizes = PartitionDenseKernel<true, false>(column, rid_span, split_cond,
|
||||
left, right);
|
||||
child_nodes_sizes = PartitionDenseKernel<true, false>(column, rid_span,
|
||||
split_cond, left, right);
|
||||
}
|
||||
} else {
|
||||
if (column_matrix.AnyMissing()) {
|
||||
child_nodes_sizes = PartitionDenseKernel<false, true>(column, rid_span, split_cond,
|
||||
left, right);
|
||||
child_nodes_sizes = PartitionDenseKernel<false, true>(column, rid_span,
|
||||
split_cond, left, right);
|
||||
} else {
|
||||
child_nodes_sizes = PartitionDenseKernel<false, false>(column, rid_span, split_cond,
|
||||
left, right);
|
||||
child_nodes_sizes = PartitionDenseKernel<false, false>(column, rid_span,
|
||||
split_cond, left, right);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const common::SparseColumn<BinIdxType>& column
|
||||
= static_cast<const common::SparseColumn<BinIdxType>& >(*(column_ptr.get()));
|
||||
if (default_left) {
|
||||
child_nodes_sizes = PartitionSparseKernel<true>(rid_span, split_cond, column, left, right);
|
||||
child_nodes_sizes = PartitionSparseKernel<true>(column, rid_span,
|
||||
split_cond, left, right);
|
||||
} else {
|
||||
child_nodes_sizes = PartitionSparseKernel<false>(rid_span, split_cond, column, left, right);
|
||||
child_nodes_sizes = PartitionSparseKernel<false>(column, rid_span,
|
||||
split_cond, left, right);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -31,6 +31,48 @@
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
|
||||
struct RandomReplace {
|
||||
public:
|
||||
// similar value as for minstd_rand
|
||||
static constexpr uint64_t kBase = 16807;
|
||||
static constexpr uint64_t kMod = static_cast<uint64_t>(1) << 63;
|
||||
|
||||
using EngineT = std::linear_congruential_engine<uint64_t, kBase, 0, kMod>;
|
||||
|
||||
/*
|
||||
Right-to-left binary method: https://en.wikipedia.org/wiki/Modular_exponentiation
|
||||
*/
|
||||
static uint64_t SimpleSkip(uint64_t exponent, uint64_t initial_seed,
|
||||
uint64_t base, uint64_t mod) {
|
||||
CHECK_LE(exponent, mod);
|
||||
uint64_t result = 1;
|
||||
while (exponent > 0) {
|
||||
if (exponent % 2 == 1) {
|
||||
result = (result * base) % mod;
|
||||
}
|
||||
base = (base * base) % mod;
|
||||
exponent = exponent >> 1;
|
||||
}
|
||||
// with result we can now find the new seed
|
||||
return (result * initial_seed) % mod;
|
||||
}
|
||||
|
||||
template<typename Condition, typename ContainerData>
|
||||
static void MakeIf(Condition condition, const typename ContainerData::value_type replace_value,
|
||||
const uint64_t initial_seed, const size_t ibegin,
|
||||
const size_t iend, ContainerData* gpair) {
|
||||
ContainerData& gpair_ref = *gpair;
|
||||
const uint64_t displaced_seed = SimpleSkip(ibegin, initial_seed, kBase, kMod);
|
||||
EngineT eng(displaced_seed);
|
||||
for (size_t i = ibegin; i < iend; ++i) {
|
||||
if (condition(i, eng)) {
|
||||
gpair_ref[i] = replace_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief A C-style array with in-stack allocation. As long as the array is smaller than MaxStackSize, it will be allocated inside the stack. Otherwise, it will be heap-allocated.
|
||||
*/
|
||||
@ -201,11 +243,13 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
using GHistRowT = GHistRow<GradientSumT>;
|
||||
using GradientPairT = xgboost::detail::GradientPairInternal<GradientSumT>;
|
||||
// constructor
|
||||
explicit Builder(const TrainParam& param,
|
||||
explicit Builder(const size_t n_trees,
|
||||
const TrainParam& param,
|
||||
std::unique_ptr<TreeUpdater> pruner,
|
||||
FeatureInteractionConstraintHost int_constraints_,
|
||||
DMatrix const* fmat)
|
||||
: param_(param),
|
||||
: n_trees_(n_trees),
|
||||
param_(param),
|
||||
tree_evaluator_(param, fmat->Info().num_col_, GenericParameter::kCpuId),
|
||||
pruner_(std::move(pruner)),
|
||||
interaction_constraints_{std::move(int_constraints_)},
|
||||
@ -279,12 +323,15 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
|
||||
// initialize temp data structure
|
||||
void InitData(const GHistIndexMatrix& gmat,
|
||||
const std::vector<GradientPair>& gpair,
|
||||
const DMatrix& fmat,
|
||||
const RegTree& tree);
|
||||
const RegTree& tree,
|
||||
std::vector<GradientPair>* gpair);
|
||||
|
||||
void InitSampling(const std::vector<GradientPair>& gpair,
|
||||
const DMatrix& fmat, std::vector<size_t>* row_indices);
|
||||
size_t GetNumberOfTrees();
|
||||
|
||||
void InitSampling(const DMatrix& fmat,
|
||||
std::vector<GradientPair>* gpair,
|
||||
std::vector<size_t>* row_indices);
|
||||
|
||||
void EvaluateSplits(const std::vector<ExpandEntry>& nodes_set,
|
||||
const GHistIndexMatrix& gmat,
|
||||
@ -298,7 +345,7 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
RegTree* p_tree);
|
||||
|
||||
template <typename BinIdxType>
|
||||
void PartitionKernel(const size_t node_in_set, const size_t nid, common::Range1d range,
|
||||
void PartitionKernel(const size_t node_in_set, const size_t nid, const common::Range1d range,
|
||||
const int32_t split_cond,
|
||||
const ColumnMatrix& column_matrix, const RegTree& tree);
|
||||
|
||||
@ -398,6 +445,7 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
}
|
||||
}
|
||||
// --data fields--
|
||||
const size_t n_trees_;
|
||||
const TrainParam& param_;
|
||||
// number of omp thread used during training
|
||||
int nthread_;
|
||||
@ -413,6 +461,7 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
std::vector<SplitEntry> best_split_tloc_;
|
||||
/*! \brief TreeNode Data: statistics for each constructed node */
|
||||
std::vector<NodeEntry> snode_;
|
||||
std::vector<GradientPair> gpair_local_;
|
||||
/*! \brief culmulative histogram of gradients. */
|
||||
HistCollection<GradientSumT> hist_;
|
||||
/*! \brief culmulative local parent histogram of gradients. */
|
||||
@ -458,7 +507,7 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
common::Monitor updater_monitor_;
|
||||
|
||||
template<typename GradientSumT>
|
||||
void SetBuilder(std::unique_ptr<Builder<GradientSumT>>*, DMatrix *dmat);
|
||||
void SetBuilder(const size_t n_trees, std::unique_ptr<Builder<GradientSumT>>*, DMatrix *dmat);
|
||||
|
||||
template<typename GradientSumT>
|
||||
void CallBuilderUpdate(const std::unique_ptr<Builder<GradientSumT>>& builder,
|
||||
|
||||
@ -31,15 +31,15 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
std::unique_ptr<TreeUpdater> pruner,
|
||||
FeatureInteractionConstraintHost int_constraint,
|
||||
DMatrix const* fmat)
|
||||
: RealImpl(param, std::move(pruner),
|
||||
: RealImpl(1, param, std::move(pruner),
|
||||
std::move(int_constraint), fmat) {}
|
||||
|
||||
public:
|
||||
void TestInitData(const GHistIndexMatrix& gmat,
|
||||
const std::vector<GradientPair>& gpair,
|
||||
std::vector<GradientPair>* gpair,
|
||||
DMatrix* p_fmat,
|
||||
const RegTree& tree) {
|
||||
RealImpl::InitData(gmat, gpair, *p_fmat, tree);
|
||||
RealImpl::InitData(gmat, *p_fmat, tree, gpair);
|
||||
ASSERT_EQ(this->data_layout_, RealImpl::DataLayout::kSparseData);
|
||||
|
||||
/* The creation of HistCutMatrix and GHistIndexMatrix are not technically
|
||||
@ -101,29 +101,34 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
}
|
||||
|
||||
void TestInitDataSampling(const GHistIndexMatrix& gmat,
|
||||
const std::vector<GradientPair>& gpair,
|
||||
std::vector<GradientPair>* gpair,
|
||||
DMatrix* p_fmat,
|
||||
const RegTree& tree) {
|
||||
// check SimpleSkip
|
||||
size_t initial_seed = 777;
|
||||
std::linear_congruential_engine<std::uint_fast64_t, 16807, 0,
|
||||
static_cast<uint64_t>(1) << 63 > eng_first(initial_seed);
|
||||
for (size_t i = 0; i < 100; ++i) {
|
||||
eng_first();
|
||||
}
|
||||
uint64_t initial_seed_th = RandomReplace::SimpleSkip(100, initial_seed, 16807, RandomReplace::kMod);
|
||||
std::linear_congruential_engine<std::uint_fast64_t, RandomReplace::kBase, 0,
|
||||
RandomReplace::kMod > eng_second(initial_seed_th);
|
||||
ASSERT_EQ(eng_first(), eng_second());
|
||||
|
||||
const size_t nthreads = omp_get_num_threads();
|
||||
// save state of global rng engine
|
||||
auto initial_rnd = common::GlobalRandom();
|
||||
std::vector<size_t> unused_rows_cpy = this->unused_rows_;
|
||||
RealImpl::InitData(gmat, gpair, *p_fmat, tree);
|
||||
RealImpl::InitData(gmat, *p_fmat, tree, gpair);
|
||||
std::vector<size_t> row_indices_initial = *(this->row_set_collection_.Data());
|
||||
std::vector<size_t> unused_row_indices_initial = this->unused_rows_;
|
||||
ASSERT_EQ(row_indices_initial.size(), p_fmat->Info().num_row_);
|
||||
auto check_each_row_occurs_in_one_of_arrays = [](const std::vector<size_t>& first,
|
||||
const std::vector<size_t>& second,
|
||||
size_t nrows) {
|
||||
std::vector<size_t> arr_union(nrows);
|
||||
for (auto&& row_indice : first) {
|
||||
++arr_union[row_indice];
|
||||
}
|
||||
for (auto&& row_indice : second) {
|
||||
++arr_union[row_indice];
|
||||
}
|
||||
for (auto&& row_cnt : arr_union) {
|
||||
ASSERT_EQ(row_cnt, 1ul);
|
||||
}
|
||||
ASSERT_EQ(first.size(), nrows);
|
||||
ASSERT_EQ(second.size(), 0);
|
||||
};
|
||||
check_each_row_occurs_in_one_of_arrays(row_indices_initial, unused_row_indices_initial,
|
||||
p_fmat->Info().num_row_);
|
||||
@ -133,7 +138,7 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
// return initial state of global rng engine
|
||||
common::GlobalRandom() = initial_rnd;
|
||||
this->unused_rows_ = unused_rows_cpy;
|
||||
RealImpl::InitData(gmat, gpair, *p_fmat, tree);
|
||||
RealImpl::InitData(gmat, *p_fmat, tree, gpair);
|
||||
std::vector<size_t>& row_indices = *(this->row_set_collection_.Data());
|
||||
ASSERT_EQ(row_indices_initial.size(), row_indices.size());
|
||||
for (size_t i = 0; i < row_indices_initial.size(); ++i) {
|
||||
@ -151,10 +156,10 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
}
|
||||
|
||||
void TestAddHistRows(const GHistIndexMatrix& gmat,
|
||||
const std::vector<GradientPair>& gpair,
|
||||
std::vector<GradientPair>* gpair,
|
||||
DMatrix* p_fmat,
|
||||
RegTree* tree) {
|
||||
RealImpl::InitData(gmat, gpair, *p_fmat, *tree);
|
||||
RealImpl::InitData(gmat, *p_fmat, *tree, gpair);
|
||||
|
||||
int starting_index = std::numeric_limits<int>::max();
|
||||
int sync_count = 0;
|
||||
@ -183,11 +188,11 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
|
||||
|
||||
void TestSyncHistograms(const GHistIndexMatrix& gmat,
|
||||
const std::vector<GradientPair>& gpair,
|
||||
std::vector<GradientPair>* gpair,
|
||||
DMatrix* p_fmat,
|
||||
RegTree* tree) {
|
||||
// init
|
||||
RealImpl::InitData(gmat, gpair, *p_fmat, *tree);
|
||||
RealImpl::InitData(gmat, *p_fmat, *tree, gpair);
|
||||
|
||||
int starting_index = std::numeric_limits<int>::max();
|
||||
int sync_count = 0;
|
||||
@ -295,10 +300,10 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
const GHistIndexMatrix& gmat,
|
||||
const DMatrix& fmat,
|
||||
const RegTree& tree) {
|
||||
const std::vector<GradientPair> gpair =
|
||||
std::vector<GradientPair> gpair =
|
||||
{ {0.23f, 0.24f}, {0.24f, 0.25f}, {0.26f, 0.27f}, {0.27f, 0.28f},
|
||||
{0.27f, 0.29f}, {0.37f, 0.39f}, {0.47f, 0.49f}, {0.57f, 0.59f} };
|
||||
RealImpl::InitData(gmat, gpair, fmat, tree);
|
||||
RealImpl::InitData(gmat, fmat, tree, &gpair);
|
||||
GHistIndexBlockMatrix dummy;
|
||||
this->hist_.AddHistRow(nid);
|
||||
this->hist_.AllocateAllData();
|
||||
@ -341,7 +346,7 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
common::GHistIndexMatrix gmat;
|
||||
gmat.Init(dmat.get(), kMaxBins);
|
||||
|
||||
RealImpl::InitData(gmat, row_gpairs, *dmat, tree);
|
||||
RealImpl::InitData(gmat, *dmat, tree, &row_gpairs);
|
||||
this->hist_.AddHistRow(0);
|
||||
this->hist_.AllocateAllData();
|
||||
this->BuildHist(row_gpairs, this->row_set_collection_[0],
|
||||
@ -437,7 +442,7 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
|
||||
// treat everything as dense, as this is what we intend to test here
|
||||
cm.Init(gmat, 0.0);
|
||||
RealImpl::InitData(gmat, row_gpairs, *dmat, tree);
|
||||
RealImpl::InitData(gmat, *dmat, tree, &row_gpairs);
|
||||
this->hist_.AddHistRow(0);
|
||||
this->hist_.AllocateAllData();
|
||||
RealImpl::InitNewNode(0, gmat, row_gpairs, *dmat, tree);
|
||||
@ -548,9 +553,9 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
{ {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f},
|
||||
{0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} };
|
||||
if (double_builder_) {
|
||||
double_builder_->TestInitData(gmat, gpair, dmat_.get(), tree);
|
||||
double_builder_->TestInitData(gmat, &gpair, dmat_.get(), tree);
|
||||
} else {
|
||||
float_builder_->TestInitData(gmat, gpair, dmat_.get(), tree);
|
||||
float_builder_->TestInitData(gmat, &gpair, dmat_.get(), tree);
|
||||
}
|
||||
}
|
||||
|
||||
@ -566,9 +571,9 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
{ {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f},
|
||||
{0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} };
|
||||
if (double_builder_) {
|
||||
double_builder_->TestInitDataSampling(gmat, gpair, dmat_.get(), tree);
|
||||
double_builder_->TestInitDataSampling(gmat, &gpair, dmat_.get(), tree);
|
||||
} else {
|
||||
float_builder_->TestInitDataSampling(gmat, gpair, dmat_.get(), tree);
|
||||
float_builder_->TestInitDataSampling(gmat, &gpair, dmat_.get(), tree);
|
||||
}
|
||||
}
|
||||
|
||||
@ -583,9 +588,9 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
{ {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f},
|
||||
{0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} };
|
||||
if (double_builder_) {
|
||||
double_builder_->TestAddHistRows(gmat, gpair, dmat_.get(), &tree);
|
||||
double_builder_->TestAddHistRows(gmat, &gpair, dmat_.get(), &tree);
|
||||
} else {
|
||||
float_builder_->TestAddHistRows(gmat, gpair, dmat_.get(), &tree);
|
||||
float_builder_->TestAddHistRows(gmat, &gpair, dmat_.get(), &tree);
|
||||
}
|
||||
}
|
||||
|
||||
@ -600,9 +605,9 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
{ {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f},
|
||||
{0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} };
|
||||
if (double_builder_) {
|
||||
double_builder_->TestSyncHistograms(gmat, gpair, dmat_.get(), &tree);
|
||||
double_builder_->TestSyncHistograms(gmat, &gpair, dmat_.get(), &tree);
|
||||
} else {
|
||||
float_builder_->TestSyncHistograms(gmat, gpair, dmat_.get(), &tree);
|
||||
float_builder_->TestSyncHistograms(gmat, &gpair, dmat_.get(), &tree);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user