Extract evaluate splits from CPU hist. (#7079)
Other than modularizing the split evaluation function, this PR also removes some more functions including `InitNewNodes` and `BuildNodeStats` among some other unused variables. Also, scattered code like setting leaf weights is grouped into the split evaluator and `NodeEntry` is simplified and made private. Another subtle difference with the original implementation is that the modified code doesn't call `tree[nidx].Parent()` to traversal upward.
This commit is contained in:
parent
d22b293f2f
commit
615ab2b03e
268
src/tree/hist/evaluate_splits.h
Normal file
268
src/tree/hist/evaluate_splits.h
Normal file
@ -0,0 +1,268 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2021 by XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#ifndef XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_
|
||||||
|
#define XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <memory>
|
||||||
|
#include <limits>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "../param.h"
|
||||||
|
#include "../constraints.h"
|
||||||
|
#include "../split_evaluator.h"
|
||||||
|
#include "../../common/random.h"
|
||||||
|
#include "../../common/hist_util.h"
|
||||||
|
#include "../../data/gradient_index.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace tree {
|
||||||
|
|
||||||
|
template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
|
||||||
|
private:
|
||||||
|
struct NodeEntry {
|
||||||
|
/*! \brief statics for node entry */
|
||||||
|
GradStats stats;
|
||||||
|
/*! \brief loss of this node, without split */
|
||||||
|
bst_float root_gain{0.0f};
|
||||||
|
};
|
||||||
|
|
||||||
|
private:
|
||||||
|
TrainParam param_;
|
||||||
|
std::shared_ptr<common::ColumnSampler> column_sampler_;
|
||||||
|
TreeEvaluator tree_evaluator_;
|
||||||
|
int32_t n_threads_ {0};
|
||||||
|
FeatureInteractionConstraintHost interaction_constraints_;
|
||||||
|
std::vector<NodeEntry> snode_;
|
||||||
|
|
||||||
|
// if sum of statistics for non-missing values in the node
|
||||||
|
// is equal to sum of statistics for all values:
|
||||||
|
// then - there are no missing values
|
||||||
|
// else - there are missing values
|
||||||
|
bool static SplitContainsMissingValues(const GradStats e,
|
||||||
|
const NodeEntry &snode) {
|
||||||
|
if (e.GetGrad() == snode.stats.GetGrad() &&
|
||||||
|
e.GetHess() == snode.stats.GetHess()) {
|
||||||
|
return false;
|
||||||
|
} else {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enumerate/Scan the split values of specific feature
|
||||||
|
// Returns the sum of gradients corresponding to the data points that contains
|
||||||
|
// a non-missing value for the particular feature fid.
|
||||||
|
template <int d_step>
|
||||||
|
GradStats EnumerateSplit(
|
||||||
|
const GHistIndexMatrix &gmat, const common::GHistRow<GradientSumT> &hist,
|
||||||
|
const NodeEntry &snode, SplitEntry *p_best, bst_feature_t fidx,
|
||||||
|
bst_node_t nidx,
|
||||||
|
TreeEvaluator::SplitEvaluator<TrainParam> const &evaluator) const {
|
||||||
|
static_assert(d_step == +1 || d_step == -1, "Invalid step.");
|
||||||
|
|
||||||
|
// aliases
|
||||||
|
const std::vector<uint32_t> &cut_ptr = gmat.cut.Ptrs();
|
||||||
|
const std::vector<bst_float> &cut_val = gmat.cut.Values();
|
||||||
|
|
||||||
|
// statistics on both sides of split
|
||||||
|
GradStats c;
|
||||||
|
GradStats e;
|
||||||
|
// best split so far
|
||||||
|
SplitEntry best;
|
||||||
|
|
||||||
|
// bin boundaries
|
||||||
|
CHECK_LE(cut_ptr[fidx],
|
||||||
|
static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
|
||||||
|
CHECK_LE(cut_ptr[fidx + 1],
|
||||||
|
static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
|
||||||
|
// imin: index (offset) of the minimum value for feature fid
|
||||||
|
// need this for backward enumeration
|
||||||
|
const auto imin = static_cast<int32_t>(cut_ptr[fidx]);
|
||||||
|
// ibegin, iend: smallest/largest cut points for feature fid
|
||||||
|
// use int to allow for value -1
|
||||||
|
int32_t ibegin, iend;
|
||||||
|
if (d_step > 0) {
|
||||||
|
ibegin = static_cast<int32_t>(cut_ptr[fidx]);
|
||||||
|
iend = static_cast<int32_t>(cut_ptr.at(fidx + 1));
|
||||||
|
} else {
|
||||||
|
ibegin = static_cast<int32_t>(cut_ptr[fidx + 1]) - 1;
|
||||||
|
iend = static_cast<int32_t>(cut_ptr[fidx]) - 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int32_t i = ibegin; i != iend; i += d_step) {
|
||||||
|
// start working
|
||||||
|
// try to find a split
|
||||||
|
e.Add(hist[i].GetGrad(), hist[i].GetHess());
|
||||||
|
if (e.GetHess() >= param_.min_child_weight) {
|
||||||
|
c.SetSubstract(snode.stats, e);
|
||||||
|
if (c.GetHess() >= param_.min_child_weight) {
|
||||||
|
bst_float loss_chg;
|
||||||
|
bst_float split_pt;
|
||||||
|
if (d_step > 0) {
|
||||||
|
// forward enumeration: split at right bound of each bin
|
||||||
|
loss_chg = static_cast<bst_float>(
|
||||||
|
evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{e},
|
||||||
|
GradStats{c}) -
|
||||||
|
snode.root_gain);
|
||||||
|
split_pt = cut_val[i];
|
||||||
|
best.Update(loss_chg, fidx, split_pt, d_step == -1, e, c);
|
||||||
|
} else {
|
||||||
|
// backward enumeration: split at left bound of each bin
|
||||||
|
loss_chg = static_cast<bst_float>(
|
||||||
|
evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{c},
|
||||||
|
GradStats{e}) -
|
||||||
|
snode.root_gain);
|
||||||
|
if (i == imin) {
|
||||||
|
// for leftmost bin, left bound is the smallest feature value
|
||||||
|
split_pt = gmat.cut.MinValues()[fidx];
|
||||||
|
} else {
|
||||||
|
split_pt = cut_val[i - 1];
|
||||||
|
}
|
||||||
|
best.Update(loss_chg, fidx, split_pt, d_step == -1, c, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
p_best->Update(best);
|
||||||
|
|
||||||
|
return e;
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
void EvaluateSplits(const common::HistCollection<GradientSumT> &hist,
|
||||||
|
GHistIndexMatrix const &gidx, const RegTree &tree,
|
||||||
|
std::vector<ExpandEntry>* p_entries) {
|
||||||
|
auto& entries = *p_entries;
|
||||||
|
// All nodes are on the same level, so we can store the shared ptr.
|
||||||
|
std::vector<std::shared_ptr<HostDeviceVector<bst_feature_t>>> features(
|
||||||
|
entries.size());
|
||||||
|
for (size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
|
||||||
|
auto nidx = entries[nidx_in_set].nid;
|
||||||
|
features[nidx_in_set] =
|
||||||
|
column_sampler_->GetFeatureSet(tree.GetDepth(nidx));
|
||||||
|
}
|
||||||
|
CHECK(!features.empty());
|
||||||
|
const size_t grain_size =
|
||||||
|
std::max<size_t>(1, features.front()->Size() / n_threads_);
|
||||||
|
common::BlockedSpace2d space(entries.size(), [&](size_t nidx_in_set) {
|
||||||
|
return features[nidx_in_set]->Size();
|
||||||
|
}, grain_size);
|
||||||
|
|
||||||
|
std::vector<ExpandEntry> tloc_candidates(omp_get_max_threads() * entries.size());
|
||||||
|
for (size_t i = 0; i < entries.size(); ++i) {
|
||||||
|
for (decltype(n_threads_) j = 0; j < n_threads_; ++j) {
|
||||||
|
tloc_candidates[i * n_threads_ + j] = entries[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto evaluator = tree_evaluator_.GetEvaluator();
|
||||||
|
|
||||||
|
common::ParallelFor2d(space, n_threads_, [&](size_t nidx_in_set, common::Range1d r) {
|
||||||
|
auto tidx = omp_get_thread_num();
|
||||||
|
auto entry = &tloc_candidates[n_threads_ * nidx_in_set + tidx];
|
||||||
|
auto best = &entry->split;
|
||||||
|
auto nidx = entry->nid;
|
||||||
|
auto histogram = hist[nidx];
|
||||||
|
auto features_set = features[nidx_in_set]->ConstHostSpan();
|
||||||
|
for (auto fidx_in_set = r.begin(); fidx_in_set < r.end(); fidx_in_set++) {
|
||||||
|
auto fidx = features_set[fidx_in_set];
|
||||||
|
if (interaction_constraints_.Query(nidx, fidx)) {
|
||||||
|
auto grad_stats = EnumerateSplit<+1>(gidx, histogram, snode_[nidx],
|
||||||
|
best, fidx, nidx, evaluator);
|
||||||
|
if (SplitContainsMissingValues(grad_stats, snode_[nidx])) {
|
||||||
|
EnumerateSplit<-1>(gidx, histogram, snode_[nidx], best, fidx, nidx,
|
||||||
|
evaluator);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
for (unsigned nidx_in_set = 0; nidx_in_set < entries.size();
|
||||||
|
++nidx_in_set) {
|
||||||
|
for (auto tidx = 0; tidx < n_threads_; ++tidx) {
|
||||||
|
entries[nidx_in_set].split.Update(
|
||||||
|
tloc_candidates[n_threads_ * nidx_in_set + tidx].split);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Add splits to tree, handles all statistic
|
||||||
|
void ApplyTreeSplit(ExpandEntry candidate, RegTree *p_tree) {
|
||||||
|
auto evaluator = tree_evaluator_.GetEvaluator();
|
||||||
|
RegTree &tree = *p_tree;
|
||||||
|
|
||||||
|
GradStats parent_sum = candidate.split.left_sum;
|
||||||
|
parent_sum.Add(candidate.split.right_sum);
|
||||||
|
auto base_weight =
|
||||||
|
evaluator.CalcWeight(candidate.nid, param_, GradStats{parent_sum});
|
||||||
|
|
||||||
|
auto left_weight = evaluator.CalcWeight(
|
||||||
|
candidate.nid, param_, GradStats{candidate.split.left_sum});
|
||||||
|
auto right_weight = evaluator.CalcWeight(
|
||||||
|
candidate.nid, param_, GradStats{candidate.split.right_sum});
|
||||||
|
|
||||||
|
tree.ExpandNode(candidate.nid, candidate.split.SplitIndex(),
|
||||||
|
candidate.split.split_value, candidate.split.DefaultLeft(),
|
||||||
|
base_weight, left_weight * param_.learning_rate,
|
||||||
|
right_weight * param_.learning_rate,
|
||||||
|
candidate.split.loss_chg, parent_sum.GetHess(),
|
||||||
|
candidate.split.left_sum.GetHess(),
|
||||||
|
candidate.split.right_sum.GetHess());
|
||||||
|
|
||||||
|
// Set up child constraints
|
||||||
|
auto left_child = tree[candidate.nid].LeftChild();
|
||||||
|
auto right_child = tree[candidate.nid].RightChild();
|
||||||
|
tree_evaluator_.AddSplit(candidate.nid, left_child, right_child,
|
||||||
|
tree[candidate.nid].SplitIndex(), left_weight,
|
||||||
|
right_weight);
|
||||||
|
|
||||||
|
auto max_node = std::max(left_child, tree[candidate.nid].RightChild());
|
||||||
|
max_node = std::max(candidate.nid, max_node);
|
||||||
|
snode_.resize(tree.GetNodes().size());
|
||||||
|
snode_.at(left_child).stats = candidate.split.left_sum;
|
||||||
|
snode_.at(left_child).root_gain = evaluator.CalcGain(
|
||||||
|
candidate.nid, param_, GradStats{candidate.split.left_sum});
|
||||||
|
snode_.at(right_child).stats = candidate.split.right_sum;
|
||||||
|
snode_.at(right_child).root_gain = evaluator.CalcGain(
|
||||||
|
candidate.nid, param_, GradStats{candidate.split.right_sum});
|
||||||
|
|
||||||
|
interaction_constraints_.Split(candidate.nid,
|
||||||
|
tree[candidate.nid].SplitIndex(), left_child,
|
||||||
|
right_child);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto Evaluator() const { return tree_evaluator_.GetEvaluator(); }
|
||||||
|
auto const& Stats() const { return snode_; }
|
||||||
|
|
||||||
|
float InitRoot(GradStats const& root_sum) {
|
||||||
|
snode_.resize(1);
|
||||||
|
auto root_evaluator = tree_evaluator_.GetEvaluator();
|
||||||
|
|
||||||
|
snode_[0].stats = GradStats{root_sum.GetGrad(), root_sum.GetHess()};
|
||||||
|
snode_[0].root_gain = root_evaluator.CalcGain(RegTree::kRoot, param_,
|
||||||
|
GradStats{snode_[0].stats});
|
||||||
|
auto weight = root_evaluator.CalcWeight(RegTree::kRoot, param_,
|
||||||
|
GradStats{snode_[0].stats});
|
||||||
|
return weight;
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
// The column sampler must be constructed by caller since we need to preserve the rng
|
||||||
|
// for the entire training session.
|
||||||
|
explicit HistEvaluator(TrainParam const ¶m, MetaInfo const &info,
|
||||||
|
int32_t n_threads,
|
||||||
|
std::shared_ptr<common::ColumnSampler> sampler,
|
||||||
|
bool skip_0_index = false)
|
||||||
|
: param_{param}, column_sampler_{std::move(sampler)},
|
||||||
|
tree_evaluator_{param, static_cast<bst_feature_t>(info.num_col_),
|
||||||
|
GenericParameter::kCpuId},
|
||||||
|
n_threads_{n_threads} {
|
||||||
|
interaction_constraints_.Configure(param, info.num_col_);
|
||||||
|
column_sampler_->Init(info.num_col_, info.feature_weigths.HostVector(),
|
||||||
|
param_.colsample_bynode, param_.colsample_bylevel,
|
||||||
|
param_.colsample_bytree, skip_0_index);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace tree
|
||||||
|
} // namespace xgboost
|
||||||
|
#endif // XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_
|
||||||
@ -51,11 +51,8 @@ template<typename GradientSumT>
|
|||||||
void QuantileHistMaker::SetBuilder(const size_t n_trees,
|
void QuantileHistMaker::SetBuilder(const size_t n_trees,
|
||||||
std::unique_ptr<Builder<GradientSumT>>* builder,
|
std::unique_ptr<Builder<GradientSumT>>* builder,
|
||||||
DMatrix *dmat) {
|
DMatrix *dmat) {
|
||||||
builder->reset(new Builder<GradientSumT>(
|
builder->reset(
|
||||||
n_trees,
|
new Builder<GradientSumT>(n_trees, param_, std::move(pruner_), dmat));
|
||||||
param_,
|
|
||||||
std::move(pruner_),
|
|
||||||
int_constraint_, dmat));
|
|
||||||
if (rabit::IsDistributed()) {
|
if (rabit::IsDistributed()) {
|
||||||
(*builder)->SetHistSynchronizer(new DistributedHistSynchronizer<GradientSumT>());
|
(*builder)->SetHistSynchronizer(new DistributedHistSynchronizer<GradientSumT>());
|
||||||
(*builder)->SetHistRowsAdder(new DistributedHistRowsAdder<GradientSumT>());
|
(*builder)->SetHistRowsAdder(new DistributedHistRowsAdder<GradientSumT>());
|
||||||
@ -75,6 +72,7 @@ void QuantileHistMaker::CallBuilderUpdate(const std::unique_ptr<Builder<Gradient
|
|||||||
builder->Update(gmat, column_matrix_, gpair, dmat, tree);
|
builder->Update(gmat, column_matrix_, gpair, dmat, tree);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair,
|
void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair,
|
||||||
DMatrix *dmat,
|
DMatrix *dmat,
|
||||||
const std::vector<RegTree *> &trees) {
|
const std::vector<RegTree *> &trees) {
|
||||||
@ -93,7 +91,7 @@ void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair,
|
|||||||
// rescale learning rate according to size of trees
|
// rescale learning rate according to size of trees
|
||||||
float lr = param_.learning_rate;
|
float lr = param_.learning_rate;
|
||||||
param_.learning_rate = lr / trees.size();
|
param_.learning_rate = lr / trees.size();
|
||||||
int_constraint_.Configure(param_, dmat->Info().num_col_);
|
|
||||||
// build tree
|
// build tree
|
||||||
const size_t n_trees = trees.size();
|
const size_t n_trees = trees.size();
|
||||||
if (hist_maker_param_.single_precision_histogram) {
|
if (hist_maker_param_.single_precision_histogram) {
|
||||||
@ -296,12 +294,9 @@ void QuantileHistMaker::Builder<GradientSumT>::SetHistRowsAdder(
|
|||||||
template <typename GradientSumT>
|
template <typename GradientSumT>
|
||||||
template <bool any_missing>
|
template <bool any_missing>
|
||||||
void QuantileHistMaker::Builder<GradientSumT>::InitRoot(
|
void QuantileHistMaker::Builder<GradientSumT>::InitRoot(
|
||||||
const GHistIndexMatrix &gmat,
|
const GHistIndexMatrix &gmat, const DMatrix &fmat, RegTree *p_tree,
|
||||||
const DMatrix& fmat,
|
const std::vector<GradientPair> &gpair_h, int *num_leaves,
|
||||||
RegTree *p_tree,
|
std::vector<CPUExpandEntry> *expand) {
|
||||||
const std::vector<GradientPair> &gpair_h,
|
|
||||||
int *num_leaves, std::vector<CPUExpandEntry> *expand) {
|
|
||||||
|
|
||||||
CPUExpandEntry node(CPUExpandEntry::kRootNid, p_tree->GetDepth(0), 0.0f);
|
CPUExpandEntry node(CPUExpandEntry::kRootNid, p_tree->GetDepth(0), 0.0f);
|
||||||
|
|
||||||
nodes_for_explicit_hist_build_.clear();
|
nodes_for_explicit_hist_build_.clear();
|
||||||
@ -315,10 +310,40 @@ void QuantileHistMaker::Builder<GradientSumT>::InitRoot(
|
|||||||
BuildLocalHistograms<any_missing>(gmat, p_tree, gpair_h);
|
BuildLocalHistograms<any_missing>(gmat, p_tree, gpair_h);
|
||||||
hist_synchronizer_->SyncHistograms(this, starting_index, sync_count, p_tree);
|
hist_synchronizer_->SyncHistograms(this, starting_index, sync_count, p_tree);
|
||||||
|
|
||||||
this->InitNewNode(CPUExpandEntry::kRootNid, gmat, gpair_h, fmat, *p_tree);
|
{
|
||||||
|
auto nid = CPUExpandEntry::kRootNid;
|
||||||
|
GHistRowT hist = hist_[nid];
|
||||||
|
GradientPairT grad_stat;
|
||||||
|
if (data_layout_ == DataLayout::kDenseDataZeroBased ||
|
||||||
|
data_layout_ == DataLayout::kDenseDataOneBased) {
|
||||||
|
const std::vector<uint32_t> &row_ptr = gmat.cut.Ptrs();
|
||||||
|
const uint32_t ibegin = row_ptr[fid_least_bins_];
|
||||||
|
const uint32_t iend = row_ptr[fid_least_bins_ + 1];
|
||||||
|
auto begin = hist.data();
|
||||||
|
for (uint32_t i = ibegin; i < iend; ++i) {
|
||||||
|
const GradientPairT et = begin[i];
|
||||||
|
grad_stat.Add(et.GetGrad(), et.GetHess());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const RowSetCollection::Elem e = row_set_collection_[nid];
|
||||||
|
for (const size_t *it = e.begin; it < e.end; ++it) {
|
||||||
|
grad_stat.Add(gpair_h[*it].GetGrad(), gpair_h[*it].GetHess());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
histred_.Allreduce(&grad_stat, 1);
|
||||||
|
|
||||||
|
auto weight = evaluator_->InitRoot(GradStats{grad_stat});
|
||||||
|
p_tree->Stat(RegTree::kRoot).sum_hess = grad_stat.GetHess();
|
||||||
|
p_tree->Stat(RegTree::kRoot).base_weight = weight;
|
||||||
|
(*p_tree)[RegTree::kRoot].SetLeaf(param_.learning_rate * weight);
|
||||||
|
|
||||||
|
std::vector<CPUExpandEntry> entries{node};
|
||||||
|
builder_monitor_.Start("EvaluateSplits");
|
||||||
|
evaluator_->EvaluateSplits(hist_, gmat, *p_tree, &entries);
|
||||||
|
builder_monitor_.Stop("EvaluateSplits");
|
||||||
|
node = entries.front();
|
||||||
|
}
|
||||||
|
|
||||||
this->EvaluateSplits({node}, gmat, hist_, *p_tree);
|
|
||||||
node.loss_chg = snode_[CPUExpandEntry::kRootNid].best.loss_chg;
|
|
||||||
expand->push_back(node);
|
expand->push_back(node);
|
||||||
++(*num_leaves);
|
++(*num_leaves);
|
||||||
}
|
}
|
||||||
@ -369,25 +394,10 @@ void QuantileHistMaker::Builder<GradientSumT>::AddSplitsToTree(
|
|||||||
RegTree *p_tree,
|
RegTree *p_tree,
|
||||||
int *num_leaves,
|
int *num_leaves,
|
||||||
std::vector<CPUExpandEntry>* nodes_for_apply_split) {
|
std::vector<CPUExpandEntry>* nodes_for_apply_split) {
|
||||||
auto evaluator = tree_evaluator_.GetEvaluator();
|
|
||||||
for (auto const& entry : expand) {
|
for (auto const& entry : expand) {
|
||||||
int nid = entry.nid;
|
|
||||||
|
|
||||||
if (entry.IsValid(param_, *num_leaves)) {
|
if (entry.IsValid(param_, *num_leaves)) {
|
||||||
(*p_tree)[nid].SetLeaf(snode_[nid].weight * param_.learning_rate);
|
|
||||||
} else {
|
|
||||||
nodes_for_apply_split->push_back(entry);
|
nodes_for_apply_split->push_back(entry);
|
||||||
|
evaluator_->ApplyTreeSplit(entry, p_tree);
|
||||||
NodeEntry& e = snode_[nid];
|
|
||||||
bst_float left_leaf_weight =
|
|
||||||
evaluator.CalcWeight(nid, param_, GradStats{e.best.left_sum}) * param_.learning_rate;
|
|
||||||
bst_float right_leaf_weight =
|
|
||||||
evaluator.CalcWeight(nid, param_, GradStats{e.best.right_sum}) * param_.learning_rate;
|
|
||||||
p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value,
|
|
||||||
e.best.DefaultLeft(), e.weight, left_leaf_weight,
|
|
||||||
right_leaf_weight, e.best.loss_chg, e.stats.GetHess(),
|
|
||||||
e.best.left_sum.GetHess(), e.best.right_sum.GetHess());
|
|
||||||
// - 1 parent + 2 new children
|
|
||||||
(*num_leaves)++;
|
(*num_leaves)++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -425,26 +435,6 @@ void QuantileHistMaker::Builder<GradientSumT>::SplitSiblings(
|
|||||||
builder_monitor_.Stop("SplitSiblings");
|
builder_monitor_.Stop("SplitSiblings");
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename GradientSumT>
|
|
||||||
void QuantileHistMaker::Builder<GradientSumT>::BuildNodeStats(
|
|
||||||
const GHistIndexMatrix &gmat,
|
|
||||||
const DMatrix& fmat,
|
|
||||||
const std::vector<GradientPair> &gpair_h,
|
|
||||||
const std::vector<CPUExpandEntry>& nodes_for_apply_split, RegTree *p_tree) {
|
|
||||||
for (auto const& candidate : nodes_for_apply_split) {
|
|
||||||
const int nid = candidate.nid;
|
|
||||||
const int cleft = (*p_tree)[nid].LeftChild();
|
|
||||||
const int cright = (*p_tree)[nid].RightChild();
|
|
||||||
|
|
||||||
InitNewNode(cleft, gmat, gpair_h, fmat, *p_tree);
|
|
||||||
InitNewNode(cright, gmat, gpair_h, fmat, *p_tree);
|
|
||||||
bst_uint featureid = snode_[nid].best.SplitIndex();
|
|
||||||
tree_evaluator_.AddSplit(nid, cleft, cright, featureid,
|
|
||||||
snode_[cleft].weight, snode_[cright].weight);
|
|
||||||
interaction_constraints_.Split(nid, featureid, cleft, cright);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename GradientSumT>
|
template<typename GradientSumT>
|
||||||
template <bool any_missing>
|
template <bool any_missing>
|
||||||
void QuantileHistMaker::Builder<GradientSumT>::ExpandTree(
|
void QuantileHistMaker::Builder<GradientSumT>::ExpandTree(
|
||||||
@ -484,20 +474,13 @@ void QuantileHistMaker::Builder<GradientSumT>::ExpandTree(
|
|||||||
hist_synchronizer_->SyncHistograms(this, starting_index, sync_count, p_tree);
|
hist_synchronizer_->SyncHistograms(this, starting_index, sync_count, p_tree);
|
||||||
}
|
}
|
||||||
|
|
||||||
BuildNodeStats(gmat, *p_fmat, gpair_h, nodes_for_apply_split, p_tree);
|
builder_monitor_.Start("EvaluateSplits");
|
||||||
EvaluateSplits(nodes_to_evaluate, gmat, hist_, *p_tree);
|
evaluator_->EvaluateSplits(hist_, gmat, *p_tree, &nodes_to_evaluate);
|
||||||
|
builder_monitor_.Stop("EvaluateSplits");
|
||||||
|
|
||||||
for (size_t i = 0; i < nodes_for_apply_split.size(); ++i) {
|
for (size_t i = 0; i < nodes_for_apply_split.size(); ++i) {
|
||||||
const CPUExpandEntry candidate = nodes_for_apply_split[i];
|
CPUExpandEntry left_node = nodes_to_evaluate.at(i * 2 + 0);
|
||||||
const int nid = candidate.nid;
|
CPUExpandEntry right_node = nodes_to_evaluate.at(i * 2 + 1);
|
||||||
const int cleft = (*p_tree)[nid].LeftChild();
|
|
||||||
const int cright = (*p_tree)[nid].RightChild();
|
|
||||||
CPUExpandEntry left_node = nodes_to_evaluate[i*2 + 0];
|
|
||||||
CPUExpandEntry right_node = nodes_to_evaluate[i*2 + 1];
|
|
||||||
|
|
||||||
left_node.loss_chg = snode_[cleft].best.loss_chg;
|
|
||||||
right_node.loss_chg = snode_[cright].best.loss_chg;
|
|
||||||
|
|
||||||
driver.Push(left_node);
|
driver.Push(left_node);
|
||||||
driver.Push(right_node);
|
driver.Push(right_node);
|
||||||
}
|
}
|
||||||
@ -521,9 +504,6 @@ void QuantileHistMaker::Builder<GradientSumT>::Update(
|
|||||||
gpair_local_ = *gpair_ptr;
|
gpair_local_ = *gpair_ptr;
|
||||||
gpair_ptr = &gpair_local_;
|
gpair_ptr = &gpair_local_;
|
||||||
}
|
}
|
||||||
tree_evaluator_ =
|
|
||||||
TreeEvaluator(param_, p_fmat->Info().num_col_, GenericParameter::kCpuId);
|
|
||||||
interaction_constraints_.Reset();
|
|
||||||
p_last_fmat_mutable_ = p_fmat;
|
p_last_fmat_mutable_ = p_fmat;
|
||||||
|
|
||||||
this->InitData(gmat, *p_fmat, *p_tree, gpair_ptr);
|
this->InitData(gmat, *p_fmat, *p_tree, gpair_ptr);
|
||||||
@ -533,11 +513,6 @@ void QuantileHistMaker::Builder<GradientSumT>::Update(
|
|||||||
} else {
|
} else {
|
||||||
ExpandTree<false>(gmat, column_matrix, p_fmat, p_tree, *gpair_ptr);
|
ExpandTree<false>(gmat, column_matrix, p_fmat, p_tree, *gpair_ptr);
|
||||||
}
|
}
|
||||||
for (int nid = 0; nid < p_tree->param.num_nodes; ++nid) {
|
|
||||||
p_tree->Stat(nid).loss_chg = snode_[nid].best.loss_chg;
|
|
||||||
p_tree->Stat(nid).base_weight = snode_[nid].weight;
|
|
||||||
p_tree->Stat(nid).sum_hess = static_cast<float>(snode_[nid].stats.GetHess());
|
|
||||||
}
|
|
||||||
pruner_->Update(gpair, p_fmat, std::vector<RegTree*>{p_tree});
|
pruner_->Update(gpair, p_fmat, std::vector<RegTree*>{p_tree});
|
||||||
|
|
||||||
builder_monitor_.Stop("Update");
|
builder_monitor_.Stop("Update");
|
||||||
@ -761,14 +736,13 @@ void QuantileHistMaker::Builder<GradientSumT>::InitData(const GHistIndexMatrix&
|
|||||||
// store a pointer to the tree
|
// store a pointer to the tree
|
||||||
p_last_tree_ = &tree;
|
p_last_tree_ = &tree;
|
||||||
if (data_layout_ == DataLayout::kDenseDataOneBased) {
|
if (data_layout_ == DataLayout::kDenseDataOneBased) {
|
||||||
column_sampler_.Init(info.num_col_, info.feature_weigths.ConstHostVector(),
|
evaluator_.reset(new HistEvaluator<GradientSumT, CPUExpandEntry>{
|
||||||
param_.colsample_bynode, param_.colsample_bylevel,
|
param_, info, this->nthread_, column_sampler_, true});
|
||||||
param_.colsample_bytree, true);
|
|
||||||
} else {
|
} else {
|
||||||
column_sampler_.Init(info.num_col_, info.feature_weigths.ConstHostVector(),
|
evaluator_.reset(new HistEvaluator<GradientSumT, CPUExpandEntry>{
|
||||||
param_.colsample_bynode, param_.colsample_bylevel,
|
param_, info, this->nthread_, column_sampler_, false});
|
||||||
param_.colsample_bytree, false);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (data_layout_ == DataLayout::kDenseDataZeroBased
|
if (data_layout_ == DataLayout::kDenseDataZeroBased
|
||||||
|| data_layout_ == DataLayout::kDenseDataOneBased) {
|
|| data_layout_ == DataLayout::kDenseDataOneBased) {
|
||||||
/* specialized code for dense data:
|
/* specialized code for dense data:
|
||||||
@ -789,95 +763,10 @@ void QuantileHistMaker::Builder<GradientSumT>::InitData(const GHistIndexMatrix&
|
|||||||
}
|
}
|
||||||
CHECK_GT(min_nbins_per_feature, 0U);
|
CHECK_GT(min_nbins_per_feature, 0U);
|
||||||
}
|
}
|
||||||
{
|
|
||||||
snode_.reserve(256);
|
|
||||||
snode_.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
builder_monitor_.Stop("InitData");
|
builder_monitor_.Stop("InitData");
|
||||||
}
|
}
|
||||||
|
|
||||||
// if sum of statistics for non-missing values in the node
|
|
||||||
// is equal to sum of statistics for all values:
|
|
||||||
// then - there are no missing values
|
|
||||||
// else - there are missing values
|
|
||||||
template <typename GradientSumT>
|
|
||||||
bool QuantileHistMaker::Builder<GradientSumT>::SplitContainsMissingValues(
|
|
||||||
const GradStats e, const NodeEntry &snode) {
|
|
||||||
if (e.GetGrad() == snode.stats.GetGrad() && e.GetHess() == snode.stats.GetHess()) {
|
|
||||||
return false;
|
|
||||||
} else {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// nodes_set - set of nodes to be processed in parallel
|
|
||||||
template<typename GradientSumT>
|
|
||||||
void QuantileHistMaker::Builder<GradientSumT>::EvaluateSplits(
|
|
||||||
const std::vector<CPUExpandEntry>& nodes_set,
|
|
||||||
const GHistIndexMatrix& gmat,
|
|
||||||
const HistCollection<GradientSumT>& hist,
|
|
||||||
const RegTree& tree) {
|
|
||||||
builder_monitor_.Start("EvaluateSplits");
|
|
||||||
|
|
||||||
const size_t n_nodes_in_set = nodes_set.size();
|
|
||||||
const size_t nthread = std::max(1, this->nthread_);
|
|
||||||
|
|
||||||
using FeatureSetType = std::shared_ptr<HostDeviceVector<bst_feature_t>>;
|
|
||||||
std::vector<FeatureSetType> features_sets(n_nodes_in_set);
|
|
||||||
best_split_tloc_.resize(nthread * n_nodes_in_set);
|
|
||||||
|
|
||||||
// Generate feature set for each tree node
|
|
||||||
for (size_t nid_in_set = 0; nid_in_set < n_nodes_in_set; ++nid_in_set) {
|
|
||||||
const int32_t nid = nodes_set[nid_in_set].nid;
|
|
||||||
features_sets[nid_in_set] = column_sampler_.GetFeatureSet(tree.GetDepth(nid));
|
|
||||||
|
|
||||||
for (unsigned tid = 0; tid < nthread; ++tid) {
|
|
||||||
best_split_tloc_[nthread*nid_in_set + tid] = snode_[nid].best;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create 2D space (# of nodes to process x # of features to process)
|
|
||||||
// to process them in parallel
|
|
||||||
const size_t grain_size = std::max<size_t>(1, features_sets[0]->Size() / nthread);
|
|
||||||
common::BlockedSpace2d space(n_nodes_in_set, [&](size_t nid_in_set) {
|
|
||||||
return features_sets[nid_in_set]->Size();
|
|
||||||
}, grain_size);
|
|
||||||
|
|
||||||
auto evaluator = tree_evaluator_.GetEvaluator();
|
|
||||||
// Start parallel enumeration for all tree nodes in the set and all features
|
|
||||||
common::ParallelFor2d(space, this->nthread_, [&](size_t nid_in_set, common::Range1d r) {
|
|
||||||
const int32_t nid = nodes_set[nid_in_set].nid;
|
|
||||||
const auto tid = static_cast<unsigned>(omp_get_thread_num());
|
|
||||||
GHistRowT node_hist = hist[nid];
|
|
||||||
|
|
||||||
for (auto idx_in_feature_set = r.begin(); idx_in_feature_set < r.end(); ++idx_in_feature_set) {
|
|
||||||
const auto fid = features_sets[nid_in_set]->ConstHostVector()[idx_in_feature_set];
|
|
||||||
if (interaction_constraints_.Query(nid, fid)) {
|
|
||||||
auto grad_stats = this->EnumerateSplit<+1>(
|
|
||||||
gmat, node_hist, snode_[nid],
|
|
||||||
&best_split_tloc_[nthread * nid_in_set + tid], fid, nid, evaluator);
|
|
||||||
if (SplitContainsMissingValues(grad_stats, snode_[nid])) {
|
|
||||||
this->EnumerateSplit<-1>(
|
|
||||||
gmat, node_hist, snode_[nid],
|
|
||||||
&best_split_tloc_[nthread * nid_in_set + tid], fid, nid,
|
|
||||||
evaluator);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// Find Best Split across threads for each node in nodes set
|
|
||||||
for (unsigned nid_in_set = 0; nid_in_set < n_nodes_in_set; ++nid_in_set) {
|
|
||||||
const int32_t nid = nodes_set[nid_in_set].nid;
|
|
||||||
for (unsigned tid = 0; tid < nthread; ++tid) {
|
|
||||||
snode_[nid].best.Update(best_split_tloc_[nthread*nid_in_set + tid]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
builder_monitor_.Stop("EvaluateSplits");
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename GradientSumT>
|
template <typename GradientSumT>
|
||||||
void QuantileHistMaker::Builder<GradientSumT>::FindSplitConditions(
|
void QuantileHistMaker::Builder<GradientSumT>::FindSplitConditions(
|
||||||
const std::vector<CPUExpandEntry>& nodes,
|
const std::vector<CPUExpandEntry>& nodes,
|
||||||
@ -988,139 +877,6 @@ void QuantileHistMaker::Builder<GradientSumT>::ApplySplit(const std::vector<CPUE
|
|||||||
AddSplitsToRowSet(nodes, p_tree);
|
AddSplitsToRowSet(nodes, p_tree);
|
||||||
builder_monitor_.Stop("ApplySplit");
|
builder_monitor_.Stop("ApplySplit");
|
||||||
}
|
}
|
||||||
template <typename GradientSumT>
|
|
||||||
void QuantileHistMaker::Builder<GradientSumT>::InitNewNode(int nid,
|
|
||||||
const GHistIndexMatrix& gmat,
|
|
||||||
const std::vector<GradientPair>& gpair,
|
|
||||||
const DMatrix& fmat,
|
|
||||||
const RegTree& tree) {
|
|
||||||
builder_monitor_.Start("InitNewNode");
|
|
||||||
{
|
|
||||||
snode_.resize(tree.param.num_nodes, NodeEntry(param_));
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
GHistRowT hist = hist_[nid];
|
|
||||||
GradientPairT grad_stat;
|
|
||||||
if (tree[nid].IsRoot()) {
|
|
||||||
if (data_layout_ == DataLayout::kDenseDataZeroBased
|
|
||||||
|| data_layout_ == DataLayout::kDenseDataOneBased) {
|
|
||||||
const std::vector<uint32_t>& row_ptr = gmat.cut.Ptrs();
|
|
||||||
const uint32_t ibegin = row_ptr[fid_least_bins_];
|
|
||||||
const uint32_t iend = row_ptr[fid_least_bins_ + 1];
|
|
||||||
auto begin = hist.data();
|
|
||||||
for (uint32_t i = ibegin; i < iend; ++i) {
|
|
||||||
const GradientPairT et = begin[i];
|
|
||||||
grad_stat.Add(et.GetGrad(), et.GetHess());
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
const RowSetCollection::Elem e = row_set_collection_[nid];
|
|
||||||
for (const size_t* it = e.begin; it < e.end; ++it) {
|
|
||||||
grad_stat.Add(gpair[*it].GetGrad(), gpair[*it].GetHess());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
histred_.Allreduce(&grad_stat, 1);
|
|
||||||
snode_[nid].stats = tree::GradStats(grad_stat.GetGrad(), grad_stat.GetHess());
|
|
||||||
} else {
|
|
||||||
int parent_id = tree[nid].Parent();
|
|
||||||
if (tree[nid].IsLeftChild()) {
|
|
||||||
snode_[nid].stats = snode_[parent_id].best.left_sum;
|
|
||||||
} else {
|
|
||||||
snode_[nid].stats = snode_[parent_id].best.right_sum;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// calculating the weights
|
|
||||||
{
|
|
||||||
auto evaluator = tree_evaluator_.GetEvaluator();
|
|
||||||
bst_uint parentid = tree[nid].Parent();
|
|
||||||
snode_[nid].weight = static_cast<float>(
|
|
||||||
evaluator.CalcWeight(parentid, param_, GradStats{snode_[nid].stats}));
|
|
||||||
snode_[nid].root_gain = static_cast<float>(
|
|
||||||
evaluator.CalcGain(parentid, param_, GradStats{snode_[nid].stats}));
|
|
||||||
}
|
|
||||||
builder_monitor_.Stop("InitNewNode");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Enumerate the split values of specific feature.
|
|
||||||
// Returns the sum of gradients corresponding to the data points that contains a non-missing value
|
|
||||||
// for the particular feature fid.
|
|
||||||
template <typename GradientSumT>
|
|
||||||
template <int d_step>
|
|
||||||
GradStats QuantileHistMaker::Builder<GradientSumT>::EnumerateSplit(
|
|
||||||
const GHistIndexMatrix &gmat, const GHistRowT &hist, const NodeEntry &snode,
|
|
||||||
SplitEntry *p_best, bst_uint fid, bst_uint nodeID,
|
|
||||||
TreeEvaluator::SplitEvaluator<TrainParam> const &evaluator) const {
|
|
||||||
CHECK(d_step == +1 || d_step == -1);
|
|
||||||
|
|
||||||
// aliases
|
|
||||||
const std::vector<uint32_t>& cut_ptr = gmat.cut.Ptrs();
|
|
||||||
const std::vector<bst_float>& cut_val = gmat.cut.Values();
|
|
||||||
|
|
||||||
// statistics on both sides of split
|
|
||||||
GradStats c;
|
|
||||||
GradStats e;
|
|
||||||
// best split so far
|
|
||||||
SplitEntry best;
|
|
||||||
|
|
||||||
// bin boundaries
|
|
||||||
CHECK_LE(cut_ptr[fid],
|
|
||||||
static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
|
|
||||||
CHECK_LE(cut_ptr[fid + 1],
|
|
||||||
static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
|
|
||||||
// imin: index (offset) of the minimum value for feature fid
|
|
||||||
// need this for backward enumeration
|
|
||||||
const auto imin = static_cast<int32_t>(cut_ptr[fid]);
|
|
||||||
// ibegin, iend: smallest/largest cut points for feature fid
|
|
||||||
// use int to allow for value -1
|
|
||||||
int32_t ibegin, iend;
|
|
||||||
if (d_step > 0) {
|
|
||||||
ibegin = static_cast<int32_t>(cut_ptr[fid]);
|
|
||||||
iend = static_cast<int32_t>(cut_ptr[fid + 1]);
|
|
||||||
} else {
|
|
||||||
ibegin = static_cast<int32_t>(cut_ptr[fid + 1]) - 1;
|
|
||||||
iend = static_cast<int32_t>(cut_ptr[fid]) - 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int32_t i = ibegin; i != iend; i += d_step) {
|
|
||||||
// start working
|
|
||||||
// try to find a split
|
|
||||||
e.Add(hist[i].GetGrad(), hist[i].GetHess());
|
|
||||||
if (e.GetHess() >= param_.min_child_weight) {
|
|
||||||
c.SetSubstract(snode.stats, e);
|
|
||||||
if (c.GetHess() >= param_.min_child_weight) {
|
|
||||||
bst_float loss_chg;
|
|
||||||
bst_float split_pt;
|
|
||||||
if (d_step > 0) {
|
|
||||||
// forward enumeration: split at right bound of each bin
|
|
||||||
loss_chg = static_cast<bst_float>(
|
|
||||||
evaluator.CalcSplitGain(param_, nodeID, fid, GradStats{e},
|
|
||||||
GradStats{c}) -
|
|
||||||
snode.root_gain);
|
|
||||||
split_pt = cut_val[i];
|
|
||||||
best.Update(loss_chg, fid, split_pt, d_step == -1, e, c);
|
|
||||||
} else {
|
|
||||||
// backward enumeration: split at left bound of each bin
|
|
||||||
loss_chg = static_cast<bst_float>(
|
|
||||||
evaluator.CalcSplitGain(param_, nodeID, fid, GradStats{c},
|
|
||||||
GradStats{e}) -
|
|
||||||
snode.root_gain);
|
|
||||||
if (i == imin) {
|
|
||||||
// for leftmost bin, left bound is the smallest feature value
|
|
||||||
split_pt = gmat.cut.MinValues()[fid];
|
|
||||||
} else {
|
|
||||||
split_pt = cut_val[i - 1];
|
|
||||||
}
|
|
||||||
best.Update(loss_chg, fid, split_pt, d_step == -1, c, e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
p_best->Update(best);
|
|
||||||
|
|
||||||
return e;
|
|
||||||
}
|
|
||||||
|
|
||||||
template struct QuantileHistMaker::Builder<float>;
|
template struct QuantileHistMaker::Builder<float>;
|
||||||
template struct QuantileHistMaker::Builder<double>;
|
template struct QuantileHistMaker::Builder<double>;
|
||||||
|
|||||||
@ -20,6 +20,8 @@
|
|||||||
|
|
||||||
#include "xgboost/data.h"
|
#include "xgboost/data.h"
|
||||||
#include "xgboost/json.h"
|
#include "xgboost/json.h"
|
||||||
|
|
||||||
|
#include "hist/evaluate_splits.h"
|
||||||
#include "constraints.h"
|
#include "constraints.h"
|
||||||
#include "./param.h"
|
#include "./param.h"
|
||||||
#include "./driver.h"
|
#include "./driver.h"
|
||||||
@ -121,19 +123,23 @@ struct CPUExpandEntry {
|
|||||||
static const int kEmptyNid = -1;
|
static const int kEmptyNid = -1;
|
||||||
int nid;
|
int nid;
|
||||||
int depth;
|
int depth;
|
||||||
bst_float loss_chg;
|
SplitEntry split;
|
||||||
|
|
||||||
|
CPUExpandEntry() = default;
|
||||||
CPUExpandEntry(int nid, int depth, bst_float loss_chg)
|
CPUExpandEntry(int nid, int depth, bst_float loss_chg)
|
||||||
: nid(nid), depth(depth), loss_chg(loss_chg) {}
|
: nid(nid), depth(depth) {
|
||||||
|
split.loss_chg = loss_chg;
|
||||||
|
}
|
||||||
|
|
||||||
bool IsValid(TrainParam const ¶m, int32_t num_leaves) const {
|
bool IsValid(TrainParam const ¶m, int32_t num_leaves) const {
|
||||||
bool ret = loss_chg <= kRtEps ||
|
bool invalid = split.loss_chg <= kRtEps ||
|
||||||
(param.max_depth > 0 && this->depth == param.max_depth) ||
|
(param.max_depth > 0 && this->depth == param.max_depth) ||
|
||||||
(param.max_leaves > 0 && num_leaves == param.max_leaves);
|
(param.max_leaves > 0 && num_leaves == param.max_leaves);
|
||||||
return ret;
|
return !invalid;
|
||||||
}
|
}
|
||||||
|
|
||||||
bst_float GetLossChange() const {
|
bst_float GetLossChange() const {
|
||||||
return loss_chg;
|
return split.loss_chg;
|
||||||
}
|
}
|
||||||
|
|
||||||
int GetNodeId() const {
|
int GetNodeId() const {
|
||||||
@ -214,39 +220,17 @@ class QuantileHistMaker: public TreeUpdater {
|
|||||||
DMatrix const* p_last_dmat_ {nullptr};
|
DMatrix const* p_last_dmat_ {nullptr};
|
||||||
bool is_gmat_initialized_ {false};
|
bool is_gmat_initialized_ {false};
|
||||||
|
|
||||||
// data structure
|
|
||||||
struct NodeEntry {
|
|
||||||
/*! \brief statics for node entry */
|
|
||||||
GradStats stats;
|
|
||||||
/*! \brief loss of this node, without split */
|
|
||||||
bst_float root_gain;
|
|
||||||
/*! \brief weight calculated related to current data */
|
|
||||||
float weight;
|
|
||||||
/*! \brief current best solution */
|
|
||||||
SplitEntry best;
|
|
||||||
// constructor
|
|
||||||
explicit NodeEntry(const TrainParam&)
|
|
||||||
: root_gain(0.0f), weight(0.0f) {}
|
|
||||||
};
|
|
||||||
// actual builder that runs the algorithm
|
// actual builder that runs the algorithm
|
||||||
|
|
||||||
template<typename GradientSumT>
|
template<typename GradientSumT>
|
||||||
struct Builder {
|
struct Builder {
|
||||||
public:
|
public:
|
||||||
using GHistRowT = GHistRow<GradientSumT>;
|
using GHistRowT = GHistRow<GradientSumT>;
|
||||||
using GradientPairT = xgboost::detail::GradientPairInternal<GradientSumT>;
|
using GradientPairT = xgboost::detail::GradientPairInternal<GradientSumT>;
|
||||||
// constructor
|
// constructor
|
||||||
explicit Builder(const size_t n_trees,
|
explicit Builder(const size_t n_trees, const TrainParam ¶m,
|
||||||
const TrainParam& param,
|
std::unique_ptr<TreeUpdater> pruner, DMatrix const *fmat)
|
||||||
std::unique_ptr<TreeUpdater> pruner,
|
: n_trees_(n_trees), param_(param), pruner_(std::move(pruner)),
|
||||||
FeatureInteractionConstraintHost int_constraints_,
|
p_last_tree_(nullptr), p_last_fmat_(fmat) {
|
||||||
DMatrix const* fmat)
|
|
||||||
: n_trees_(n_trees),
|
|
||||||
param_(param),
|
|
||||||
tree_evaluator_(param, fmat->Info().num_col_, GenericParameter::kCpuId),
|
|
||||||
pruner_(std::move(pruner)),
|
|
||||||
interaction_constraints_{std::move(int_constraints_)},
|
|
||||||
p_last_tree_(nullptr), p_last_fmat_(fmat) {
|
|
||||||
builder_monitor_.Init("Quantile::Builder");
|
builder_monitor_.Init("Quantile::Builder");
|
||||||
}
|
}
|
||||||
// update one tree, growing
|
// update one tree, growing
|
||||||
@ -290,11 +274,6 @@ class QuantileHistMaker: public TreeUpdater {
|
|||||||
std::vector<GradientPair>* gpair,
|
std::vector<GradientPair>* gpair,
|
||||||
std::vector<size_t>* row_indices);
|
std::vector<size_t>* row_indices);
|
||||||
|
|
||||||
void EvaluateSplits(const std::vector<CPUExpandEntry>& nodes_set,
|
|
||||||
const GHistIndexMatrix& gmat,
|
|
||||||
const HistCollection<GradientSumT>& hist,
|
|
||||||
const RegTree& tree);
|
|
||||||
|
|
||||||
template <bool any_missing>
|
template <bool any_missing>
|
||||||
void ApplySplit(std::vector<CPUExpandEntry> nodes,
|
void ApplySplit(std::vector<CPUExpandEntry> nodes,
|
||||||
const GHistIndexMatrix& gmat,
|
const GHistIndexMatrix& gmat,
|
||||||
@ -308,26 +287,6 @@ class QuantileHistMaker: public TreeUpdater {
|
|||||||
void FindSplitConditions(const std::vector<CPUExpandEntry>& nodes, const RegTree& tree,
|
void FindSplitConditions(const std::vector<CPUExpandEntry>& nodes, const RegTree& tree,
|
||||||
const GHistIndexMatrix& gmat, std::vector<int32_t>* split_conditions);
|
const GHistIndexMatrix& gmat, std::vector<int32_t>* split_conditions);
|
||||||
|
|
||||||
void InitNewNode(int nid,
|
|
||||||
const GHistIndexMatrix& gmat,
|
|
||||||
const std::vector<GradientPair>& gpair,
|
|
||||||
const DMatrix& fmat,
|
|
||||||
const RegTree& tree);
|
|
||||||
|
|
||||||
// Enumerate the split values of specific feature
|
|
||||||
// Returns the sum of gradients corresponding to the data points that contains a non-missing
|
|
||||||
// value for the particular feature fid.
|
|
||||||
template <int d_step>
|
|
||||||
GradStats EnumerateSplit(const GHistIndexMatrix &gmat, const GHistRowT &hist,
|
|
||||||
const NodeEntry &snode, SplitEntry *p_best, bst_uint fid,
|
|
||||||
bst_uint nodeID,
|
|
||||||
TreeEvaluator::SplitEvaluator<TrainParam> const &evaluator) const;
|
|
||||||
|
|
||||||
// if sum of statistics for non-missing values in the node
|
|
||||||
// is equal to sum of statistics for all values:
|
|
||||||
// then - there are no missing values
|
|
||||||
// else - there are missing values
|
|
||||||
bool SplitContainsMissingValues(const GradStats e, const NodeEntry& snode);
|
|
||||||
|
|
||||||
template <bool any_missing>
|
template <bool any_missing>
|
||||||
void BuildLocalHistograms(const GHistIndexMatrix &gmat,
|
void BuildLocalHistograms(const GHistIndexMatrix &gmat,
|
||||||
@ -352,10 +311,6 @@ class QuantileHistMaker: public TreeUpdater {
|
|||||||
int *num_leaves,
|
int *num_leaves,
|
||||||
std::vector<CPUExpandEntry>* nodes_for_apply_split);
|
std::vector<CPUExpandEntry>* nodes_for_apply_split);
|
||||||
|
|
||||||
void BuildNodeStats(const GHistIndexMatrix &gmat,
|
|
||||||
const DMatrix& fmat,
|
|
||||||
const std::vector<GradientPair> &gpair_h,
|
|
||||||
const std::vector<CPUExpandEntry>& nodes_for_apply_split, RegTree *p_tree);
|
|
||||||
template <bool any_missing>
|
template <bool any_missing>
|
||||||
void ExpandTree(const GHistIndexMatrix& gmat,
|
void ExpandTree(const GHistIndexMatrix& gmat,
|
||||||
const ColumnMatrix& column_matrix,
|
const ColumnMatrix& column_matrix,
|
||||||
@ -368,31 +323,24 @@ class QuantileHistMaker: public TreeUpdater {
|
|||||||
const TrainParam& param_;
|
const TrainParam& param_;
|
||||||
// number of omp thread used during training
|
// number of omp thread used during training
|
||||||
int nthread_;
|
int nthread_;
|
||||||
common::ColumnSampler column_sampler_;
|
std::shared_ptr<common::ColumnSampler> column_sampler_{
|
||||||
|
std::make_shared<common::ColumnSampler>()};
|
||||||
|
|
||||||
|
std::vector<size_t> unused_rows_;
|
||||||
// the internal row sets
|
// the internal row sets
|
||||||
RowSetCollection row_set_collection_;
|
RowSetCollection row_set_collection_;
|
||||||
// tree rows that were not used for current training
|
|
||||||
std::vector<size_t> unused_rows_;
|
|
||||||
// feature vectors for subsampled prediction
|
|
||||||
std::vector<RegTree::FVec> feat_vecs_;
|
|
||||||
// the temp space for split
|
|
||||||
std::vector<RowSetCollection::Split> row_split_tloc_;
|
|
||||||
std::vector<SplitEntry> best_split_tloc_;
|
|
||||||
/*! \brief TreeNode Data: statistics for each constructed node */
|
|
||||||
std::vector<NodeEntry> snode_;
|
|
||||||
std::vector<GradientPair> gpair_local_;
|
std::vector<GradientPair> gpair_local_;
|
||||||
/*! \brief culmulative histogram of gradients. */
|
/*! \brief culmulative histogram of gradients. */
|
||||||
HistCollection<GradientSumT> hist_;
|
HistCollection<GradientSumT> hist_;
|
||||||
/*! \brief culmulative local parent histogram of gradients. */
|
/*! \brief culmulative local parent histogram of gradients. */
|
||||||
HistCollection<GradientSumT> hist_local_worker_;
|
HistCollection<GradientSumT> hist_local_worker_;
|
||||||
TreeEvaluator tree_evaluator_;
|
|
||||||
/*! \brief feature with least # of bins. to be used for dense specialization
|
/*! \brief feature with least # of bins. to be used for dense specialization
|
||||||
of InitNewNode() */
|
of InitNewNode() */
|
||||||
uint32_t fid_least_bins_;
|
uint32_t fid_least_bins_;
|
||||||
|
|
||||||
GHistBuilder<GradientSumT> hist_builder_;
|
GHistBuilder<GradientSumT> hist_builder_;
|
||||||
std::unique_ptr<TreeUpdater> pruner_;
|
std::unique_ptr<TreeUpdater> pruner_;
|
||||||
FeatureInteractionConstraintHost interaction_constraints_;
|
std::unique_ptr<HistEvaluator<GradientSumT, CPUExpandEntry>> evaluator_;
|
||||||
|
|
||||||
static constexpr size_t kPartitionBlockSize = 2048;
|
static constexpr size_t kPartitionBlockSize = 2048;
|
||||||
common::PartitionBuilder<kPartitionBlockSize> partition_builder_;
|
common::PartitionBuilder<kPartitionBlockSize> partition_builder_;
|
||||||
@ -402,10 +350,6 @@ class QuantileHistMaker: public TreeUpdater {
|
|||||||
DMatrix const* const p_last_fmat_;
|
DMatrix const* const p_last_fmat_;
|
||||||
DMatrix* p_last_fmat_mutable_;
|
DMatrix* p_last_fmat_mutable_;
|
||||||
|
|
||||||
using ExpandQueue =
|
|
||||||
std::priority_queue<CPUExpandEntry, std::vector<CPUExpandEntry>,
|
|
||||||
std::function<bool(CPUExpandEntry, CPUExpandEntry)>>;
|
|
||||||
|
|
||||||
// key is the node id which should be calculated by Subtraction Trick, value is the node which
|
// key is the node id which should be calculated by Subtraction Trick, value is the node which
|
||||||
// provides the evidence for subtraction
|
// provides the evidence for subtraction
|
||||||
std::vector<CPUExpandEntry> nodes_for_subtraction_trick_;
|
std::vector<CPUExpandEntry> nodes_for_subtraction_trick_;
|
||||||
@ -438,7 +382,6 @@ class QuantileHistMaker: public TreeUpdater {
|
|||||||
std::unique_ptr<Builder<double>> double_builder_;
|
std::unique_ptr<Builder<double>> double_builder_;
|
||||||
|
|
||||||
std::unique_ptr<TreeUpdater> pruner_;
|
std::unique_ptr<TreeUpdater> pruner_;
|
||||||
FeatureInteractionConstraintHost int_constraint_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename GradientSumT>
|
template <typename GradientSumT>
|
||||||
|
|||||||
112
tests/cpp/tree/hist/test_evaluate_splits.cc
Normal file
112
tests/cpp/tree/hist/test_evaluate_splits.cc
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include <xgboost/base.h>
|
||||||
|
#include "../../../../src/tree/hist/evaluate_splits.h"
|
||||||
|
#include "../../../../src/tree/updater_quantile_hist.h"
|
||||||
|
#include "../../../../src/common/hist_util.h"
|
||||||
|
#include "../../helpers.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace tree {
|
||||||
|
|
||||||
|
template <typename GradientSumT> void TestEvaluateSplits() {
|
||||||
|
int static constexpr kRows = 8, kCols = 16;
|
||||||
|
auto orig = omp_get_max_threads();
|
||||||
|
int32_t n_threads = std::min(omp_get_max_threads(), 4);
|
||||||
|
omp_set_num_threads(n_threads);
|
||||||
|
auto sampler = std::make_shared<common::ColumnSampler>();
|
||||||
|
|
||||||
|
TrainParam param;
|
||||||
|
param.UpdateAllowUnknown(Args{{}});
|
||||||
|
param.min_child_weight = 0;
|
||||||
|
param.reg_lambda = 0;
|
||||||
|
|
||||||
|
auto dmat = RandomDataGenerator(kRows, kCols, 0).Seed(3).GenerateDMatrix();
|
||||||
|
|
||||||
|
auto evaluator =
|
||||||
|
HistEvaluator<GradientSumT, CPUExpandEntry>{param, dmat->Info(), n_threads, sampler};
|
||||||
|
common::HistCollection<GradientSumT> hist;
|
||||||
|
std::vector<GradientPair> row_gpairs = {
|
||||||
|
{1.23f, 0.24f}, {0.24f, 0.25f}, {0.26f, 0.27f}, {2.27f, 0.28f},
|
||||||
|
{0.27f, 0.29f}, {0.37f, 0.39f}, {-0.47f, 0.49f}, {0.57f, 0.59f}};
|
||||||
|
|
||||||
|
size_t constexpr kMaxBins = 4;
|
||||||
|
// dense, no missing values
|
||||||
|
|
||||||
|
GHistIndexMatrix gmat(dmat.get(), kMaxBins);
|
||||||
|
common::RowSetCollection row_set_collection;
|
||||||
|
std::vector<size_t> &row_indices = *row_set_collection.Data();
|
||||||
|
row_indices.resize(kRows);
|
||||||
|
std::iota(row_indices.begin(), row_indices.end(), 0);
|
||||||
|
row_set_collection.Init();
|
||||||
|
|
||||||
|
auto hist_builder = GHistBuilder<GradientSumT>(n_threads, gmat.cut.Ptrs().back());
|
||||||
|
hist.Init(gmat.cut.Ptrs().back());
|
||||||
|
hist.AddHistRow(0);
|
||||||
|
hist.AllocateAllData();
|
||||||
|
hist_builder.template BuildHist<false>(row_gpairs, row_set_collection[0],
|
||||||
|
gmat, hist[0]);
|
||||||
|
|
||||||
|
// Compute total gradient for all data points
|
||||||
|
GradientPairPrecise total_gpair;
|
||||||
|
for (const auto &e : row_gpairs) {
|
||||||
|
total_gpair += GradientPairPrecise(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
RegTree tree;
|
||||||
|
std::vector<CPUExpandEntry> entries(1);
|
||||||
|
entries.front().nid = 0;
|
||||||
|
entries.front().depth = 0;
|
||||||
|
|
||||||
|
evaluator.InitRoot(GradStats{total_gpair});
|
||||||
|
evaluator.EvaluateSplits(hist, gmat, tree, &entries);
|
||||||
|
|
||||||
|
auto best_loss_chg =
|
||||||
|
evaluator.Evaluator().CalcSplitGain(
|
||||||
|
param, 0, entries.front().split.SplitIndex(),
|
||||||
|
entries.front().split.left_sum, entries.front().split.right_sum) -
|
||||||
|
evaluator.Stats().front().root_gain;
|
||||||
|
ASSERT_EQ(entries.front().split.loss_chg, best_loss_chg);
|
||||||
|
ASSERT_GT(entries.front().split.loss_chg, 16.2f);
|
||||||
|
|
||||||
|
// Assert that's the best split
|
||||||
|
for (size_t i = 1; i < gmat.cut.Ptrs().size(); ++i) {
|
||||||
|
GradStats left, right;
|
||||||
|
for (size_t j = gmat.cut.Ptrs()[i-1]; j < gmat.cut.Ptrs()[i]; ++j) {
|
||||||
|
auto loss_chg =
|
||||||
|
evaluator.Evaluator().CalcSplitGain(param, 0, i - 1, left, right) -
|
||||||
|
evaluator.Stats().front().root_gain;
|
||||||
|
ASSERT_GE(best_loss_chg, loss_chg);
|
||||||
|
left.Add(hist[0][j].GetGrad(), hist[0][j].GetHess());
|
||||||
|
right.SetSubstract(GradStats{total_gpair}, left);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
omp_set_num_threads(orig);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(HistEvaluator, Evaluate) {
|
||||||
|
TestEvaluateSplits<float>();
|
||||||
|
TestEvaluateSplits<double>();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(HistEvaluator, Apply) {
|
||||||
|
RegTree tree;
|
||||||
|
int static constexpr kNRows = 8, kNCols = 16;
|
||||||
|
TrainParam param;
|
||||||
|
param.UpdateAllowUnknown(Args{{}});
|
||||||
|
auto dmat = RandomDataGenerator(kNRows, kNCols, 0).Seed(3).GenerateDMatrix();
|
||||||
|
auto sampler = std::make_shared<common::ColumnSampler>();
|
||||||
|
auto evaluator_ =
|
||||||
|
HistEvaluator<float, CPUExpandEntry>{param, dmat->Info(), 4, sampler};
|
||||||
|
|
||||||
|
CPUExpandEntry entry{0, 0, 10.0f};
|
||||||
|
entry.split.left_sum = GradStats{0.4, 0.6f};
|
||||||
|
entry.split.right_sum = GradStats{0.5, 0.7f};
|
||||||
|
|
||||||
|
evaluator_.ApplyTreeSplit(entry, &tree);
|
||||||
|
ASSERT_EQ(tree.NumExtraNodes(), 2);
|
||||||
|
ASSERT_EQ(tree.Stat(tree[0].LeftChild()).sum_hess, 0.6f);
|
||||||
|
ASSERT_EQ(tree.Stat(tree[0].RightChild()).sum_hess, 0.7f);
|
||||||
|
}
|
||||||
|
} // namespace tree
|
||||||
|
} // namespace xgboost
|
||||||
@ -26,12 +26,9 @@ class QuantileHistMock : public QuantileHistMaker {
|
|||||||
using RealImpl = QuantileHistMaker::Builder<GradientSumT>;
|
using RealImpl = QuantileHistMaker::Builder<GradientSumT>;
|
||||||
using GHistRowT = typename RealImpl::GHistRowT;
|
using GHistRowT = typename RealImpl::GHistRowT;
|
||||||
|
|
||||||
BuilderMock(const TrainParam& param,
|
BuilderMock(const TrainParam ¶m, std::unique_ptr<TreeUpdater> pruner,
|
||||||
std::unique_ptr<TreeUpdater> pruner,
|
DMatrix const *fmat)
|
||||||
FeatureInteractionConstraintHost int_constraint,
|
: RealImpl(1, param, std::move(pruner), fmat) {}
|
||||||
DMatrix const* fmat)
|
|
||||||
: RealImpl(1, param, std::move(pruner),
|
|
||||||
std::move(int_constraint), fmat) {}
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
void TestInitData(const GHistIndexMatrix& gmat,
|
void TestInitData(const GHistIndexMatrix& gmat,
|
||||||
@ -336,92 +333,6 @@ class QuantileHistMock : public QuantileHistMaker {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestEvaluateSplit(const RegTree& tree) {
|
|
||||||
std::vector<GradientPair> row_gpairs =
|
|
||||||
{ {1.23f, 0.24f}, {0.24f, 0.25f}, {0.26f, 0.27f}, {2.27f, 0.28f},
|
|
||||||
{0.27f, 0.29f}, {0.37f, 0.39f}, {-0.47f, 0.49f}, {0.57f, 0.59f} };
|
|
||||||
size_t constexpr kMaxBins = 4;
|
|
||||||
auto dmat = RandomDataGenerator(kNRows, kNCols, 0).Seed(3).GenerateDMatrix();
|
|
||||||
// dense, no missing values
|
|
||||||
|
|
||||||
GHistIndexMatrix gmat(dmat.get(), kMaxBins);
|
|
||||||
|
|
||||||
RealImpl::InitData(gmat, *dmat, tree, &row_gpairs);
|
|
||||||
this->hist_.AddHistRow(0);
|
|
||||||
this->hist_.AllocateAllData();
|
|
||||||
this->hist_builder_.template BuildHist<false>(row_gpairs, this->row_set_collection_[0],
|
|
||||||
gmat, this->hist_[0]);
|
|
||||||
|
|
||||||
RealImpl::InitNewNode(0, gmat, row_gpairs, *dmat, tree);
|
|
||||||
|
|
||||||
/* Compute correct split (best_split) using the computed histogram */
|
|
||||||
const size_t num_row = dmat->Info().num_row_;
|
|
||||||
const size_t num_feature = dmat->Info().num_col_;
|
|
||||||
CHECK_EQ(num_row, row_gpairs.size());
|
|
||||||
// Compute total gradient for all data points
|
|
||||||
GradientPairPrecise total_gpair;
|
|
||||||
for (const auto& e : row_gpairs) {
|
|
||||||
total_gpair += GradientPairPrecise(e);
|
|
||||||
}
|
|
||||||
// Now enumerate all feature*threshold combination to get best split
|
|
||||||
// To simplify logic, we make some assumptions:
|
|
||||||
// 1) no missing values in data
|
|
||||||
// 2) no regularization, i.e. set min_child_weight, reg_lambda, reg_alpha,
|
|
||||||
// and max_delta_step to 0.
|
|
||||||
bst_float best_split_gain = 0.0f;
|
|
||||||
size_t best_split_threshold = std::numeric_limits<size_t>::max();
|
|
||||||
size_t best_split_feature = std::numeric_limits<size_t>::max();
|
|
||||||
// Enumerate all features
|
|
||||||
for (size_t fid = 0; fid < num_feature; ++fid) {
|
|
||||||
const size_t bin_id_min = gmat.cut.Ptrs()[fid];
|
|
||||||
const size_t bin_id_max = gmat.cut.Ptrs()[fid + 1];
|
|
||||||
// Enumerate all bin ID in [bin_id_min, bin_id_max), i.e. every possible
|
|
||||||
// choice of thresholds for feature fid
|
|
||||||
for (size_t split_thresh = bin_id_min;
|
|
||||||
split_thresh < bin_id_max; ++split_thresh) {
|
|
||||||
// left_sum, right_sum: Gradient sums for data points whose feature
|
|
||||||
// value is left/right side of the split threshold
|
|
||||||
GradientPairPrecise left_sum, right_sum;
|
|
||||||
for (size_t rid = 0; rid < num_row; ++rid) {
|
|
||||||
for (size_t offset = gmat.row_ptr[rid];
|
|
||||||
offset < gmat.row_ptr[rid + 1]; ++offset) {
|
|
||||||
const size_t bin_id = gmat.index[offset];
|
|
||||||
if (bin_id >= bin_id_min && bin_id < bin_id_max) {
|
|
||||||
if (bin_id <= split_thresh) {
|
|
||||||
left_sum += GradientPairPrecise(row_gpairs[rid]);
|
|
||||||
} else {
|
|
||||||
right_sum += GradientPairPrecise(row_gpairs[rid]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Now compute gain (change in loss)
|
|
||||||
auto evaluator = this->tree_evaluator_.GetEvaluator();
|
|
||||||
const auto split_gain = evaluator.CalcSplitGain(
|
|
||||||
this->param_, 0, fid, GradStats(left_sum), GradStats(right_sum));
|
|
||||||
if (split_gain > best_split_gain) {
|
|
||||||
best_split_gain = split_gain;
|
|
||||||
best_split_feature = fid;
|
|
||||||
best_split_threshold = split_thresh;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Now compare against result given by EvaluateSplit() */
|
|
||||||
CPUExpandEntry node(CPUExpandEntry::kRootNid,
|
|
||||||
tree.GetDepth(0),
|
|
||||||
this->snode_[0].best.loss_chg);
|
|
||||||
RealImpl::EvaluateSplits({node}, gmat, this->hist_, tree);
|
|
||||||
ASSERT_EQ(this->snode_[0].best.SplitIndex(), best_split_feature);
|
|
||||||
ASSERT_EQ(this->snode_[0].best.split_value, gmat.cut.Values()[best_split_threshold]);
|
|
||||||
}
|
|
||||||
|
|
||||||
void TestEvaluateSplitParallel(const RegTree &tree) {
|
|
||||||
omp_set_num_threads(2);
|
|
||||||
TestEvaluateSplit(tree);
|
|
||||||
omp_set_num_threads(1);
|
|
||||||
}
|
|
||||||
|
|
||||||
void TestApplySplit(const RegTree& tree) {
|
void TestApplySplit(const RegTree& tree) {
|
||||||
std::vector<GradientPair> row_gpairs =
|
std::vector<GradientPair> row_gpairs =
|
||||||
{ {1.23f, 0.24f}, {0.24f, 0.25f}, {0.26f, 0.27f}, {2.27f, 0.28f},
|
{ {1.23f, 0.24f}, {0.24f, 0.25f}, {0.26f, 0.27f}, {2.27f, 0.28f},
|
||||||
@ -441,7 +352,6 @@ class QuantileHistMock : public QuantileHistMaker {
|
|||||||
RealImpl::InitData(gmat, *dmat, tree, &row_gpairs);
|
RealImpl::InitData(gmat, *dmat, tree, &row_gpairs);
|
||||||
this->hist_.AddHistRow(0);
|
this->hist_.AddHistRow(0);
|
||||||
this->hist_.AllocateAllData();
|
this->hist_.AllocateAllData();
|
||||||
RealImpl::InitNewNode(0, gmat, row_gpairs, *dmat, tree);
|
|
||||||
|
|
||||||
const size_t num_row = dmat->Info().num_row_;
|
const size_t num_row = dmat->Info().num_row_;
|
||||||
// split by feature 0
|
// split by feature 0
|
||||||
@ -513,7 +423,6 @@ class QuantileHistMock : public QuantileHistMaker {
|
|||||||
new BuilderMock<float>(
|
new BuilderMock<float>(
|
||||||
param_,
|
param_,
|
||||||
std::move(pruner_),
|
std::move(pruner_),
|
||||||
int_constraint_,
|
|
||||||
dmat_.get()));
|
dmat_.get()));
|
||||||
if (batch) {
|
if (batch) {
|
||||||
float_builder_->SetHistSynchronizer(new BatchHistSynchronizer<float>());
|
float_builder_->SetHistSynchronizer(new BatchHistSynchronizer<float>());
|
||||||
@ -527,7 +436,6 @@ class QuantileHistMock : public QuantileHistMaker {
|
|||||||
new BuilderMock<double>(
|
new BuilderMock<double>(
|
||||||
param_,
|
param_,
|
||||||
std::move(pruner_),
|
std::move(pruner_),
|
||||||
int_constraint_,
|
|
||||||
dmat_.get()));
|
dmat_.get()));
|
||||||
if (batch) {
|
if (batch) {
|
||||||
double_builder_->SetHistSynchronizer(new BatchHistSynchronizer<double>());
|
double_builder_->SetHistSynchronizer(new BatchHistSynchronizer<double>());
|
||||||
@ -622,23 +530,13 @@ class QuantileHistMock : public QuantileHistMaker {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestEvaluateSplit() {
|
|
||||||
RegTree tree = RegTree();
|
|
||||||
tree.param.UpdateAllowUnknown(cfg_);
|
|
||||||
if (double_builder_) {
|
|
||||||
double_builder_->TestEvaluateSplit(tree);
|
|
||||||
} else {
|
|
||||||
float_builder_->TestEvaluateSplit(tree);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void TestApplySplit() {
|
void TestApplySplit() {
|
||||||
RegTree tree = RegTree();
|
RegTree tree = RegTree();
|
||||||
tree.param.UpdateAllowUnknown(cfg_);
|
tree.param.UpdateAllowUnknown(cfg_);
|
||||||
if (double_builder_) {
|
if (double_builder_) {
|
||||||
double_builder_->TestApplySplit(tree);
|
double_builder_->TestApplySplit(tree);
|
||||||
} else {
|
} else {
|
||||||
float_builder_->TestEvaluateSplit(tree);
|
float_builder_->TestApplySplit(tree);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -716,19 +614,6 @@ TEST(QuantileHist, BuildHist) {
|
|||||||
maker_float.TestBuildHist();
|
maker_float.TestBuildHist();
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(QuantileHist, EvalSplits) {
|
|
||||||
std::vector<std::pair<std::string, std::string>> cfg
|
|
||||||
{{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())},
|
|
||||||
{"split_evaluator", "elastic_net"},
|
|
||||||
{"reg_lambda", "0"}, {"reg_alpha", "0"}, {"max_delta_step", "0"},
|
|
||||||
{"min_child_weight", "0"}};
|
|
||||||
QuantileHistMock maker(cfg);
|
|
||||||
maker.TestEvaluateSplit();
|
|
||||||
const bool single_precision_histogram = true;
|
|
||||||
QuantileHistMock maker_float(cfg, single_precision_histogram);
|
|
||||||
maker_float.TestEvaluateSplit();
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(QuantileHist, ApplySplit) {
|
TEST(QuantileHist, ApplySplit) {
|
||||||
std::vector<std::pair<std::string, std::string>> cfg
|
std::vector<std::pair<std::string, std::string>> cfg
|
||||||
{{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())},
|
{{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())},
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user