Fix pruner. (#5335)
* Honor the tree depth. * Prevent pruning pruned node.
This commit is contained in:
@@ -99,9 +99,10 @@ struct RTreeNodeStat {
|
||||
*/
|
||||
class RegTree : public Model {
|
||||
public:
|
||||
/*! \brief auxiliary statistics of node to help tree building */
|
||||
using SplitCondT = bst_float;
|
||||
static constexpr int32_t kInvalidNodeId {-1};
|
||||
static constexpr uint32_t kDeletedNodeMarker = std::numeric_limits<uint32_t>::max();
|
||||
|
||||
/*! \brief tree node */
|
||||
class Node {
|
||||
public:
|
||||
@@ -158,7 +159,7 @@ class RegTree : public Model {
|
||||
}
|
||||
/*! \brief whether this node is deleted */
|
||||
XGBOOST_DEVICE bool IsDeleted() const {
|
||||
return sindex_ == std::numeric_limits<uint32_t>::max();
|
||||
return sindex_ == kDeletedNodeMarker;
|
||||
}
|
||||
/*! \brief whether current node is root */
|
||||
XGBOOST_DEVICE bool IsRoot() const { return parent_ == kInvalidNodeId; }
|
||||
@@ -201,7 +202,7 @@ class RegTree : public Model {
|
||||
}
|
||||
/*! \brief mark that this node is deleted */
|
||||
XGBOOST_DEVICE void MarkDelete() {
|
||||
this->sindex_ = std::numeric_limits<unsigned>::max();
|
||||
this->sindex_ = kDeletedNodeMarker;
|
||||
}
|
||||
/*! \brief Reuse this deleted node. */
|
||||
XGBOOST_DEVICE void Reuse() {
|
||||
@@ -534,6 +535,13 @@ class RegTree : public Model {
|
||||
// delete a tree node, keep the parent field to allow trace back
|
||||
void DeleteNode(int nid) {
|
||||
CHECK_GE(nid, 1);
|
||||
auto pid = (*this)[nid].Parent();
|
||||
if (nid == (*this)[pid].LeftChild()) {
|
||||
(*this)[pid].SetLeftChild(kInvalidNodeId);
|
||||
} else {
|
||||
(*this)[pid].SetRightChild(kInvalidNodeId);
|
||||
}
|
||||
|
||||
deleted_nodes_.push_back(nid);
|
||||
nodes_[nid].MarkDelete();
|
||||
++param.num_deleted;
|
||||
@@ -548,16 +556,20 @@ inline void RegTree::FVec::Init(size_t size) {
|
||||
}
|
||||
|
||||
inline void RegTree::FVec::Fill(const SparsePage::Inst& inst) {
|
||||
for (bst_uint i = 0; i < inst.size(); ++i) {
|
||||
if (inst[i].index >= data_.size()) continue;
|
||||
data_[inst[i].index].fvalue = inst[i].fvalue;
|
||||
for (auto const& entry : inst) {
|
||||
if (entry.index >= data_.size()) {
|
||||
continue;
|
||||
}
|
||||
data_[entry.index].fvalue = entry.fvalue;
|
||||
}
|
||||
}
|
||||
|
||||
inline void RegTree::FVec::Drop(const SparsePage::Inst& inst) {
|
||||
for (bst_uint i = 0; i < inst.size(); ++i) {
|
||||
if (inst[i].index >= data_.size()) continue;
|
||||
data_[inst[i].index].flag = -1;
|
||||
for (auto const& entry : inst) {
|
||||
if (entry.index >= data_.size()) {
|
||||
continue;
|
||||
}
|
||||
data_[entry.index].flag = -1;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user