Unify evaluation functions. (#6037)
This commit is contained in:
@@ -29,10 +29,9 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
|
||||
BuilderMock(const TrainParam& param,
|
||||
std::unique_ptr<TreeUpdater> pruner,
|
||||
std::unique_ptr<SplitEvaluator> spliteval,
|
||||
FeatureInteractionConstraintHost int_constraint,
|
||||
DMatrix const* fmat)
|
||||
: RealImpl(param, std::move(pruner), std::move(spliteval),
|
||||
: RealImpl(param, std::move(pruner),
|
||||
std::move(int_constraint), fmat) {}
|
||||
|
||||
public:
|
||||
@@ -195,7 +194,7 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
this->hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree);
|
||||
|
||||
const size_t n_nodes = this->nodes_for_explicit_hist_build_.size();
|
||||
ASSERT_EQ(n_nodes, 2);
|
||||
ASSERT_EQ(n_nodes, 2ul);
|
||||
this->row_set_collection_.AddSplit(0, (*tree)[0].LeftChild(),
|
||||
(*tree)[0].RightChild(), 4, 4);
|
||||
this->row_set_collection_.AddSplit(1, (*tree)[1].LeftChild(),
|
||||
@@ -331,10 +330,6 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
for (const auto& e : row_gpairs) {
|
||||
total_gpair += GradientPairPrecise(e);
|
||||
}
|
||||
// Initialize split evaluator
|
||||
std::unique_ptr<SplitEvaluator> evaluator(SplitEvaluator::Create("elastic_net"));
|
||||
evaluator->Init(&this->param_);
|
||||
|
||||
// Now enumerate all feature*threshold combination to get best split
|
||||
// To simplify logic, we make some assumptions:
|
||||
// 1) no missing values in data
|
||||
@@ -368,9 +363,9 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
}
|
||||
}
|
||||
// Now compute gain (change in loss)
|
||||
const auto split_gain
|
||||
= evaluator->ComputeSplitScore(0, fid, GradStats(left_sum),
|
||||
GradStats(right_sum));
|
||||
auto evaluator = this->tree_evaluator_.GetEvaluator();
|
||||
const auto split_gain = evaluator.CalcSplitGain(
|
||||
this->param_, 0, fid, GradStats(left_sum), GradStats(right_sum));
|
||||
if (split_gain > best_split_gain) {
|
||||
best_split_gain = split_gain;
|
||||
best_split_feature = fid;
|
||||
@@ -476,14 +471,12 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
const bool single_precision_histogram = false, bool batch = true) :
|
||||
cfg_{args} {
|
||||
QuantileHistMaker::Configure(args);
|
||||
spliteval_->Init(¶m_);
|
||||
dmat_ = RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix();
|
||||
if (single_precision_histogram) {
|
||||
float_builder_.reset(
|
||||
new BuilderMock<float>(
|
||||
param_,
|
||||
std::move(pruner_),
|
||||
std::unique_ptr<SplitEvaluator>(spliteval_->GetHostClone()),
|
||||
int_constraint_,
|
||||
dmat_.get()));
|
||||
if (batch) {
|
||||
@@ -498,7 +491,6 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
new BuilderMock<double>(
|
||||
param_,
|
||||
std::move(pruner_),
|
||||
std::unique_ptr<SplitEvaluator>(spliteval_->GetHostClone()),
|
||||
int_constraint_,
|
||||
dmat_.get()));
|
||||
if (batch) {
|
||||
|
||||
Reference in New Issue
Block a user