Requires setting leaf stat when expanding tree. (#5501)
* Fix GPU Hist feature importance.
This commit is contained in:
parent
dc2950fd90
commit
7d52c0b8c2
@ -22,6 +22,7 @@
|
|||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
|
#include <stack>
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
|
|
||||||
@ -88,6 +89,10 @@ struct RTreeNodeStat {
|
|||||||
bst_float base_weight;
|
bst_float base_weight;
|
||||||
/*! \brief number of child that is leaf node known up to now */
|
/*! \brief number of child that is leaf node known up to now */
|
||||||
int leaf_child_cnt {0};
|
int leaf_child_cnt {0};
|
||||||
|
|
||||||
|
RTreeNodeStat() = default;
|
||||||
|
RTreeNodeStat(float loss_chg, float sum_hess, float weight) :
|
||||||
|
loss_chg{loss_chg}, sum_hess{sum_hess}, base_weight{weight} {}
|
||||||
bool operator==(const RTreeNodeStat& b) const {
|
bool operator==(const RTreeNodeStat& b) const {
|
||||||
return loss_chg == b.loss_chg && sum_hess == b.sum_hess &&
|
return loss_chg == b.loss_chg && sum_hess == b.sum_hess &&
|
||||||
base_weight == b.base_weight && leaf_child_cnt == b.leaf_child_cnt;
|
base_weight == b.base_weight && leaf_child_cnt == b.leaf_child_cnt;
|
||||||
@ -101,8 +106,9 @@ struct RTreeNodeStat {
|
|||||||
class RegTree : public Model {
|
class RegTree : public Model {
|
||||||
public:
|
public:
|
||||||
using SplitCondT = bst_float;
|
using SplitCondT = bst_float;
|
||||||
static constexpr int32_t kInvalidNodeId {-1};
|
static constexpr bst_node_t kInvalidNodeId {-1};
|
||||||
static constexpr uint32_t kDeletedNodeMarker = std::numeric_limits<uint32_t>::max();
|
static constexpr uint32_t kDeletedNodeMarker = std::numeric_limits<uint32_t>::max();
|
||||||
|
static constexpr bst_node_t kRoot { 0 };
|
||||||
|
|
||||||
/*! \brief tree node */
|
/*! \brief tree node */
|
||||||
class Node {
|
class Node {
|
||||||
@ -321,6 +327,31 @@ class RegTree : public Model {
|
|||||||
return nodes_ == b.nodes_ && stats_ == b.stats_ &&
|
return nodes_ == b.nodes_ && stats_ == b.stats_ &&
|
||||||
deleted_nodes_ == b.deleted_nodes_ && param == b.param;
|
deleted_nodes_ == b.deleted_nodes_ && param == b.param;
|
||||||
}
|
}
|
||||||
|
/* \brief Iterate through all nodes in this tree.
|
||||||
|
*
|
||||||
|
* \param Function that accepts a node index, and returns false when iteration should
|
||||||
|
* stop, otherwise returns true.
|
||||||
|
*/
|
||||||
|
template <typename Func> void WalkTree(Func func) const {
|
||||||
|
std::stack<bst_node_t> nodes;
|
||||||
|
nodes.push(kRoot);
|
||||||
|
auto &self = *this;
|
||||||
|
while (!nodes.empty()) {
|
||||||
|
auto nidx = nodes.top();
|
||||||
|
nodes.pop();
|
||||||
|
if (!func(nidx)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto left = self[nidx].LeftChild();
|
||||||
|
auto right = self[nidx].RightChild();
|
||||||
|
if (left != RegTree::kInvalidNodeId) {
|
||||||
|
nodes.push(left);
|
||||||
|
}
|
||||||
|
if (right != RegTree::kInvalidNodeId) {
|
||||||
|
nodes.push(right);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief Compares whether 2 trees are equal from a user's perspective. The equality
|
* \brief Compares whether 2 trees are equal from a user's perspective. The equality
|
||||||
* compares only non-deleted nodes.
|
* compares only non-deleted nodes.
|
||||||
@ -341,13 +372,16 @@ class RegTree : public Model {
|
|||||||
* \param right_leaf_weight The right leaf weight for prediction, modified by learning rate.
|
* \param right_leaf_weight The right leaf weight for prediction, modified by learning rate.
|
||||||
* \param loss_change The loss change.
|
* \param loss_change The loss change.
|
||||||
* \param sum_hess The sum hess.
|
* \param sum_hess The sum hess.
|
||||||
* \param leaf_right_child The right child index of leaf, by default kInvalidNodeId,
|
* \param left_sum The sum hess of left leaf.
|
||||||
* some updaters use the right child index of leaf as a marker
|
* \param right_sum The sum hess of right leaf.
|
||||||
|
* \param leaf_right_child The right child index of leaf, by default kInvalidNodeId,
|
||||||
|
* some updaters use the right child index of leaf as a marker
|
||||||
*/
|
*/
|
||||||
void ExpandNode(int nid, unsigned split_index, bst_float split_value,
|
void ExpandNode(int nid, unsigned split_index, bst_float split_value,
|
||||||
bool default_left, bst_float base_weight,
|
bool default_left, bst_float base_weight,
|
||||||
bst_float left_leaf_weight, bst_float right_leaf_weight,
|
bst_float left_leaf_weight, bst_float right_leaf_weight,
|
||||||
bst_float loss_change, float sum_hess,
|
bst_float loss_change, float sum_hess, float left_sum,
|
||||||
|
float right_sum,
|
||||||
bst_node_t leaf_right_child = kInvalidNodeId) {
|
bst_node_t leaf_right_child = kInvalidNodeId) {
|
||||||
int pleft = this->AllocNode();
|
int pleft = this->AllocNode();
|
||||||
int pright = this->AllocNode();
|
int pright = this->AllocNode();
|
||||||
@ -363,9 +397,9 @@ class RegTree : public Model {
|
|||||||
nodes_[pleft].SetLeaf(left_leaf_weight, leaf_right_child);
|
nodes_[pleft].SetLeaf(left_leaf_weight, leaf_right_child);
|
||||||
nodes_[pright].SetLeaf(right_leaf_weight, leaf_right_child);
|
nodes_[pright].SetLeaf(right_leaf_weight, leaf_right_child);
|
||||||
|
|
||||||
this->Stat(nid).loss_chg = loss_change;
|
this->Stat(nid) = {loss_change, sum_hess, base_weight};
|
||||||
this->Stat(nid).base_weight = base_weight;
|
this->Stat(pleft) = {0.0f, left_sum, left_leaf_weight};
|
||||||
this->Stat(nid).sum_hess = sum_hess;
|
this->Stat(pright) = {0.0f, right_sum, right_leaf_weight};
|
||||||
}
|
}
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
@ -402,6 +436,10 @@ class RegTree : public Model {
|
|||||||
return param.num_nodes - 1 - param.num_deleted;
|
return param.num_nodes - 1 - param.num_deleted;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* \brief Count number of leaves in tree. */
|
||||||
|
bst_node_t GetNumLeaves() const;
|
||||||
|
bst_node_t GetNumSplitNodes() const;
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief dense feature vector that can be taken by RegTree
|
* \brief dense feature vector that can be taken by RegTree
|
||||||
* and can be construct from sparse feature vector.
|
* and can be construct from sparse feature vector.
|
||||||
|
|||||||
@ -607,6 +607,8 @@ XGBOOST_REGISTER_TREE_IO(GraphvizGenerator, "dot")
|
|||||||
return new GraphvizGenerator(fmap, attrs, with_stats);
|
return new GraphvizGenerator(fmap, attrs, with_stats);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
constexpr bst_node_t RegTree::kRoot;
|
||||||
|
|
||||||
std::string RegTree::DumpModel(const FeatureMap& fmap,
|
std::string RegTree::DumpModel(const FeatureMap& fmap,
|
||||||
bool with_stats,
|
bool with_stats,
|
||||||
std::string format) const {
|
std::string format) const {
|
||||||
@ -623,26 +625,40 @@ bool RegTree::Equal(const RegTree& b) const {
|
|||||||
if (NumExtraNodes() != b.NumExtraNodes()) {
|
if (NumExtraNodes() != b.NumExtraNodes()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
auto const& self = *this;
|
||||||
std::stack<bst_node_t> nodes;
|
bool ret { true };
|
||||||
nodes.push(0);
|
this->WalkTree([&self, &b, &ret](bst_node_t nidx) {
|
||||||
auto& self = *this;
|
if (!(self.nodes_.at(nidx) == b.nodes_.at(nidx))) {
|
||||||
while (!nodes.empty()) {
|
ret = false;
|
||||||
auto nid = nodes.top();
|
|
||||||
nodes.pop();
|
|
||||||
if (!(self.nodes_.at(nid) == b.nodes_.at(nid))) {
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto left = self[nid].LeftChild();
|
return true;
|
||||||
auto right = self[nid].RightChild();
|
});
|
||||||
if (left != RegTree::kInvalidNodeId) {
|
return ret;
|
||||||
nodes.push(left);
|
}
|
||||||
}
|
|
||||||
if (right != RegTree::kInvalidNodeId) {
|
bst_node_t RegTree::GetNumLeaves() const {
|
||||||
nodes.push(right);
|
bst_node_t leaves { 0 };
|
||||||
}
|
auto const& self = *this;
|
||||||
}
|
this->WalkTree([&leaves, &self](bst_node_t nidx) {
|
||||||
return true;
|
if (self[nidx].IsLeaf()) {
|
||||||
|
leaves++;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
});
|
||||||
|
return leaves;
|
||||||
|
}
|
||||||
|
|
||||||
|
bst_node_t RegTree::GetNumSplitNodes() const {
|
||||||
|
bst_node_t splits { 0 };
|
||||||
|
auto const& self = *this;
|
||||||
|
this->WalkTree([&splits, &self](bst_node_t nidx) {
|
||||||
|
if (!self[nidx].IsLeaf()) {
|
||||||
|
splits++;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
});
|
||||||
|
return splits;
|
||||||
}
|
}
|
||||||
|
|
||||||
void RegTree::Load(dmlc::Stream* fi) {
|
void RegTree::Load(dmlc::Stream* fi) {
|
||||||
|
|||||||
@ -499,7 +499,9 @@ class ColMaker: public TreeUpdater {
|
|||||||
p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value,
|
p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value,
|
||||||
e.best.DefaultLeft(), e.weight, left_leaf_weight,
|
e.best.DefaultLeft(), e.weight, left_leaf_weight,
|
||||||
right_leaf_weight, e.best.loss_chg,
|
right_leaf_weight, e.best.loss_chg,
|
||||||
e.stats.sum_hess, 0);
|
e.stats.sum_hess,
|
||||||
|
e.best.left_sum.GetHess(), e.best.right_sum.GetHess(),
|
||||||
|
0);
|
||||||
} else {
|
} else {
|
||||||
(*p_tree)[nid].SetLeaf(e.weight * param_.learning_rate);
|
(*p_tree)[nid].SetLeaf(e.weight * param_.learning_rate);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -814,7 +814,8 @@ struct GPUHistMakerDevice {
|
|||||||
tree.ExpandNode(candidate.nid, candidate.split.findex,
|
tree.ExpandNode(candidate.nid, candidate.split.findex,
|
||||||
candidate.split.fvalue, candidate.split.dir == kLeftDir,
|
candidate.split.fvalue, candidate.split.dir == kLeftDir,
|
||||||
base_weight, left_weight, right_weight,
|
base_weight, left_weight, right_weight,
|
||||||
candidate.split.loss_chg, parent_sum.sum_hess);
|
candidate.split.loss_chg, parent_sum.sum_hess,
|
||||||
|
left_stats.GetHess(), right_stats.GetHess());
|
||||||
// Set up child constraints
|
// Set up child constraints
|
||||||
node_value_constraints.resize(tree.GetNodes().size());
|
node_value_constraints.resize(tree.GetNodes().size());
|
||||||
node_value_constraints[candidate.nid].SetChild(
|
node_value_constraints[candidate.nid].SetChild(
|
||||||
|
|||||||
@ -249,7 +249,8 @@ class HistMaker: public BaseMaker {
|
|||||||
p_tree->ExpandNode(nid, best.SplitIndex(), best.split_value,
|
p_tree->ExpandNode(nid, best.SplitIndex(), best.split_value,
|
||||||
best.DefaultLeft(), base_weight, left_leaf_weight,
|
best.DefaultLeft(), base_weight, left_leaf_weight,
|
||||||
right_leaf_weight, best.loss_chg,
|
right_leaf_weight, best.loss_chg,
|
||||||
node_sum.sum_hess);
|
node_sum.sum_hess,
|
||||||
|
best.left_sum.GetHess(), best.right_sum.GetHess());
|
||||||
GradStats right_sum;
|
GradStats right_sum;
|
||||||
right_sum.SetSubstract(node_sum, left_sum[wid]);
|
right_sum.SetSubstract(node_sum, left_sum[wid]);
|
||||||
auto left_child = (*p_tree)[nid].LeftChild();
|
auto left_child = (*p_tree)[nid].LeftChild();
|
||||||
|
|||||||
@ -263,7 +263,8 @@ void QuantileHistMaker::Builder::AddSplitsToTree(
|
|||||||
spliteval_->ComputeWeight(nid, e.best.right_sum) * param_.learning_rate;
|
spliteval_->ComputeWeight(nid, e.best.right_sum) * param_.learning_rate;
|
||||||
p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value,
|
p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value,
|
||||||
e.best.DefaultLeft(), e.weight, left_leaf_weight,
|
e.best.DefaultLeft(), e.weight, left_leaf_weight,
|
||||||
right_leaf_weight, e.best.loss_chg, e.stats.sum_hess);
|
right_leaf_weight, e.best.loss_chg, e.stats.sum_hess,
|
||||||
|
e.best.left_sum.GetHess(), e.best.right_sum.GetHess());
|
||||||
|
|
||||||
int left_id = (*p_tree)[nid].LeftChild();
|
int left_id = (*p_tree)[nid].LeftChild();
|
||||||
int right_id = (*p_tree)[nid].RightChild();
|
int right_id = (*p_tree)[nid].RightChild();
|
||||||
@ -410,7 +411,8 @@ void QuantileHistMaker::Builder::ExpandWithLossGuide(
|
|||||||
spliteval_->ComputeWeight(nid, e.best.right_sum) * param_.learning_rate;
|
spliteval_->ComputeWeight(nid, e.best.right_sum) * param_.learning_rate;
|
||||||
p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value,
|
p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value,
|
||||||
e.best.DefaultLeft(), e.weight, left_leaf_weight,
|
e.best.DefaultLeft(), e.weight, left_leaf_weight,
|
||||||
right_leaf_weight, e.best.loss_chg, e.stats.sum_hess);
|
right_leaf_weight, e.best.loss_chg, e.stats.sum_hess,
|
||||||
|
e.best.left_sum.GetHess(), e.best.right_sum.GetHess());
|
||||||
|
|
||||||
this->ApplySplit({candidate}, gmat, column_matrix, hist_, p_tree);
|
this->ApplySplit({candidate}, gmat, column_matrix, hist_, p_tree);
|
||||||
|
|
||||||
|
|||||||
@ -289,7 +289,8 @@ class SketchMaker: public BaseMaker {
|
|||||||
p_tree->ExpandNode(nid, best.SplitIndex(), best.split_value,
|
p_tree->ExpandNode(nid, best.SplitIndex(), best.split_value,
|
||||||
best.DefaultLeft(), base_weight, left_leaf_weight,
|
best.DefaultLeft(), base_weight, left_leaf_weight,
|
||||||
right_leaf_weight, best.loss_chg,
|
right_leaf_weight, best.loss_chg,
|
||||||
node_stats_[nid].sum_hess);
|
node_stats_[nid].sum_hess,
|
||||||
|
best.left_sum.GetHess(), best.right_sum.GetHess());
|
||||||
} else {
|
} else {
|
||||||
(*p_tree)[nid].SetLeaf(p_tree->Stat(nid).base_weight * param_.learning_rate);
|
(*p_tree)[nid].SetLeaf(p_tree->Stat(nid).base_weight * param_.learning_rate);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -42,13 +42,15 @@ TEST(Updater, Prune) {
|
|||||||
pruner->Configure(cfg);
|
pruner->Configure(cfg);
|
||||||
|
|
||||||
// loss_chg < min_split_loss;
|
// loss_chg < min_split_loss;
|
||||||
tree.ExpandNode(0, 0, 0, true, 0.0f, 0.3f, 0.4f, 0.0f, 0.0f);
|
tree.ExpandNode(0, 0, 0, true, 0.0f, 0.3f, 0.4f, 0.0f, 0.0f,
|
||||||
|
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
|
||||||
pruner->Update(&gpair, p_dmat.get(), trees);
|
pruner->Update(&gpair, p_dmat.get(), trees);
|
||||||
|
|
||||||
ASSERT_EQ(tree.NumExtraNodes(), 0);
|
ASSERT_EQ(tree.NumExtraNodes(), 0);
|
||||||
|
|
||||||
// loss_chg > min_split_loss;
|
// loss_chg > min_split_loss;
|
||||||
tree.ExpandNode(0, 0, 0, true, 0.0f, 0.3f, 0.4f, 11.0f, 0.0f);
|
tree.ExpandNode(0, 0, 0, true, 0.0f, 0.3f, 0.4f, 11.0f, 0.0f,
|
||||||
|
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
|
||||||
pruner->Update(&gpair, p_dmat.get(), trees);
|
pruner->Update(&gpair, p_dmat.get(), trees);
|
||||||
|
|
||||||
ASSERT_EQ(tree.NumExtraNodes(), 2);
|
ASSERT_EQ(tree.NumExtraNodes(), 2);
|
||||||
@ -63,10 +65,12 @@ TEST(Updater, Prune) {
|
|||||||
// loss_chg > min_split_loss
|
// loss_chg > min_split_loss
|
||||||
tree.ExpandNode(tree[0].LeftChild(),
|
tree.ExpandNode(tree[0].LeftChild(),
|
||||||
0, 0.5f, true, 0.3, 0.4, 0.5,
|
0, 0.5f, true, 0.3, 0.4, 0.5,
|
||||||
/*loss_chg=*/18.0f, 0.0f);
|
/*loss_chg=*/18.0f, 0.0f,
|
||||||
|
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
|
||||||
tree.ExpandNode(tree[0].RightChild(),
|
tree.ExpandNode(tree[0].RightChild(),
|
||||||
0, 0.5f, true, 0.3, 0.4, 0.5,
|
0, 0.5f, true, 0.3, 0.4, 0.5,
|
||||||
/*loss_chg=*/19.0f, 0.0f);
|
/*loss_chg=*/19.0f, 0.0f,
|
||||||
|
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
|
||||||
cfg.emplace_back(std::make_pair("max_depth", "1"));
|
cfg.emplace_back(std::make_pair("max_depth", "1"));
|
||||||
pruner->Configure(cfg);
|
pruner->Configure(cfg);
|
||||||
pruner->Update(&gpair, p_dmat.get(), trees);
|
pruner->Update(&gpair, p_dmat.get(), trees);
|
||||||
@ -75,7 +79,8 @@ TEST(Updater, Prune) {
|
|||||||
|
|
||||||
tree.ExpandNode(tree[0].LeftChild(),
|
tree.ExpandNode(tree[0].LeftChild(),
|
||||||
0, 0.5f, true, 0.3, 0.4, 0.5,
|
0, 0.5f, true, 0.3, 0.4, 0.5,
|
||||||
/*loss_chg=*/18.0f, 0.0f);
|
/*loss_chg=*/18.0f, 0.0f,
|
||||||
|
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
|
||||||
cfg.emplace_back(std::make_pair("min_split_loss", "0"));
|
cfg.emplace_back(std::make_pair("min_split_loss", "0"));
|
||||||
pruner->Configure(cfg);
|
pruner->Configure(cfg);
|
||||||
pruner->Update(&gpair, p_dmat.get(), trees);
|
pruner->Update(&gpair, p_dmat.get(), trees);
|
||||||
|
|||||||
@ -34,7 +34,8 @@ TEST(Updater, Refresh) {
|
|||||||
std::vector<RegTree*> trees {&tree};
|
std::vector<RegTree*> trees {&tree};
|
||||||
std::unique_ptr<TreeUpdater> refresher(TreeUpdater::Create("refresh", &lparam));
|
std::unique_ptr<TreeUpdater> refresher(TreeUpdater::Create("refresh", &lparam));
|
||||||
|
|
||||||
tree.ExpandNode(0, 2, 0.2f, false, 0.0, 0.2f, 0.8f, 0.0f, 0.0f);
|
tree.ExpandNode(0, 2, 0.2f, false, 0.0, 0.2f, 0.8f, 0.0f, 0.0f,
|
||||||
|
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
|
||||||
int cleft = tree[0].LeftChild();
|
int cleft = tree[0].LeftChild();
|
||||||
int cright = tree[0].RightChild();
|
int cright = tree[0].RightChild();
|
||||||
|
|
||||||
|
|||||||
@ -88,13 +88,13 @@ TEST(Tree, Load) {
|
|||||||
|
|
||||||
TEST(Tree, AllocateNode) {
|
TEST(Tree, AllocateNode) {
|
||||||
RegTree tree;
|
RegTree tree;
|
||||||
tree.ExpandNode(
|
tree.ExpandNode(0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
|
||||||
0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
|
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
|
||||||
tree.CollapseToLeaf(0, 0);
|
tree.CollapseToLeaf(0, 0);
|
||||||
ASSERT_EQ(tree.NumExtraNodes(), 0);
|
ASSERT_EQ(tree.NumExtraNodes(), 0);
|
||||||
|
|
||||||
tree.ExpandNode(
|
tree.ExpandNode(0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
|
||||||
0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
|
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
|
||||||
ASSERT_EQ(tree.NumExtraNodes(), 2);
|
ASSERT_EQ(tree.NumExtraNodes(), 2);
|
||||||
|
|
||||||
auto& nodes = tree.GetNodes();
|
auto& nodes = tree.GetNodes();
|
||||||
@ -107,18 +107,18 @@ RegTree ConstructTree() {
|
|||||||
RegTree tree;
|
RegTree tree;
|
||||||
tree.ExpandNode(
|
tree.ExpandNode(
|
||||||
/*nid=*/0, /*split_index=*/0, /*split_value=*/0.0f,
|
/*nid=*/0, /*split_index=*/0, /*split_value=*/0.0f,
|
||||||
/*default_left=*/true,
|
/*default_left=*/true, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, /*left_sum=*/0.0f,
|
||||||
0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
|
/*right_sum=*/0.0f);
|
||||||
auto left = tree[0].LeftChild();
|
auto left = tree[0].LeftChild();
|
||||||
auto right = tree[0].RightChild();
|
auto right = tree[0].RightChild();
|
||||||
tree.ExpandNode(
|
tree.ExpandNode(
|
||||||
/*nid=*/left, /*split_index=*/1, /*split_value=*/1.0f,
|
/*nid=*/left, /*split_index=*/1, /*split_value=*/1.0f,
|
||||||
/*default_left=*/false,
|
/*default_left=*/false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, /*left_sum=*/0.0f,
|
||||||
0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
|
/*right_sum=*/0.0f);
|
||||||
tree.ExpandNode(
|
tree.ExpandNode(
|
||||||
/*nid=*/right, /*split_index=*/2, /*split_value=*/2.0f,
|
/*nid=*/right, /*split_index=*/2, /*split_value=*/2.0f,
|
||||||
/*default_left=*/false,
|
/*default_left=*/false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, /*left_sum=*/0.0f,
|
||||||
0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
|
/*right_sum=*/0.0f);
|
||||||
return tree;
|
return tree;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -222,7 +222,8 @@ TEST(Tree, DumpDot) {
|
|||||||
|
|
||||||
TEST(Tree, JsonIO) {
|
TEST(Tree, JsonIO) {
|
||||||
RegTree tree;
|
RegTree tree;
|
||||||
tree.ExpandNode(0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
|
tree.ExpandNode(0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
|
||||||
|
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
|
||||||
Json j_tree{Object()};
|
Json j_tree{Object()};
|
||||||
tree.SaveModel(&j_tree);
|
tree.SaveModel(&j_tree);
|
||||||
|
|
||||||
@ -246,8 +247,10 @@ TEST(Tree, JsonIO) {
|
|||||||
|
|
||||||
auto left = tree[0].LeftChild();
|
auto left = tree[0].LeftChild();
|
||||||
auto right = tree[0].RightChild();
|
auto right = tree[0].RightChild();
|
||||||
tree.ExpandNode(left, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
|
tree.ExpandNode(left, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
|
||||||
tree.ExpandNode(right, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
|
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
|
||||||
|
tree.ExpandNode(right, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
|
||||||
|
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
|
||||||
tree.SaveModel(&j_tree);
|
tree.SaveModel(&j_tree);
|
||||||
|
|
||||||
tree.ChangeToLeaf(1, 1.0f);
|
tree.ChangeToLeaf(1, 1.0f);
|
||||||
|
|||||||
59
tests/cpp/tree/test_tree_stat.cc
Normal file
59
tests/cpp/tree/test_tree_stat.cc
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
#include <xgboost/tree_updater.h>
|
||||||
|
#include <xgboost/tree_model.h>
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include "../helpers.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
class UpdaterTreeStatTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
std::shared_ptr<DMatrix> p_dmat_;
|
||||||
|
HostDeviceVector<GradientPair> gpairs_;
|
||||||
|
size_t constexpr static kRows = 10;
|
||||||
|
size_t constexpr static kCols = 10;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void SetUp() override {
|
||||||
|
p_dmat_ = RandomDataGenerator(kRows, kCols, .5f).GenerateDMatix(true);
|
||||||
|
auto g = GenerateRandomGradients(kRows);
|
||||||
|
gpairs_.Resize(kRows);
|
||||||
|
gpairs_.Copy(g);
|
||||||
|
}
|
||||||
|
|
||||||
|
void RunTest(std::string updater) {
|
||||||
|
auto tparam = CreateEmptyGenericParam(0);
|
||||||
|
auto up = std::unique_ptr<TreeUpdater>{
|
||||||
|
TreeUpdater::Create(updater, &tparam)};
|
||||||
|
up->Configure(Args{});
|
||||||
|
RegTree tree;
|
||||||
|
tree.param.num_feature = kCols;
|
||||||
|
up->Update(&gpairs_, p_dmat_.get(), {&tree});
|
||||||
|
|
||||||
|
tree.WalkTree([&tree](bst_node_t nidx) {
|
||||||
|
if (tree[nidx].IsLeaf()) {
|
||||||
|
// 1.0 is the default `min_child_weight`.
|
||||||
|
CHECK_GE(tree.Stat(nidx).sum_hess, 1.0);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
|
TEST_F(UpdaterTreeStatTest, GPUHist) {
|
||||||
|
this->RunTest("grow_gpu_hist");
|
||||||
|
}
|
||||||
|
#endif // defined(XGBOOST_USE_CUDA)
|
||||||
|
|
||||||
|
TEST_F(UpdaterTreeStatTest, Hist) {
|
||||||
|
this->RunTest("grow_quantile_histmaker");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(UpdaterTreeStatTest, Exact) {
|
||||||
|
this->RunTest("grow_colmaker");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(UpdaterTreeStatTest, Approx) {
|
||||||
|
this->RunTest("grow_histmaker");
|
||||||
|
}
|
||||||
|
} // namespace xgboost
|
||||||
Loading…
x
Reference in New Issue
Block a user