Merge lossgude and depthwise strategies for CPU hist (#7007)
* fix java/scala test: max depth is also valid parameter for lossguide Co-authored-by: Kirill Shvets <kirill.shvets@intel.com>
This commit is contained in:
@@ -1,20 +1,21 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include "../../../../src/tree/gpu_hist/driver.cuh"
|
||||
#include "../../../../src/tree/driver.h"
|
||||
#include "../../../../src/tree/gpu_hist/expand_entry.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
TEST(GpuHist, DriverDepthWise) {
|
||||
Driver driver(TrainParam::kDepthWise);
|
||||
Driver<GPUExpandEntry> driver(TrainParam::kDepthWise);
|
||||
EXPECT_TRUE(driver.Pop().empty());
|
||||
DeviceSplitCandidate split;
|
||||
split.loss_chg = 1.0f;
|
||||
ExpandEntry root(0, 0, split, .0f, .0f, .0f);
|
||||
GPUExpandEntry root(0, 0, split, .0f, .0f, .0f);
|
||||
driver.Push({root});
|
||||
EXPECT_EQ(driver.Pop().front().nid, 0);
|
||||
driver.Push({ExpandEntry{1, 1, split, .0f, .0f, .0f}});
|
||||
driver.Push({ExpandEntry{2, 1, split, .0f, .0f, .0f}});
|
||||
driver.Push({ExpandEntry{3, 2, split, .0f, .0f, .0f}});
|
||||
driver.Push({GPUExpandEntry{1, 1, split, .0f, .0f, .0f}});
|
||||
driver.Push({GPUExpandEntry{2, 1, split, .0f, .0f, .0f}});
|
||||
driver.Push({GPUExpandEntry{3, 2, split, .0f, .0f, .0f}});
|
||||
// Should return entries from level 1
|
||||
auto res = driver.Pop();
|
||||
EXPECT_EQ(res.size(), 2);
|
||||
@@ -32,14 +33,14 @@ TEST(GpuHist, DriverLossGuided) {
|
||||
DeviceSplitCandidate low_gain;
|
||||
low_gain.loss_chg = 1.0f;
|
||||
|
||||
Driver driver(TrainParam::kLossGuide);
|
||||
Driver<GPUExpandEntry> driver(TrainParam::kLossGuide);
|
||||
EXPECT_TRUE(driver.Pop().empty());
|
||||
ExpandEntry root(0, 0, high_gain, .0f, .0f, .0f);
|
||||
GPUExpandEntry root(0, 0, high_gain, .0f, .0f, .0f);
|
||||
driver.Push({root});
|
||||
EXPECT_EQ(driver.Pop().front().nid, 0);
|
||||
// Select high gain first
|
||||
driver.Push({ExpandEntry{1, 1, low_gain, .0f, .0f, .0f}});
|
||||
driver.Push({ExpandEntry{2, 2, high_gain, .0f, .0f, .0f}});
|
||||
driver.Push({GPUExpandEntry{1, 1, low_gain, .0f, .0f, .0f}});
|
||||
driver.Push({GPUExpandEntry{2, 2, high_gain, .0f, .0f, .0f}});
|
||||
auto res = driver.Pop();
|
||||
EXPECT_EQ(res.size(), 1);
|
||||
EXPECT_EQ(res[0].nid, 2);
|
||||
@@ -48,8 +49,8 @@ TEST(GpuHist, DriverLossGuided) {
|
||||
EXPECT_EQ(res[0].nid, 1);
|
||||
|
||||
// If equal gain, use nid
|
||||
driver.Push({ExpandEntry{2, 1, low_gain, .0f, .0f, .0f}});
|
||||
driver.Push({ExpandEntry{1, 1, low_gain, .0f, .0f, .0f}});
|
||||
driver.Push({GPUExpandEntry{2, 1, low_gain, .0f, .0f, .0f}});
|
||||
driver.Push({GPUExpandEntry{1, 1, low_gain, .0f, .0f, .0f}});
|
||||
res = driver.Pop();
|
||||
EXPECT_EQ(res[0].nid, 1);
|
||||
res = driver.Pop();
|
||||
|
||||
@@ -132,7 +132,7 @@ TEST(GpuHist, BuildHistSharedMem) {
|
||||
|
||||
TEST(GpuHist, ApplySplit) {
|
||||
RegTree tree;
|
||||
ExpandEntry candidate;
|
||||
GPUExpandEntry candidate;
|
||||
candidate.nid = 0;
|
||||
candidate.left_weight = 1.0f;
|
||||
candidate.right_weight = 2.0f;
|
||||
|
||||
@@ -24,7 +24,6 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
template <typename GradientSumT>
|
||||
struct BuilderMock : public QuantileHistMaker::Builder<GradientSumT> {
|
||||
using RealImpl = QuantileHistMaker::Builder<GradientSumT>;
|
||||
using ExpandEntryT = typename RealImpl::ExpandEntry;
|
||||
using GHistRowT = typename RealImpl::GHistRowT;
|
||||
|
||||
BuilderMock(const TrainParam& param,
|
||||
@@ -169,19 +168,19 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
tree->ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||
tree->ExpandNode((*tree)[0].LeftChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||
tree->ExpandNode((*tree)[0].RightChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||
this->nodes_for_explicit_hist_build_.emplace_back(3, 4, tree->GetDepth(3), 0.0f, 0);
|
||||
this->nodes_for_explicit_hist_build_.emplace_back(4, 3, tree->GetDepth(4), 0.0f, 0);
|
||||
this->nodes_for_subtraction_trick_.emplace_back(5, 6, tree->GetDepth(5), 0.0f, 0);
|
||||
this->nodes_for_subtraction_trick_.emplace_back(6, 5, tree->GetDepth(6), 0.0f, 0);
|
||||
this->nodes_for_explicit_hist_build_.emplace_back(3, tree->GetDepth(3), 0.0f);
|
||||
this->nodes_for_explicit_hist_build_.emplace_back(4, tree->GetDepth(4), 0.0f);
|
||||
this->nodes_for_subtraction_trick_.emplace_back(5, tree->GetDepth(5), 0.0f);
|
||||
this->nodes_for_subtraction_trick_.emplace_back(6, tree->GetDepth(6), 0.0f);
|
||||
|
||||
this->hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree);
|
||||
ASSERT_EQ(sync_count, 2);
|
||||
ASSERT_EQ(starting_index, 3);
|
||||
|
||||
for (const ExpandEntryT& node : this->nodes_for_explicit_hist_build_) {
|
||||
for (const CPUExpandEntry& node : this->nodes_for_explicit_hist_build_) {
|
||||
ASSERT_EQ(this->hist_.RowExists(node.nid), true);
|
||||
}
|
||||
for (const ExpandEntryT& node : this->nodes_for_subtraction_trick_) {
|
||||
for (const CPUExpandEntry& node : this->nodes_for_subtraction_trick_) {
|
||||
ASSERT_EQ(this->hist_.RowExists(node.nid), true);
|
||||
}
|
||||
}
|
||||
@@ -199,7 +198,7 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
this->nodes_for_explicit_hist_build_.clear();
|
||||
this->nodes_for_subtraction_trick_.clear();
|
||||
// level 0
|
||||
this->nodes_for_explicit_hist_build_.emplace_back(0, -1, tree->GetDepth(0), 0.0f, 0);
|
||||
this->nodes_for_explicit_hist_build_.emplace_back(0, tree->GetDepth(0), 0.0f);
|
||||
this->hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree);
|
||||
tree->ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||
|
||||
@@ -207,11 +206,9 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
this->nodes_for_subtraction_trick_.clear();
|
||||
// level 1
|
||||
this->nodes_for_explicit_hist_build_.emplace_back((*tree)[0].LeftChild(),
|
||||
(*tree)[0].RightChild(),
|
||||
tree->GetDepth(1), 0.0f, 0);
|
||||
tree->GetDepth(1), 0.0f);
|
||||
this->nodes_for_subtraction_trick_.emplace_back((*tree)[0].RightChild(),
|
||||
(*tree)[0].LeftChild(),
|
||||
tree->GetDepth(2), 0.0f, 0);
|
||||
tree->GetDepth(2), 0.0f);
|
||||
this->hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree);
|
||||
tree->ExpandNode((*tree)[0].LeftChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||
tree->ExpandNode((*tree)[0].RightChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||
@@ -219,10 +216,10 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
this->nodes_for_explicit_hist_build_.clear();
|
||||
this->nodes_for_subtraction_trick_.clear();
|
||||
// level 2
|
||||
this->nodes_for_explicit_hist_build_.emplace_back(3, 4, tree->GetDepth(3), 0.0f, 0);
|
||||
this->nodes_for_subtraction_trick_.emplace_back(4, 3, tree->GetDepth(4), 0.0f, 0);
|
||||
this->nodes_for_explicit_hist_build_.emplace_back(5, 6, tree->GetDepth(5), 0.0f, 0);
|
||||
this->nodes_for_subtraction_trick_.emplace_back(6, 5, tree->GetDepth(6), 0.0f, 0);
|
||||
this->nodes_for_explicit_hist_build_.emplace_back(3, tree->GetDepth(3), 0.0f);
|
||||
this->nodes_for_subtraction_trick_.emplace_back(4, tree->GetDepth(4), 0.0f);
|
||||
this->nodes_for_explicit_hist_build_.emplace_back(5, tree->GetDepth(5), 0.0f);
|
||||
this->nodes_for_subtraction_trick_.emplace_back(6, tree->GetDepth(6), 0.0f);
|
||||
this->hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree);
|
||||
|
||||
const size_t n_nodes = this->nodes_for_explicit_hist_build_.size();
|
||||
@@ -278,21 +275,27 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
ASSERT_EQ(p_parent[i], p_left[i] + p_right[i]);
|
||||
}
|
||||
};
|
||||
for (const ExpandEntryT& node : this->nodes_for_explicit_hist_build_) {
|
||||
size_t node_id = 0;
|
||||
for (const CPUExpandEntry& node : this->nodes_for_explicit_hist_build_) {
|
||||
auto this_hist = this->hist_[node.nid];
|
||||
const size_t parent_id = (*tree)[node.nid].Parent();
|
||||
const size_t subtraction_node_id = this->nodes_for_subtraction_trick_[node_id].nid;
|
||||
auto parent_hist = this->hist_[parent_id];
|
||||
auto sibling_hist = this->hist_[node.sibling_nid];
|
||||
auto sibling_hist = this->hist_[subtraction_node_id];
|
||||
|
||||
check_hist(parent_hist, this_hist, sibling_hist, 0, nbins);
|
||||
++node_id;
|
||||
}
|
||||
for (const ExpandEntryT& node : this->nodes_for_subtraction_trick_) {
|
||||
node_id = 0;
|
||||
for (const CPUExpandEntry& node : this->nodes_for_subtraction_trick_) {
|
||||
auto this_hist = this->hist_[node.nid];
|
||||
const size_t parent_id = (*tree)[node.nid].Parent();
|
||||
const size_t subtraction_node_id = this->nodes_for_explicit_hist_build_[node_id].nid;
|
||||
auto parent_hist = this->hist_[parent_id];
|
||||
auto sibling_hist = this->hist_[node.sibling_nid];
|
||||
auto sibling_hist = this->hist_[subtraction_node_id];
|
||||
|
||||
check_hist(parent_hist, this_hist, sibling_hist, 0, nbins);
|
||||
++node_id;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -408,10 +411,9 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
}
|
||||
|
||||
/* Now compare against result given by EvaluateSplit() */
|
||||
typename RealImpl::ExpandEntry node(RealImpl::ExpandEntry::kRootNid,
|
||||
RealImpl::ExpandEntry::kEmptyNid,
|
||||
tree.GetDepth(0),
|
||||
this->snode_[0].best.loss_chg, 0);
|
||||
CPUExpandEntry node(CPUExpandEntry::kRootNid,
|
||||
tree.GetDepth(0),
|
||||
this->snode_[0].best.loss_chg);
|
||||
RealImpl::EvaluateSplits({node}, gmat, this->hist_, tree);
|
||||
ASSERT_EQ(this->snode_[0].best.SplitIndex(), best_split_feature);
|
||||
ASSERT_EQ(this->snode_[0].best.split_value, gmat.cut.Values()[best_split_threshold]);
|
||||
|
||||
Reference in New Issue
Block a user