Merge lossgude and depthwise strategies for CPU hist (#7007)

* fix java/scala test: max depth is also valid parameter for lossguide

Co-authored-by: Kirill Shvets <kirill.shvets@intel.com>
This commit is contained in:
ShvetsKS
2021-06-02 20:49:43 +03:00
committed by GitHub
parent ee4f51a631
commit 57c732655e
11 changed files with 415 additions and 484 deletions

View File

@@ -24,7 +24,6 @@ class QuantileHistMock : public QuantileHistMaker {
template <typename GradientSumT>
struct BuilderMock : public QuantileHistMaker::Builder<GradientSumT> {
using RealImpl = QuantileHistMaker::Builder<GradientSumT>;
using ExpandEntryT = typename RealImpl::ExpandEntry;
using GHistRowT = typename RealImpl::GHistRowT;
BuilderMock(const TrainParam& param,
@@ -169,19 +168,19 @@ class QuantileHistMock : public QuantileHistMaker {
tree->ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
tree->ExpandNode((*tree)[0].LeftChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
tree->ExpandNode((*tree)[0].RightChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
this->nodes_for_explicit_hist_build_.emplace_back(3, 4, tree->GetDepth(3), 0.0f, 0);
this->nodes_for_explicit_hist_build_.emplace_back(4, 3, tree->GetDepth(4), 0.0f, 0);
this->nodes_for_subtraction_trick_.emplace_back(5, 6, tree->GetDepth(5), 0.0f, 0);
this->nodes_for_subtraction_trick_.emplace_back(6, 5, tree->GetDepth(6), 0.0f, 0);
this->nodes_for_explicit_hist_build_.emplace_back(3, tree->GetDepth(3), 0.0f);
this->nodes_for_explicit_hist_build_.emplace_back(4, tree->GetDepth(4), 0.0f);
this->nodes_for_subtraction_trick_.emplace_back(5, tree->GetDepth(5), 0.0f);
this->nodes_for_subtraction_trick_.emplace_back(6, tree->GetDepth(6), 0.0f);
this->hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree);
ASSERT_EQ(sync_count, 2);
ASSERT_EQ(starting_index, 3);
for (const ExpandEntryT& node : this->nodes_for_explicit_hist_build_) {
for (const CPUExpandEntry& node : this->nodes_for_explicit_hist_build_) {
ASSERT_EQ(this->hist_.RowExists(node.nid), true);
}
for (const ExpandEntryT& node : this->nodes_for_subtraction_trick_) {
for (const CPUExpandEntry& node : this->nodes_for_subtraction_trick_) {
ASSERT_EQ(this->hist_.RowExists(node.nid), true);
}
}
@@ -199,7 +198,7 @@ class QuantileHistMock : public QuantileHistMaker {
this->nodes_for_explicit_hist_build_.clear();
this->nodes_for_subtraction_trick_.clear();
// level 0
this->nodes_for_explicit_hist_build_.emplace_back(0, -1, tree->GetDepth(0), 0.0f, 0);
this->nodes_for_explicit_hist_build_.emplace_back(0, tree->GetDepth(0), 0.0f);
this->hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree);
tree->ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
@@ -207,11 +206,9 @@ class QuantileHistMock : public QuantileHistMaker {
this->nodes_for_subtraction_trick_.clear();
// level 1
this->nodes_for_explicit_hist_build_.emplace_back((*tree)[0].LeftChild(),
(*tree)[0].RightChild(),
tree->GetDepth(1), 0.0f, 0);
tree->GetDepth(1), 0.0f);
this->nodes_for_subtraction_trick_.emplace_back((*tree)[0].RightChild(),
(*tree)[0].LeftChild(),
tree->GetDepth(2), 0.0f, 0);
tree->GetDepth(2), 0.0f);
this->hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree);
tree->ExpandNode((*tree)[0].LeftChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
tree->ExpandNode((*tree)[0].RightChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
@@ -219,10 +216,10 @@ class QuantileHistMock : public QuantileHistMaker {
this->nodes_for_explicit_hist_build_.clear();
this->nodes_for_subtraction_trick_.clear();
// level 2
this->nodes_for_explicit_hist_build_.emplace_back(3, 4, tree->GetDepth(3), 0.0f, 0);
this->nodes_for_subtraction_trick_.emplace_back(4, 3, tree->GetDepth(4), 0.0f, 0);
this->nodes_for_explicit_hist_build_.emplace_back(5, 6, tree->GetDepth(5), 0.0f, 0);
this->nodes_for_subtraction_trick_.emplace_back(6, 5, tree->GetDepth(6), 0.0f, 0);
this->nodes_for_explicit_hist_build_.emplace_back(3, tree->GetDepth(3), 0.0f);
this->nodes_for_subtraction_trick_.emplace_back(4, tree->GetDepth(4), 0.0f);
this->nodes_for_explicit_hist_build_.emplace_back(5, tree->GetDepth(5), 0.0f);
this->nodes_for_subtraction_trick_.emplace_back(6, tree->GetDepth(6), 0.0f);
this->hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree);
const size_t n_nodes = this->nodes_for_explicit_hist_build_.size();
@@ -278,21 +275,27 @@ class QuantileHistMock : public QuantileHistMaker {
ASSERT_EQ(p_parent[i], p_left[i] + p_right[i]);
}
};
for (const ExpandEntryT& node : this->nodes_for_explicit_hist_build_) {
size_t node_id = 0;
for (const CPUExpandEntry& node : this->nodes_for_explicit_hist_build_) {
auto this_hist = this->hist_[node.nid];
const size_t parent_id = (*tree)[node.nid].Parent();
const size_t subtraction_node_id = this->nodes_for_subtraction_trick_[node_id].nid;
auto parent_hist = this->hist_[parent_id];
auto sibling_hist = this->hist_[node.sibling_nid];
auto sibling_hist = this->hist_[subtraction_node_id];
check_hist(parent_hist, this_hist, sibling_hist, 0, nbins);
++node_id;
}
for (const ExpandEntryT& node : this->nodes_for_subtraction_trick_) {
node_id = 0;
for (const CPUExpandEntry& node : this->nodes_for_subtraction_trick_) {
auto this_hist = this->hist_[node.nid];
const size_t parent_id = (*tree)[node.nid].Parent();
const size_t subtraction_node_id = this->nodes_for_explicit_hist_build_[node_id].nid;
auto parent_hist = this->hist_[parent_id];
auto sibling_hist = this->hist_[node.sibling_nid];
auto sibling_hist = this->hist_[subtraction_node_id];
check_hist(parent_hist, this_hist, sibling_hist, 0, nbins);
++node_id;
}
}
@@ -408,10 +411,9 @@ class QuantileHistMock : public QuantileHistMaker {
}
/* Now compare against result given by EvaluateSplit() */
typename RealImpl::ExpandEntry node(RealImpl::ExpandEntry::kRootNid,
RealImpl::ExpandEntry::kEmptyNid,
tree.GetDepth(0),
this->snode_[0].best.loss_chg, 0);
CPUExpandEntry node(CPUExpandEntry::kRootNid,
tree.GetDepth(0),
this->snode_[0].best.loss_chg);
RealImpl::EvaluateSplits({node}, gmat, this->hist_, tree);
ASSERT_EQ(this->snode_[0].best.SplitIndex(), best_split_feature);
ASSERT_EQ(this->snode_[0].best.split_value, gmat.cut.Values()[best_split_threshold]);