Extract evaluate splits from CPU hist. (#7079)
Other than modularizing the split evaluation function, this PR also removes some more functions including `InitNewNodes` and `BuildNodeStats` among some other unused variables. Also, scattered code like setting leaf weights is grouped into the split evaluator and `NodeEntry` is simplified and made private. Another subtle difference with the original implementation is that the modified code doesn't call `tree[nidx].Parent()` to traversal upward.
This commit is contained in:
@@ -26,12 +26,9 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
using RealImpl = QuantileHistMaker::Builder<GradientSumT>;
|
||||
using GHistRowT = typename RealImpl::GHistRowT;
|
||||
|
||||
BuilderMock(const TrainParam& param,
|
||||
std::unique_ptr<TreeUpdater> pruner,
|
||||
FeatureInteractionConstraintHost int_constraint,
|
||||
DMatrix const* fmat)
|
||||
: RealImpl(1, param, std::move(pruner),
|
||||
std::move(int_constraint), fmat) {}
|
||||
BuilderMock(const TrainParam ¶m, std::unique_ptr<TreeUpdater> pruner,
|
||||
DMatrix const *fmat)
|
||||
: RealImpl(1, param, std::move(pruner), fmat) {}
|
||||
|
||||
public:
|
||||
void TestInitData(const GHistIndexMatrix& gmat,
|
||||
@@ -336,92 +333,6 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
}
|
||||
}
|
||||
|
||||
void TestEvaluateSplit(const RegTree& tree) {
|
||||
std::vector<GradientPair> row_gpairs =
|
||||
{ {1.23f, 0.24f}, {0.24f, 0.25f}, {0.26f, 0.27f}, {2.27f, 0.28f},
|
||||
{0.27f, 0.29f}, {0.37f, 0.39f}, {-0.47f, 0.49f}, {0.57f, 0.59f} };
|
||||
size_t constexpr kMaxBins = 4;
|
||||
auto dmat = RandomDataGenerator(kNRows, kNCols, 0).Seed(3).GenerateDMatrix();
|
||||
// dense, no missing values
|
||||
|
||||
GHistIndexMatrix gmat(dmat.get(), kMaxBins);
|
||||
|
||||
RealImpl::InitData(gmat, *dmat, tree, &row_gpairs);
|
||||
this->hist_.AddHistRow(0);
|
||||
this->hist_.AllocateAllData();
|
||||
this->hist_builder_.template BuildHist<false>(row_gpairs, this->row_set_collection_[0],
|
||||
gmat, this->hist_[0]);
|
||||
|
||||
RealImpl::InitNewNode(0, gmat, row_gpairs, *dmat, tree);
|
||||
|
||||
/* Compute correct split (best_split) using the computed histogram */
|
||||
const size_t num_row = dmat->Info().num_row_;
|
||||
const size_t num_feature = dmat->Info().num_col_;
|
||||
CHECK_EQ(num_row, row_gpairs.size());
|
||||
// Compute total gradient for all data points
|
||||
GradientPairPrecise total_gpair;
|
||||
for (const auto& e : row_gpairs) {
|
||||
total_gpair += GradientPairPrecise(e);
|
||||
}
|
||||
// Now enumerate all feature*threshold combination to get best split
|
||||
// To simplify logic, we make some assumptions:
|
||||
// 1) no missing values in data
|
||||
// 2) no regularization, i.e. set min_child_weight, reg_lambda, reg_alpha,
|
||||
// and max_delta_step to 0.
|
||||
bst_float best_split_gain = 0.0f;
|
||||
size_t best_split_threshold = std::numeric_limits<size_t>::max();
|
||||
size_t best_split_feature = std::numeric_limits<size_t>::max();
|
||||
// Enumerate all features
|
||||
for (size_t fid = 0; fid < num_feature; ++fid) {
|
||||
const size_t bin_id_min = gmat.cut.Ptrs()[fid];
|
||||
const size_t bin_id_max = gmat.cut.Ptrs()[fid + 1];
|
||||
// Enumerate all bin ID in [bin_id_min, bin_id_max), i.e. every possible
|
||||
// choice of thresholds for feature fid
|
||||
for (size_t split_thresh = bin_id_min;
|
||||
split_thresh < bin_id_max; ++split_thresh) {
|
||||
// left_sum, right_sum: Gradient sums for data points whose feature
|
||||
// value is left/right side of the split threshold
|
||||
GradientPairPrecise left_sum, right_sum;
|
||||
for (size_t rid = 0; rid < num_row; ++rid) {
|
||||
for (size_t offset = gmat.row_ptr[rid];
|
||||
offset < gmat.row_ptr[rid + 1]; ++offset) {
|
||||
const size_t bin_id = gmat.index[offset];
|
||||
if (bin_id >= bin_id_min && bin_id < bin_id_max) {
|
||||
if (bin_id <= split_thresh) {
|
||||
left_sum += GradientPairPrecise(row_gpairs[rid]);
|
||||
} else {
|
||||
right_sum += GradientPairPrecise(row_gpairs[rid]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Now compute gain (change in loss)
|
||||
auto evaluator = this->tree_evaluator_.GetEvaluator();
|
||||
const auto split_gain = evaluator.CalcSplitGain(
|
||||
this->param_, 0, fid, GradStats(left_sum), GradStats(right_sum));
|
||||
if (split_gain > best_split_gain) {
|
||||
best_split_gain = split_gain;
|
||||
best_split_feature = fid;
|
||||
best_split_threshold = split_thresh;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Now compare against result given by EvaluateSplit() */
|
||||
CPUExpandEntry node(CPUExpandEntry::kRootNid,
|
||||
tree.GetDepth(0),
|
||||
this->snode_[0].best.loss_chg);
|
||||
RealImpl::EvaluateSplits({node}, gmat, this->hist_, tree);
|
||||
ASSERT_EQ(this->snode_[0].best.SplitIndex(), best_split_feature);
|
||||
ASSERT_EQ(this->snode_[0].best.split_value, gmat.cut.Values()[best_split_threshold]);
|
||||
}
|
||||
|
||||
void TestEvaluateSplitParallel(const RegTree &tree) {
|
||||
omp_set_num_threads(2);
|
||||
TestEvaluateSplit(tree);
|
||||
omp_set_num_threads(1);
|
||||
}
|
||||
|
||||
void TestApplySplit(const RegTree& tree) {
|
||||
std::vector<GradientPair> row_gpairs =
|
||||
{ {1.23f, 0.24f}, {0.24f, 0.25f}, {0.26f, 0.27f}, {2.27f, 0.28f},
|
||||
@@ -441,7 +352,6 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
RealImpl::InitData(gmat, *dmat, tree, &row_gpairs);
|
||||
this->hist_.AddHistRow(0);
|
||||
this->hist_.AllocateAllData();
|
||||
RealImpl::InitNewNode(0, gmat, row_gpairs, *dmat, tree);
|
||||
|
||||
const size_t num_row = dmat->Info().num_row_;
|
||||
// split by feature 0
|
||||
@@ -513,7 +423,6 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
new BuilderMock<float>(
|
||||
param_,
|
||||
std::move(pruner_),
|
||||
int_constraint_,
|
||||
dmat_.get()));
|
||||
if (batch) {
|
||||
float_builder_->SetHistSynchronizer(new BatchHistSynchronizer<float>());
|
||||
@@ -527,7 +436,6 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
new BuilderMock<double>(
|
||||
param_,
|
||||
std::move(pruner_),
|
||||
int_constraint_,
|
||||
dmat_.get()));
|
||||
if (batch) {
|
||||
double_builder_->SetHistSynchronizer(new BatchHistSynchronizer<double>());
|
||||
@@ -622,23 +530,13 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
}
|
||||
}
|
||||
|
||||
void TestEvaluateSplit() {
|
||||
RegTree tree = RegTree();
|
||||
tree.param.UpdateAllowUnknown(cfg_);
|
||||
if (double_builder_) {
|
||||
double_builder_->TestEvaluateSplit(tree);
|
||||
} else {
|
||||
float_builder_->TestEvaluateSplit(tree);
|
||||
}
|
||||
}
|
||||
|
||||
void TestApplySplit() {
|
||||
RegTree tree = RegTree();
|
||||
tree.param.UpdateAllowUnknown(cfg_);
|
||||
if (double_builder_) {
|
||||
double_builder_->TestApplySplit(tree);
|
||||
} else {
|
||||
float_builder_->TestEvaluateSplit(tree);
|
||||
float_builder_->TestApplySplit(tree);
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -716,19 +614,6 @@ TEST(QuantileHist, BuildHist) {
|
||||
maker_float.TestBuildHist();
|
||||
}
|
||||
|
||||
TEST(QuantileHist, EvalSplits) {
|
||||
std::vector<std::pair<std::string, std::string>> cfg
|
||||
{{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())},
|
||||
{"split_evaluator", "elastic_net"},
|
||||
{"reg_lambda", "0"}, {"reg_alpha", "0"}, {"max_delta_step", "0"},
|
||||
{"min_child_weight", "0"}};
|
||||
QuantileHistMock maker(cfg);
|
||||
maker.TestEvaluateSplit();
|
||||
const bool single_precision_histogram = true;
|
||||
QuantileHistMock maker_float(cfg, single_precision_histogram);
|
||||
maker_float.TestEvaluateSplit();
|
||||
}
|
||||
|
||||
TEST(QuantileHist, ApplySplit) {
|
||||
std::vector<std::pair<std::string, std::string>> cfg
|
||||
{{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())},
|
||||
|
||||
Reference in New Issue
Block a user