Fix pruner. (#5335)

* Honor the tree depth.
* Prevent pruning pruned node.
This commit is contained in:
Jiaming Yuan
2020-02-25 08:32:46 +08:00
committed by GitHub
parent b0ed3f0a66
commit e0509b3307
5 changed files with 99 additions and 34 deletions

View File

@@ -99,9 +99,10 @@ struct RTreeNodeStat {
*/
class RegTree : public Model {
public:
/*! \brief auxiliary statistics of node to help tree building */
using SplitCondT = bst_float;
static constexpr int32_t kInvalidNodeId {-1};
static constexpr uint32_t kDeletedNodeMarker = std::numeric_limits<uint32_t>::max();
/*! \brief tree node */
class Node {
public:
@@ -158,7 +159,7 @@ class RegTree : public Model {
}
/*! \brief whether this node is deleted */
XGBOOST_DEVICE bool IsDeleted() const {
return sindex_ == std::numeric_limits<uint32_t>::max();
return sindex_ == kDeletedNodeMarker;
}
/*! \brief whether current node is root */
XGBOOST_DEVICE bool IsRoot() const { return parent_ == kInvalidNodeId; }
@@ -201,7 +202,7 @@ class RegTree : public Model {
}
/*! \brief mark that this node is deleted */
XGBOOST_DEVICE void MarkDelete() {
this->sindex_ = std::numeric_limits<unsigned>::max();
this->sindex_ = kDeletedNodeMarker;
}
/*! \brief Reuse this deleted node. */
XGBOOST_DEVICE void Reuse() {
@@ -534,6 +535,13 @@ class RegTree : public Model {
// delete a tree node, keep the parent field to allow trace back
void DeleteNode(int nid) {
CHECK_GE(nid, 1);
auto pid = (*this)[nid].Parent();
if (nid == (*this)[pid].LeftChild()) {
(*this)[pid].SetLeftChild(kInvalidNodeId);
} else {
(*this)[pid].SetRightChild(kInvalidNodeId);
}
deleted_nodes_.push_back(nid);
nodes_[nid].MarkDelete();
++param.num_deleted;
@@ -548,16 +556,20 @@ inline void RegTree::FVec::Init(size_t size) {
}
inline void RegTree::FVec::Fill(const SparsePage::Inst& inst) {
for (bst_uint i = 0; i < inst.size(); ++i) {
if (inst[i].index >= data_.size()) continue;
data_[inst[i].index].fvalue = inst[i].fvalue;
for (auto const& entry : inst) {
if (entry.index >= data_.size()) {
continue;
}
data_[entry.index].fvalue = entry.fvalue;
}
}
inline void RegTree::FVec::Drop(const SparsePage::Inst& inst) {
for (bst_uint i = 0; i < inst.size(); ++i) {
if (inst[i].index >= data_.size()) continue;
data_[inst[i].index].flag = -1;
for (auto const& entry : inst) {
if (entry.index >= data_.size()) {
continue;
}
data_[entry.index].flag = -1;
}
}