Unify evaluation functions. (#6037)

This commit is contained in:
Jiaming Yuan
2020-08-26 14:23:27 +08:00
committed by GitHub
parent 80c8547147
commit 2fcc4f2886
29 changed files with 570 additions and 734 deletions

View File

@@ -29,10 +29,9 @@ class QuantileHistMock : public QuantileHistMaker {
BuilderMock(const TrainParam& param,
std::unique_ptr<TreeUpdater> pruner,
std::unique_ptr<SplitEvaluator> spliteval,
FeatureInteractionConstraintHost int_constraint,
DMatrix const* fmat)
: RealImpl(param, std::move(pruner), std::move(spliteval),
: RealImpl(param, std::move(pruner),
std::move(int_constraint), fmat) {}
public:
@@ -195,7 +194,7 @@ class QuantileHistMock : public QuantileHistMaker {
this->hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree);
const size_t n_nodes = this->nodes_for_explicit_hist_build_.size();
ASSERT_EQ(n_nodes, 2);
ASSERT_EQ(n_nodes, 2ul);
this->row_set_collection_.AddSplit(0, (*tree)[0].LeftChild(),
(*tree)[0].RightChild(), 4, 4);
this->row_set_collection_.AddSplit(1, (*tree)[1].LeftChild(),
@@ -331,10 +330,6 @@ class QuantileHistMock : public QuantileHistMaker {
for (const auto& e : row_gpairs) {
total_gpair += GradientPairPrecise(e);
}
// Initialize split evaluator
std::unique_ptr<SplitEvaluator> evaluator(SplitEvaluator::Create("elastic_net"));
evaluator->Init(&this->param_);
// Now enumerate all feature*threshold combination to get best split
// To simplify logic, we make some assumptions:
// 1) no missing values in data
@@ -368,9 +363,9 @@ class QuantileHistMock : public QuantileHistMaker {
}
}
// Now compute gain (change in loss)
const auto split_gain
= evaluator->ComputeSplitScore(0, fid, GradStats(left_sum),
GradStats(right_sum));
auto evaluator = this->tree_evaluator_.GetEvaluator();
const auto split_gain = evaluator.CalcSplitGain(
this->param_, 0, fid, GradStats(left_sum), GradStats(right_sum));
if (split_gain > best_split_gain) {
best_split_gain = split_gain;
best_split_feature = fid;
@@ -476,14 +471,12 @@ class QuantileHistMock : public QuantileHistMaker {
const bool single_precision_histogram = false, bool batch = true) :
cfg_{args} {
QuantileHistMaker::Configure(args);
spliteval_->Init(&param_);
dmat_ = RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix();
if (single_precision_histogram) {
float_builder_.reset(
new BuilderMock<float>(
param_,
std::move(pruner_),
std::unique_ptr<SplitEvaluator>(spliteval_->GetHostClone()),
int_constraint_,
dmat_.get()));
if (batch) {
@@ -498,7 +491,6 @@ class QuantileHistMock : public QuantileHistMaker {
new BuilderMock<double>(
param_,
std::move(pruner_),
std::unique_ptr<SplitEvaluator>(spliteval_->GetHostClone()),
int_constraint_,
dmat_.get()));
if (batch) {