Requires setting leaf stat when expanding tree. (#5501)
* Fix GPU Hist feature importance.
This commit is contained in:
@@ -22,6 +22,7 @@
|
||||
#include <cstring>
|
||||
#include <algorithm>
|
||||
#include <tuple>
|
||||
#include <stack>
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
@@ -88,6 +89,10 @@ struct RTreeNodeStat {
|
||||
bst_float base_weight;
|
||||
/*! \brief number of child that is leaf node known up to now */
|
||||
int leaf_child_cnt {0};
|
||||
|
||||
RTreeNodeStat() = default;
|
||||
RTreeNodeStat(float loss_chg, float sum_hess, float weight) :
|
||||
loss_chg{loss_chg}, sum_hess{sum_hess}, base_weight{weight} {}
|
||||
bool operator==(const RTreeNodeStat& b) const {
|
||||
return loss_chg == b.loss_chg && sum_hess == b.sum_hess &&
|
||||
base_weight == b.base_weight && leaf_child_cnt == b.leaf_child_cnt;
|
||||
@@ -101,8 +106,9 @@ struct RTreeNodeStat {
|
||||
class RegTree : public Model {
|
||||
public:
|
||||
using SplitCondT = bst_float;
|
||||
static constexpr int32_t kInvalidNodeId {-1};
|
||||
static constexpr bst_node_t kInvalidNodeId {-1};
|
||||
static constexpr uint32_t kDeletedNodeMarker = std::numeric_limits<uint32_t>::max();
|
||||
static constexpr bst_node_t kRoot { 0 };
|
||||
|
||||
/*! \brief tree node */
|
||||
class Node {
|
||||
@@ -321,6 +327,31 @@ class RegTree : public Model {
|
||||
return nodes_ == b.nodes_ && stats_ == b.stats_ &&
|
||||
deleted_nodes_ == b.deleted_nodes_ && param == b.param;
|
||||
}
|
||||
/* \brief Iterate through all nodes in this tree.
|
||||
*
|
||||
* \param Function that accepts a node index, and returns false when iteration should
|
||||
* stop, otherwise returns true.
|
||||
*/
|
||||
template <typename Func> void WalkTree(Func func) const {
|
||||
std::stack<bst_node_t> nodes;
|
||||
nodes.push(kRoot);
|
||||
auto &self = *this;
|
||||
while (!nodes.empty()) {
|
||||
auto nidx = nodes.top();
|
||||
nodes.pop();
|
||||
if (!func(nidx)) {
|
||||
return;
|
||||
}
|
||||
auto left = self[nidx].LeftChild();
|
||||
auto right = self[nidx].RightChild();
|
||||
if (left != RegTree::kInvalidNodeId) {
|
||||
nodes.push(left);
|
||||
}
|
||||
if (right != RegTree::kInvalidNodeId) {
|
||||
nodes.push(right);
|
||||
}
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief Compares whether 2 trees are equal from a user's perspective. The equality
|
||||
* compares only non-deleted nodes.
|
||||
@@ -341,13 +372,16 @@ class RegTree : public Model {
|
||||
* \param right_leaf_weight The right leaf weight for prediction, modified by learning rate.
|
||||
* \param loss_change The loss change.
|
||||
* \param sum_hess The sum hess.
|
||||
* \param leaf_right_child The right child index of leaf, by default kInvalidNodeId,
|
||||
* some updaters use the right child index of leaf as a marker
|
||||
* \param left_sum The sum hess of left leaf.
|
||||
* \param right_sum The sum hess of right leaf.
|
||||
* \param leaf_right_child The right child index of leaf, by default kInvalidNodeId,
|
||||
* some updaters use the right child index of leaf as a marker
|
||||
*/
|
||||
void ExpandNode(int nid, unsigned split_index, bst_float split_value,
|
||||
bool default_left, bst_float base_weight,
|
||||
bst_float left_leaf_weight, bst_float right_leaf_weight,
|
||||
bst_float loss_change, float sum_hess,
|
||||
bst_float loss_change, float sum_hess, float left_sum,
|
||||
float right_sum,
|
||||
bst_node_t leaf_right_child = kInvalidNodeId) {
|
||||
int pleft = this->AllocNode();
|
||||
int pright = this->AllocNode();
|
||||
@@ -363,9 +397,9 @@ class RegTree : public Model {
|
||||
nodes_[pleft].SetLeaf(left_leaf_weight, leaf_right_child);
|
||||
nodes_[pright].SetLeaf(right_leaf_weight, leaf_right_child);
|
||||
|
||||
this->Stat(nid).loss_chg = loss_change;
|
||||
this->Stat(nid).base_weight = base_weight;
|
||||
this->Stat(nid).sum_hess = sum_hess;
|
||||
this->Stat(nid) = {loss_change, sum_hess, base_weight};
|
||||
this->Stat(pleft) = {0.0f, left_sum, left_leaf_weight};
|
||||
this->Stat(pright) = {0.0f, right_sum, right_leaf_weight};
|
||||
}
|
||||
|
||||
/*!
|
||||
@@ -402,6 +436,10 @@ class RegTree : public Model {
|
||||
return param.num_nodes - 1 - param.num_deleted;
|
||||
}
|
||||
|
||||
/* \brief Count number of leaves in tree. */
|
||||
bst_node_t GetNumLeaves() const;
|
||||
bst_node_t GetNumSplitNodes() const;
|
||||
|
||||
/*!
|
||||
* \brief dense feature vector that can be taken by RegTree
|
||||
* and can be construct from sparse feature vector.
|
||||
|
||||
Reference in New Issue
Block a user