Merge lossgude and depthwise strategies for CPU hist (#7007)

* fix java/scala test: max depth is also valid parameter for lossguide

Co-authored-by: Kirill Shvets <kirill.shvets@intel.com>
This commit is contained in:
ShvetsKS
2021-06-02 20:49:43 +03:00
committed by GitHub
parent ee4f51a631
commit 57c732655e
11 changed files with 415 additions and 484 deletions

View File

@@ -1,20 +1,21 @@
#include <gtest/gtest.h>
#include "../../../../src/tree/gpu_hist/driver.cuh"
#include "../../../../src/tree/driver.h"
#include "../../../../src/tree/gpu_hist/expand_entry.cuh"
namespace xgboost {
namespace tree {
TEST(GpuHist, DriverDepthWise) {
Driver driver(TrainParam::kDepthWise);
Driver<GPUExpandEntry> driver(TrainParam::kDepthWise);
EXPECT_TRUE(driver.Pop().empty());
DeviceSplitCandidate split;
split.loss_chg = 1.0f;
ExpandEntry root(0, 0, split, .0f, .0f, .0f);
GPUExpandEntry root(0, 0, split, .0f, .0f, .0f);
driver.Push({root});
EXPECT_EQ(driver.Pop().front().nid, 0);
driver.Push({ExpandEntry{1, 1, split, .0f, .0f, .0f}});
driver.Push({ExpandEntry{2, 1, split, .0f, .0f, .0f}});
driver.Push({ExpandEntry{3, 2, split, .0f, .0f, .0f}});
driver.Push({GPUExpandEntry{1, 1, split, .0f, .0f, .0f}});
driver.Push({GPUExpandEntry{2, 1, split, .0f, .0f, .0f}});
driver.Push({GPUExpandEntry{3, 2, split, .0f, .0f, .0f}});
// Should return entries from level 1
auto res = driver.Pop();
EXPECT_EQ(res.size(), 2);
@@ -32,14 +33,14 @@ TEST(GpuHist, DriverLossGuided) {
DeviceSplitCandidate low_gain;
low_gain.loss_chg = 1.0f;
Driver driver(TrainParam::kLossGuide);
Driver<GPUExpandEntry> driver(TrainParam::kLossGuide);
EXPECT_TRUE(driver.Pop().empty());
ExpandEntry root(0, 0, high_gain, .0f, .0f, .0f);
GPUExpandEntry root(0, 0, high_gain, .0f, .0f, .0f);
driver.Push({root});
EXPECT_EQ(driver.Pop().front().nid, 0);
// Select high gain first
driver.Push({ExpandEntry{1, 1, low_gain, .0f, .0f, .0f}});
driver.Push({ExpandEntry{2, 2, high_gain, .0f, .0f, .0f}});
driver.Push({GPUExpandEntry{1, 1, low_gain, .0f, .0f, .0f}});
driver.Push({GPUExpandEntry{2, 2, high_gain, .0f, .0f, .0f}});
auto res = driver.Pop();
EXPECT_EQ(res.size(), 1);
EXPECT_EQ(res[0].nid, 2);
@@ -48,8 +49,8 @@ TEST(GpuHist, DriverLossGuided) {
EXPECT_EQ(res[0].nid, 1);
// If equal gain, use nid
driver.Push({ExpandEntry{2, 1, low_gain, .0f, .0f, .0f}});
driver.Push({ExpandEntry{1, 1, low_gain, .0f, .0f, .0f}});
driver.Push({GPUExpandEntry{2, 1, low_gain, .0f, .0f, .0f}});
driver.Push({GPUExpandEntry{1, 1, low_gain, .0f, .0f, .0f}});
res = driver.Pop();
EXPECT_EQ(res[0].nid, 1);
res = driver.Pop();

View File

@@ -132,7 +132,7 @@ TEST(GpuHist, BuildHistSharedMem) {
TEST(GpuHist, ApplySplit) {
RegTree tree;
ExpandEntry candidate;
GPUExpandEntry candidate;
candidate.nid = 0;
candidate.left_weight = 1.0f;
candidate.right_weight = 2.0f;

View File

@@ -24,7 +24,6 @@ class QuantileHistMock : public QuantileHistMaker {
template <typename GradientSumT>
struct BuilderMock : public QuantileHistMaker::Builder<GradientSumT> {
using RealImpl = QuantileHistMaker::Builder<GradientSumT>;
using ExpandEntryT = typename RealImpl::ExpandEntry;
using GHistRowT = typename RealImpl::GHistRowT;
BuilderMock(const TrainParam& param,
@@ -169,19 +168,19 @@ class QuantileHistMock : public QuantileHistMaker {
tree->ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
tree->ExpandNode((*tree)[0].LeftChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
tree->ExpandNode((*tree)[0].RightChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
this->nodes_for_explicit_hist_build_.emplace_back(3, 4, tree->GetDepth(3), 0.0f, 0);
this->nodes_for_explicit_hist_build_.emplace_back(4, 3, tree->GetDepth(4), 0.0f, 0);
this->nodes_for_subtraction_trick_.emplace_back(5, 6, tree->GetDepth(5), 0.0f, 0);
this->nodes_for_subtraction_trick_.emplace_back(6, 5, tree->GetDepth(6), 0.0f, 0);
this->nodes_for_explicit_hist_build_.emplace_back(3, tree->GetDepth(3), 0.0f);
this->nodes_for_explicit_hist_build_.emplace_back(4, tree->GetDepth(4), 0.0f);
this->nodes_for_subtraction_trick_.emplace_back(5, tree->GetDepth(5), 0.0f);
this->nodes_for_subtraction_trick_.emplace_back(6, tree->GetDepth(6), 0.0f);
this->hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree);
ASSERT_EQ(sync_count, 2);
ASSERT_EQ(starting_index, 3);
for (const ExpandEntryT& node : this->nodes_for_explicit_hist_build_) {
for (const CPUExpandEntry& node : this->nodes_for_explicit_hist_build_) {
ASSERT_EQ(this->hist_.RowExists(node.nid), true);
}
for (const ExpandEntryT& node : this->nodes_for_subtraction_trick_) {
for (const CPUExpandEntry& node : this->nodes_for_subtraction_trick_) {
ASSERT_EQ(this->hist_.RowExists(node.nid), true);
}
}
@@ -199,7 +198,7 @@ class QuantileHistMock : public QuantileHistMaker {
this->nodes_for_explicit_hist_build_.clear();
this->nodes_for_subtraction_trick_.clear();
// level 0
this->nodes_for_explicit_hist_build_.emplace_back(0, -1, tree->GetDepth(0), 0.0f, 0);
this->nodes_for_explicit_hist_build_.emplace_back(0, tree->GetDepth(0), 0.0f);
this->hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree);
tree->ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
@@ -207,11 +206,9 @@ class QuantileHistMock : public QuantileHistMaker {
this->nodes_for_subtraction_trick_.clear();
// level 1
this->nodes_for_explicit_hist_build_.emplace_back((*tree)[0].LeftChild(),
(*tree)[0].RightChild(),
tree->GetDepth(1), 0.0f, 0);
tree->GetDepth(1), 0.0f);
this->nodes_for_subtraction_trick_.emplace_back((*tree)[0].RightChild(),
(*tree)[0].LeftChild(),
tree->GetDepth(2), 0.0f, 0);
tree->GetDepth(2), 0.0f);
this->hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree);
tree->ExpandNode((*tree)[0].LeftChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
tree->ExpandNode((*tree)[0].RightChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
@@ -219,10 +216,10 @@ class QuantileHistMock : public QuantileHistMaker {
this->nodes_for_explicit_hist_build_.clear();
this->nodes_for_subtraction_trick_.clear();
// level 2
this->nodes_for_explicit_hist_build_.emplace_back(3, 4, tree->GetDepth(3), 0.0f, 0);
this->nodes_for_subtraction_trick_.emplace_back(4, 3, tree->GetDepth(4), 0.0f, 0);
this->nodes_for_explicit_hist_build_.emplace_back(5, 6, tree->GetDepth(5), 0.0f, 0);
this->nodes_for_subtraction_trick_.emplace_back(6, 5, tree->GetDepth(6), 0.0f, 0);
this->nodes_for_explicit_hist_build_.emplace_back(3, tree->GetDepth(3), 0.0f);
this->nodes_for_subtraction_trick_.emplace_back(4, tree->GetDepth(4), 0.0f);
this->nodes_for_explicit_hist_build_.emplace_back(5, tree->GetDepth(5), 0.0f);
this->nodes_for_subtraction_trick_.emplace_back(6, tree->GetDepth(6), 0.0f);
this->hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree);
const size_t n_nodes = this->nodes_for_explicit_hist_build_.size();
@@ -278,21 +275,27 @@ class QuantileHistMock : public QuantileHistMaker {
ASSERT_EQ(p_parent[i], p_left[i] + p_right[i]);
}
};
for (const ExpandEntryT& node : this->nodes_for_explicit_hist_build_) {
size_t node_id = 0;
for (const CPUExpandEntry& node : this->nodes_for_explicit_hist_build_) {
auto this_hist = this->hist_[node.nid];
const size_t parent_id = (*tree)[node.nid].Parent();
const size_t subtraction_node_id = this->nodes_for_subtraction_trick_[node_id].nid;
auto parent_hist = this->hist_[parent_id];
auto sibling_hist = this->hist_[node.sibling_nid];
auto sibling_hist = this->hist_[subtraction_node_id];
check_hist(parent_hist, this_hist, sibling_hist, 0, nbins);
++node_id;
}
for (const ExpandEntryT& node : this->nodes_for_subtraction_trick_) {
node_id = 0;
for (const CPUExpandEntry& node : this->nodes_for_subtraction_trick_) {
auto this_hist = this->hist_[node.nid];
const size_t parent_id = (*tree)[node.nid].Parent();
const size_t subtraction_node_id = this->nodes_for_explicit_hist_build_[node_id].nid;
auto parent_hist = this->hist_[parent_id];
auto sibling_hist = this->hist_[node.sibling_nid];
auto sibling_hist = this->hist_[subtraction_node_id];
check_hist(parent_hist, this_hist, sibling_hist, 0, nbins);
++node_id;
}
}
@@ -408,10 +411,9 @@ class QuantileHistMock : public QuantileHistMaker {
}
/* Now compare against result given by EvaluateSplit() */
typename RealImpl::ExpandEntry node(RealImpl::ExpandEntry::kRootNid,
RealImpl::ExpandEntry::kEmptyNid,
tree.GetDepth(0),
this->snode_[0].best.loss_chg, 0);
CPUExpandEntry node(CPUExpandEntry::kRootNid,
tree.GetDepth(0),
this->snode_[0].best.loss_chg);
RealImpl::EvaluateSplits({node}, gmat, this->hist_, tree);
ASSERT_EQ(this->snode_[0].best.SplitIndex(), best_split_feature);
ASSERT_EQ(this->snode_[0].best.split_value, gmat.cut.Values()[best_split_threshold]);