Requires setting leaf stat when expanding tree. (#5501)

* Fix GPU Hist feature importance.
This commit is contained in:
Jiaming Yuan 2020-04-10 12:27:03 +08:00 committed by GitHub
parent dc2950fd90
commit 7d52c0b8c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 179 additions and 50 deletions

View File

@ -22,6 +22,7 @@
#include <cstring> #include <cstring>
#include <algorithm> #include <algorithm>
#include <tuple> #include <tuple>
#include <stack>
namespace xgboost { namespace xgboost {
@ -88,6 +89,10 @@ struct RTreeNodeStat {
bst_float base_weight; bst_float base_weight;
/*! \brief number of child that is leaf node known up to now */ /*! \brief number of child that is leaf node known up to now */
int leaf_child_cnt {0}; int leaf_child_cnt {0};
RTreeNodeStat() = default;
RTreeNodeStat(float loss_chg, float sum_hess, float weight) :
loss_chg{loss_chg}, sum_hess{sum_hess}, base_weight{weight} {}
bool operator==(const RTreeNodeStat& b) const { bool operator==(const RTreeNodeStat& b) const {
return loss_chg == b.loss_chg && sum_hess == b.sum_hess && return loss_chg == b.loss_chg && sum_hess == b.sum_hess &&
base_weight == b.base_weight && leaf_child_cnt == b.leaf_child_cnt; base_weight == b.base_weight && leaf_child_cnt == b.leaf_child_cnt;
@ -101,8 +106,9 @@ struct RTreeNodeStat {
class RegTree : public Model { class RegTree : public Model {
public: public:
using SplitCondT = bst_float; using SplitCondT = bst_float;
static constexpr int32_t kInvalidNodeId {-1}; static constexpr bst_node_t kInvalidNodeId {-1};
static constexpr uint32_t kDeletedNodeMarker = std::numeric_limits<uint32_t>::max(); static constexpr uint32_t kDeletedNodeMarker = std::numeric_limits<uint32_t>::max();
static constexpr bst_node_t kRoot { 0 };
/*! \brief tree node */ /*! \brief tree node */
class Node { class Node {
@ -321,6 +327,31 @@ class RegTree : public Model {
return nodes_ == b.nodes_ && stats_ == b.stats_ && return nodes_ == b.nodes_ && stats_ == b.stats_ &&
deleted_nodes_ == b.deleted_nodes_ && param == b.param; deleted_nodes_ == b.deleted_nodes_ && param == b.param;
} }
/* \brief Iterate through all nodes in this tree.
*
* \param Function that accepts a node index, and returns false when iteration should
* stop, otherwise returns true.
*/
template <typename Func> void WalkTree(Func func) const {
std::stack<bst_node_t> nodes;
nodes.push(kRoot);
auto &self = *this;
while (!nodes.empty()) {
auto nidx = nodes.top();
nodes.pop();
if (!func(nidx)) {
return;
}
auto left = self[nidx].LeftChild();
auto right = self[nidx].RightChild();
if (left != RegTree::kInvalidNodeId) {
nodes.push(left);
}
if (right != RegTree::kInvalidNodeId) {
nodes.push(right);
}
}
}
/*! /*!
* \brief Compares whether 2 trees are equal from a user's perspective. The equality * \brief Compares whether 2 trees are equal from a user's perspective. The equality
* compares only non-deleted nodes. * compares only non-deleted nodes.
@ -341,13 +372,16 @@ class RegTree : public Model {
* \param right_leaf_weight The right leaf weight for prediction, modified by learning rate. * \param right_leaf_weight The right leaf weight for prediction, modified by learning rate.
* \param loss_change The loss change. * \param loss_change The loss change.
* \param sum_hess The sum hess. * \param sum_hess The sum hess.
* \param left_sum The sum hess of left leaf.
* \param right_sum The sum hess of right leaf.
* \param leaf_right_child The right child index of leaf, by default kInvalidNodeId, * \param leaf_right_child The right child index of leaf, by default kInvalidNodeId,
* some updaters use the right child index of leaf as a marker * some updaters use the right child index of leaf as a marker
*/ */
void ExpandNode(int nid, unsigned split_index, bst_float split_value, void ExpandNode(int nid, unsigned split_index, bst_float split_value,
bool default_left, bst_float base_weight, bool default_left, bst_float base_weight,
bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float left_leaf_weight, bst_float right_leaf_weight,
bst_float loss_change, float sum_hess, bst_float loss_change, float sum_hess, float left_sum,
float right_sum,
bst_node_t leaf_right_child = kInvalidNodeId) { bst_node_t leaf_right_child = kInvalidNodeId) {
int pleft = this->AllocNode(); int pleft = this->AllocNode();
int pright = this->AllocNode(); int pright = this->AllocNode();
@ -363,9 +397,9 @@ class RegTree : public Model {
nodes_[pleft].SetLeaf(left_leaf_weight, leaf_right_child); nodes_[pleft].SetLeaf(left_leaf_weight, leaf_right_child);
nodes_[pright].SetLeaf(right_leaf_weight, leaf_right_child); nodes_[pright].SetLeaf(right_leaf_weight, leaf_right_child);
this->Stat(nid).loss_chg = loss_change; this->Stat(nid) = {loss_change, sum_hess, base_weight};
this->Stat(nid).base_weight = base_weight; this->Stat(pleft) = {0.0f, left_sum, left_leaf_weight};
this->Stat(nid).sum_hess = sum_hess; this->Stat(pright) = {0.0f, right_sum, right_leaf_weight};
} }
/*! /*!
@ -402,6 +436,10 @@ class RegTree : public Model {
return param.num_nodes - 1 - param.num_deleted; return param.num_nodes - 1 - param.num_deleted;
} }
/* \brief Count number of leaves in tree. */
bst_node_t GetNumLeaves() const;
bst_node_t GetNumSplitNodes() const;
/*! /*!
* \brief dense feature vector that can be taken by RegTree * \brief dense feature vector that can be taken by RegTree
* and can be construct from sparse feature vector. * and can be construct from sparse feature vector.

View File

@ -607,6 +607,8 @@ XGBOOST_REGISTER_TREE_IO(GraphvizGenerator, "dot")
return new GraphvizGenerator(fmap, attrs, with_stats); return new GraphvizGenerator(fmap, attrs, with_stats);
}); });
constexpr bst_node_t RegTree::kRoot;
std::string RegTree::DumpModel(const FeatureMap& fmap, std::string RegTree::DumpModel(const FeatureMap& fmap,
bool with_stats, bool with_stats,
std::string format) const { std::string format) const {
@ -623,26 +625,40 @@ bool RegTree::Equal(const RegTree& b) const {
if (NumExtraNodes() != b.NumExtraNodes()) { if (NumExtraNodes() != b.NumExtraNodes()) {
return false; return false;
} }
auto const& self = *this;
std::stack<bst_node_t> nodes; bool ret { true };
nodes.push(0); this->WalkTree([&self, &b, &ret](bst_node_t nidx) {
auto& self = *this; if (!(self.nodes_.at(nidx) == b.nodes_.at(nidx))) {
while (!nodes.empty()) { ret = false;
auto nid = nodes.top();
nodes.pop();
if (!(self.nodes_.at(nid) == b.nodes_.at(nid))) {
return false; return false;
} }
auto left = self[nid].LeftChild(); return true;
auto right = self[nid].RightChild(); });
if (left != RegTree::kInvalidNodeId) { return ret;
nodes.push(left);
}
if (right != RegTree::kInvalidNodeId) {
nodes.push(right);
} }
bst_node_t RegTree::GetNumLeaves() const {
bst_node_t leaves { 0 };
auto const& self = *this;
this->WalkTree([&leaves, &self](bst_node_t nidx) {
if (self[nidx].IsLeaf()) {
leaves++;
} }
return true; return true;
});
return leaves;
}
bst_node_t RegTree::GetNumSplitNodes() const {
bst_node_t splits { 0 };
auto const& self = *this;
this->WalkTree([&splits, &self](bst_node_t nidx) {
if (!self[nidx].IsLeaf()) {
splits++;
}
return true;
});
return splits;
} }
void RegTree::Load(dmlc::Stream* fi) { void RegTree::Load(dmlc::Stream* fi) {

View File

@ -499,7 +499,9 @@ class ColMaker: public TreeUpdater {
p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value, p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value,
e.best.DefaultLeft(), e.weight, left_leaf_weight, e.best.DefaultLeft(), e.weight, left_leaf_weight,
right_leaf_weight, e.best.loss_chg, right_leaf_weight, e.best.loss_chg,
e.stats.sum_hess, 0); e.stats.sum_hess,
e.best.left_sum.GetHess(), e.best.right_sum.GetHess(),
0);
} else { } else {
(*p_tree)[nid].SetLeaf(e.weight * param_.learning_rate); (*p_tree)[nid].SetLeaf(e.weight * param_.learning_rate);
} }

View File

@ -814,7 +814,8 @@ struct GPUHistMakerDevice {
tree.ExpandNode(candidate.nid, candidate.split.findex, tree.ExpandNode(candidate.nid, candidate.split.findex,
candidate.split.fvalue, candidate.split.dir == kLeftDir, candidate.split.fvalue, candidate.split.dir == kLeftDir,
base_weight, left_weight, right_weight, base_weight, left_weight, right_weight,
candidate.split.loss_chg, parent_sum.sum_hess); candidate.split.loss_chg, parent_sum.sum_hess,
left_stats.GetHess(), right_stats.GetHess());
// Set up child constraints // Set up child constraints
node_value_constraints.resize(tree.GetNodes().size()); node_value_constraints.resize(tree.GetNodes().size());
node_value_constraints[candidate.nid].SetChild( node_value_constraints[candidate.nid].SetChild(

View File

@ -249,7 +249,8 @@ class HistMaker: public BaseMaker {
p_tree->ExpandNode(nid, best.SplitIndex(), best.split_value, p_tree->ExpandNode(nid, best.SplitIndex(), best.split_value,
best.DefaultLeft(), base_weight, left_leaf_weight, best.DefaultLeft(), base_weight, left_leaf_weight,
right_leaf_weight, best.loss_chg, right_leaf_weight, best.loss_chg,
node_sum.sum_hess); node_sum.sum_hess,
best.left_sum.GetHess(), best.right_sum.GetHess());
GradStats right_sum; GradStats right_sum;
right_sum.SetSubstract(node_sum, left_sum[wid]); right_sum.SetSubstract(node_sum, left_sum[wid]);
auto left_child = (*p_tree)[nid].LeftChild(); auto left_child = (*p_tree)[nid].LeftChild();

View File

@ -263,7 +263,8 @@ void QuantileHistMaker::Builder::AddSplitsToTree(
spliteval_->ComputeWeight(nid, e.best.right_sum) * param_.learning_rate; spliteval_->ComputeWeight(nid, e.best.right_sum) * param_.learning_rate;
p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value, p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value,
e.best.DefaultLeft(), e.weight, left_leaf_weight, e.best.DefaultLeft(), e.weight, left_leaf_weight,
right_leaf_weight, e.best.loss_chg, e.stats.sum_hess); right_leaf_weight, e.best.loss_chg, e.stats.sum_hess,
e.best.left_sum.GetHess(), e.best.right_sum.GetHess());
int left_id = (*p_tree)[nid].LeftChild(); int left_id = (*p_tree)[nid].LeftChild();
int right_id = (*p_tree)[nid].RightChild(); int right_id = (*p_tree)[nid].RightChild();
@ -410,7 +411,8 @@ void QuantileHistMaker::Builder::ExpandWithLossGuide(
spliteval_->ComputeWeight(nid, e.best.right_sum) * param_.learning_rate; spliteval_->ComputeWeight(nid, e.best.right_sum) * param_.learning_rate;
p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value, p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value,
e.best.DefaultLeft(), e.weight, left_leaf_weight, e.best.DefaultLeft(), e.weight, left_leaf_weight,
right_leaf_weight, e.best.loss_chg, e.stats.sum_hess); right_leaf_weight, e.best.loss_chg, e.stats.sum_hess,
e.best.left_sum.GetHess(), e.best.right_sum.GetHess());
this->ApplySplit({candidate}, gmat, column_matrix, hist_, p_tree); this->ApplySplit({candidate}, gmat, column_matrix, hist_, p_tree);

View File

@ -289,7 +289,8 @@ class SketchMaker: public BaseMaker {
p_tree->ExpandNode(nid, best.SplitIndex(), best.split_value, p_tree->ExpandNode(nid, best.SplitIndex(), best.split_value,
best.DefaultLeft(), base_weight, left_leaf_weight, best.DefaultLeft(), base_weight, left_leaf_weight,
right_leaf_weight, best.loss_chg, right_leaf_weight, best.loss_chg,
node_stats_[nid].sum_hess); node_stats_[nid].sum_hess,
best.left_sum.GetHess(), best.right_sum.GetHess());
} else { } else {
(*p_tree)[nid].SetLeaf(p_tree->Stat(nid).base_weight * param_.learning_rate); (*p_tree)[nid].SetLeaf(p_tree->Stat(nid).base_weight * param_.learning_rate);
} }

View File

@ -42,13 +42,15 @@ TEST(Updater, Prune) {
pruner->Configure(cfg); pruner->Configure(cfg);
// loss_chg < min_split_loss; // loss_chg < min_split_loss;
tree.ExpandNode(0, 0, 0, true, 0.0f, 0.3f, 0.4f, 0.0f, 0.0f); tree.ExpandNode(0, 0, 0, true, 0.0f, 0.3f, 0.4f, 0.0f, 0.0f,
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
pruner->Update(&gpair, p_dmat.get(), trees); pruner->Update(&gpair, p_dmat.get(), trees);
ASSERT_EQ(tree.NumExtraNodes(), 0); ASSERT_EQ(tree.NumExtraNodes(), 0);
// loss_chg > min_split_loss; // loss_chg > min_split_loss;
tree.ExpandNode(0, 0, 0, true, 0.0f, 0.3f, 0.4f, 11.0f, 0.0f); tree.ExpandNode(0, 0, 0, true, 0.0f, 0.3f, 0.4f, 11.0f, 0.0f,
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
pruner->Update(&gpair, p_dmat.get(), trees); pruner->Update(&gpair, p_dmat.get(), trees);
ASSERT_EQ(tree.NumExtraNodes(), 2); ASSERT_EQ(tree.NumExtraNodes(), 2);
@ -63,10 +65,12 @@ TEST(Updater, Prune) {
// loss_chg > min_split_loss // loss_chg > min_split_loss
tree.ExpandNode(tree[0].LeftChild(), tree.ExpandNode(tree[0].LeftChild(),
0, 0.5f, true, 0.3, 0.4, 0.5, 0, 0.5f, true, 0.3, 0.4, 0.5,
/*loss_chg=*/18.0f, 0.0f); /*loss_chg=*/18.0f, 0.0f,
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
tree.ExpandNode(tree[0].RightChild(), tree.ExpandNode(tree[0].RightChild(),
0, 0.5f, true, 0.3, 0.4, 0.5, 0, 0.5f, true, 0.3, 0.4, 0.5,
/*loss_chg=*/19.0f, 0.0f); /*loss_chg=*/19.0f, 0.0f,
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
cfg.emplace_back(std::make_pair("max_depth", "1")); cfg.emplace_back(std::make_pair("max_depth", "1"));
pruner->Configure(cfg); pruner->Configure(cfg);
pruner->Update(&gpair, p_dmat.get(), trees); pruner->Update(&gpair, p_dmat.get(), trees);
@ -75,7 +79,8 @@ TEST(Updater, Prune) {
tree.ExpandNode(tree[0].LeftChild(), tree.ExpandNode(tree[0].LeftChild(),
0, 0.5f, true, 0.3, 0.4, 0.5, 0, 0.5f, true, 0.3, 0.4, 0.5,
/*loss_chg=*/18.0f, 0.0f); /*loss_chg=*/18.0f, 0.0f,
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
cfg.emplace_back(std::make_pair("min_split_loss", "0")); cfg.emplace_back(std::make_pair("min_split_loss", "0"));
pruner->Configure(cfg); pruner->Configure(cfg);
pruner->Update(&gpair, p_dmat.get(), trees); pruner->Update(&gpair, p_dmat.get(), trees);

View File

@ -34,7 +34,8 @@ TEST(Updater, Refresh) {
std::vector<RegTree*> trees {&tree}; std::vector<RegTree*> trees {&tree};
std::unique_ptr<TreeUpdater> refresher(TreeUpdater::Create("refresh", &lparam)); std::unique_ptr<TreeUpdater> refresher(TreeUpdater::Create("refresh", &lparam));
tree.ExpandNode(0, 2, 0.2f, false, 0.0, 0.2f, 0.8f, 0.0f, 0.0f); tree.ExpandNode(0, 2, 0.2f, false, 0.0, 0.2f, 0.8f, 0.0f, 0.0f,
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
int cleft = tree[0].LeftChild(); int cleft = tree[0].LeftChild();
int cright = tree[0].RightChild(); int cright = tree[0].RightChild();

View File

@ -88,13 +88,13 @@ TEST(Tree, Load) {
TEST(Tree, AllocateNode) { TEST(Tree, AllocateNode) {
RegTree tree; RegTree tree;
tree.ExpandNode( tree.ExpandNode(0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f); /*left_sum=*/0.0f, /*right_sum=*/0.0f);
tree.CollapseToLeaf(0, 0); tree.CollapseToLeaf(0, 0);
ASSERT_EQ(tree.NumExtraNodes(), 0); ASSERT_EQ(tree.NumExtraNodes(), 0);
tree.ExpandNode( tree.ExpandNode(0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f); /*left_sum=*/0.0f, /*right_sum=*/0.0f);
ASSERT_EQ(tree.NumExtraNodes(), 2); ASSERT_EQ(tree.NumExtraNodes(), 2);
auto& nodes = tree.GetNodes(); auto& nodes = tree.GetNodes();
@ -107,18 +107,18 @@ RegTree ConstructTree() {
RegTree tree; RegTree tree;
tree.ExpandNode( tree.ExpandNode(
/*nid=*/0, /*split_index=*/0, /*split_value=*/0.0f, /*nid=*/0, /*split_index=*/0, /*split_value=*/0.0f,
/*default_left=*/true, /*default_left=*/true, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, /*left_sum=*/0.0f,
0.0f, 0.0f, 0.0f, 0.0f, 0.0f); /*right_sum=*/0.0f);
auto left = tree[0].LeftChild(); auto left = tree[0].LeftChild();
auto right = tree[0].RightChild(); auto right = tree[0].RightChild();
tree.ExpandNode( tree.ExpandNode(
/*nid=*/left, /*split_index=*/1, /*split_value=*/1.0f, /*nid=*/left, /*split_index=*/1, /*split_value=*/1.0f,
/*default_left=*/false, /*default_left=*/false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, /*left_sum=*/0.0f,
0.0f, 0.0f, 0.0f, 0.0f, 0.0f); /*right_sum=*/0.0f);
tree.ExpandNode( tree.ExpandNode(
/*nid=*/right, /*split_index=*/2, /*split_value=*/2.0f, /*nid=*/right, /*split_index=*/2, /*split_value=*/2.0f,
/*default_left=*/false, /*default_left=*/false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, /*left_sum=*/0.0f,
0.0f, 0.0f, 0.0f, 0.0f, 0.0f); /*right_sum=*/0.0f);
return tree; return tree;
} }
@ -222,7 +222,8 @@ TEST(Tree, DumpDot) {
TEST(Tree, JsonIO) { TEST(Tree, JsonIO) {
RegTree tree; RegTree tree;
tree.ExpandNode(0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f); tree.ExpandNode(0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
Json j_tree{Object()}; Json j_tree{Object()};
tree.SaveModel(&j_tree); tree.SaveModel(&j_tree);
@ -246,8 +247,10 @@ TEST(Tree, JsonIO) {
auto left = tree[0].LeftChild(); auto left = tree[0].LeftChild();
auto right = tree[0].RightChild(); auto right = tree[0].RightChild();
tree.ExpandNode(left, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f); tree.ExpandNode(left, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
tree.ExpandNode(right, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f); /*left_sum=*/0.0f, /*right_sum=*/0.0f);
tree.ExpandNode(right, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
tree.SaveModel(&j_tree); tree.SaveModel(&j_tree);
tree.ChangeToLeaf(1, 1.0f); tree.ChangeToLeaf(1, 1.0f);

View File

@ -0,0 +1,59 @@
#include <xgboost/tree_updater.h>
#include <xgboost/tree_model.h>
#include <gtest/gtest.h>
#include "../helpers.h"
namespace xgboost {
class UpdaterTreeStatTest : public ::testing::Test {
protected:
std::shared_ptr<DMatrix> p_dmat_;
HostDeviceVector<GradientPair> gpairs_;
size_t constexpr static kRows = 10;
size_t constexpr static kCols = 10;
protected:
void SetUp() override {
p_dmat_ = RandomDataGenerator(kRows, kCols, .5f).GenerateDMatix(true);
auto g = GenerateRandomGradients(kRows);
gpairs_.Resize(kRows);
gpairs_.Copy(g);
}
void RunTest(std::string updater) {
auto tparam = CreateEmptyGenericParam(0);
auto up = std::unique_ptr<TreeUpdater>{
TreeUpdater::Create(updater, &tparam)};
up->Configure(Args{});
RegTree tree;
tree.param.num_feature = kCols;
up->Update(&gpairs_, p_dmat_.get(), {&tree});
tree.WalkTree([&tree](bst_node_t nidx) {
if (tree[nidx].IsLeaf()) {
// 1.0 is the default `min_child_weight`.
CHECK_GE(tree.Stat(nidx).sum_hess, 1.0);
}
return true;
});
}
};
#if defined(XGBOOST_USE_CUDA)
TEST_F(UpdaterTreeStatTest, GPUHist) {
this->RunTest("grow_gpu_hist");
}
#endif // defined(XGBOOST_USE_CUDA)
TEST_F(UpdaterTreeStatTest, Hist) {
this->RunTest("grow_quantile_histmaker");
}
TEST_F(UpdaterTreeStatTest, Exact) {
this->RunTest("grow_colmaker");
}
TEST_F(UpdaterTreeStatTest, Approx) {
this->RunTest("grow_histmaker");
}
} // namespace xgboost