Merge lossgude and depthwise strategies for CPU hist (#7007)
* fix java/scala test: max depth is also valid parameter for lossguide Co-authored-by: Kirill Shvets <kirill.shvets@intel.com>
This commit is contained in:
@@ -24,7 +24,6 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
template <typename GradientSumT>
|
||||
struct BuilderMock : public QuantileHistMaker::Builder<GradientSumT> {
|
||||
using RealImpl = QuantileHistMaker::Builder<GradientSumT>;
|
||||
using ExpandEntryT = typename RealImpl::ExpandEntry;
|
||||
using GHistRowT = typename RealImpl::GHistRowT;
|
||||
|
||||
BuilderMock(const TrainParam& param,
|
||||
@@ -169,19 +168,19 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
tree->ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||
tree->ExpandNode((*tree)[0].LeftChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||
tree->ExpandNode((*tree)[0].RightChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||
this->nodes_for_explicit_hist_build_.emplace_back(3, 4, tree->GetDepth(3), 0.0f, 0);
|
||||
this->nodes_for_explicit_hist_build_.emplace_back(4, 3, tree->GetDepth(4), 0.0f, 0);
|
||||
this->nodes_for_subtraction_trick_.emplace_back(5, 6, tree->GetDepth(5), 0.0f, 0);
|
||||
this->nodes_for_subtraction_trick_.emplace_back(6, 5, tree->GetDepth(6), 0.0f, 0);
|
||||
this->nodes_for_explicit_hist_build_.emplace_back(3, tree->GetDepth(3), 0.0f);
|
||||
this->nodes_for_explicit_hist_build_.emplace_back(4, tree->GetDepth(4), 0.0f);
|
||||
this->nodes_for_subtraction_trick_.emplace_back(5, tree->GetDepth(5), 0.0f);
|
||||
this->nodes_for_subtraction_trick_.emplace_back(6, tree->GetDepth(6), 0.0f);
|
||||
|
||||
this->hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree);
|
||||
ASSERT_EQ(sync_count, 2);
|
||||
ASSERT_EQ(starting_index, 3);
|
||||
|
||||
for (const ExpandEntryT& node : this->nodes_for_explicit_hist_build_) {
|
||||
for (const CPUExpandEntry& node : this->nodes_for_explicit_hist_build_) {
|
||||
ASSERT_EQ(this->hist_.RowExists(node.nid), true);
|
||||
}
|
||||
for (const ExpandEntryT& node : this->nodes_for_subtraction_trick_) {
|
||||
for (const CPUExpandEntry& node : this->nodes_for_subtraction_trick_) {
|
||||
ASSERT_EQ(this->hist_.RowExists(node.nid), true);
|
||||
}
|
||||
}
|
||||
@@ -199,7 +198,7 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
this->nodes_for_explicit_hist_build_.clear();
|
||||
this->nodes_for_subtraction_trick_.clear();
|
||||
// level 0
|
||||
this->nodes_for_explicit_hist_build_.emplace_back(0, -1, tree->GetDepth(0), 0.0f, 0);
|
||||
this->nodes_for_explicit_hist_build_.emplace_back(0, tree->GetDepth(0), 0.0f);
|
||||
this->hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree);
|
||||
tree->ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||
|
||||
@@ -207,11 +206,9 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
this->nodes_for_subtraction_trick_.clear();
|
||||
// level 1
|
||||
this->nodes_for_explicit_hist_build_.emplace_back((*tree)[0].LeftChild(),
|
||||
(*tree)[0].RightChild(),
|
||||
tree->GetDepth(1), 0.0f, 0);
|
||||
tree->GetDepth(1), 0.0f);
|
||||
this->nodes_for_subtraction_trick_.emplace_back((*tree)[0].RightChild(),
|
||||
(*tree)[0].LeftChild(),
|
||||
tree->GetDepth(2), 0.0f, 0);
|
||||
tree->GetDepth(2), 0.0f);
|
||||
this->hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree);
|
||||
tree->ExpandNode((*tree)[0].LeftChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||
tree->ExpandNode((*tree)[0].RightChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||
@@ -219,10 +216,10 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
this->nodes_for_explicit_hist_build_.clear();
|
||||
this->nodes_for_subtraction_trick_.clear();
|
||||
// level 2
|
||||
this->nodes_for_explicit_hist_build_.emplace_back(3, 4, tree->GetDepth(3), 0.0f, 0);
|
||||
this->nodes_for_subtraction_trick_.emplace_back(4, 3, tree->GetDepth(4), 0.0f, 0);
|
||||
this->nodes_for_explicit_hist_build_.emplace_back(5, 6, tree->GetDepth(5), 0.0f, 0);
|
||||
this->nodes_for_subtraction_trick_.emplace_back(6, 5, tree->GetDepth(6), 0.0f, 0);
|
||||
this->nodes_for_explicit_hist_build_.emplace_back(3, tree->GetDepth(3), 0.0f);
|
||||
this->nodes_for_subtraction_trick_.emplace_back(4, tree->GetDepth(4), 0.0f);
|
||||
this->nodes_for_explicit_hist_build_.emplace_back(5, tree->GetDepth(5), 0.0f);
|
||||
this->nodes_for_subtraction_trick_.emplace_back(6, tree->GetDepth(6), 0.0f);
|
||||
this->hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree);
|
||||
|
||||
const size_t n_nodes = this->nodes_for_explicit_hist_build_.size();
|
||||
@@ -278,21 +275,27 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
ASSERT_EQ(p_parent[i], p_left[i] + p_right[i]);
|
||||
}
|
||||
};
|
||||
for (const ExpandEntryT& node : this->nodes_for_explicit_hist_build_) {
|
||||
size_t node_id = 0;
|
||||
for (const CPUExpandEntry& node : this->nodes_for_explicit_hist_build_) {
|
||||
auto this_hist = this->hist_[node.nid];
|
||||
const size_t parent_id = (*tree)[node.nid].Parent();
|
||||
const size_t subtraction_node_id = this->nodes_for_subtraction_trick_[node_id].nid;
|
||||
auto parent_hist = this->hist_[parent_id];
|
||||
auto sibling_hist = this->hist_[node.sibling_nid];
|
||||
auto sibling_hist = this->hist_[subtraction_node_id];
|
||||
|
||||
check_hist(parent_hist, this_hist, sibling_hist, 0, nbins);
|
||||
++node_id;
|
||||
}
|
||||
for (const ExpandEntryT& node : this->nodes_for_subtraction_trick_) {
|
||||
node_id = 0;
|
||||
for (const CPUExpandEntry& node : this->nodes_for_subtraction_trick_) {
|
||||
auto this_hist = this->hist_[node.nid];
|
||||
const size_t parent_id = (*tree)[node.nid].Parent();
|
||||
const size_t subtraction_node_id = this->nodes_for_explicit_hist_build_[node_id].nid;
|
||||
auto parent_hist = this->hist_[parent_id];
|
||||
auto sibling_hist = this->hist_[node.sibling_nid];
|
||||
auto sibling_hist = this->hist_[subtraction_node_id];
|
||||
|
||||
check_hist(parent_hist, this_hist, sibling_hist, 0, nbins);
|
||||
++node_id;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -408,10 +411,9 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
}
|
||||
|
||||
/* Now compare against result given by EvaluateSplit() */
|
||||
typename RealImpl::ExpandEntry node(RealImpl::ExpandEntry::kRootNid,
|
||||
RealImpl::ExpandEntry::kEmptyNid,
|
||||
tree.GetDepth(0),
|
||||
this->snode_[0].best.loss_chg, 0);
|
||||
CPUExpandEntry node(CPUExpandEntry::kRootNid,
|
||||
tree.GetDepth(0),
|
||||
this->snode_[0].best.loss_chg);
|
||||
RealImpl::EvaluateSplits({node}, gmat, this->hist_, tree);
|
||||
ASSERT_EQ(this->snode_[0].best.SplitIndex(), best_split_feature);
|
||||
ASSERT_EQ(this->snode_[0].best.split_value, gmat.cut.Values()[best_split_threshold]);
|
||||
|
||||
Reference in New Issue
Block a user