diff --git a/include/xgboost/tree_model.h b/include/xgboost/tree_model.h index 31dc15093..69515e943 100644 --- a/include/xgboost/tree_model.h +++ b/include/xgboost/tree_model.h @@ -179,6 +179,10 @@ class RegTree { XGBOOST_DEVICE void MarkDelete() { this->sindex_ = std::numeric_limits::max(); } + /*! \brief Reuse this deleted node. */ + XGBOOST_DEVICE void Reuse() { + this->sindex_ = 0; + } // set parent XGBOOST_DEVICE void SetParent(int pidx, bool is_left_child = true) { if (is_left_child) pidx |= (1U << 31); @@ -503,10 +507,11 @@ class RegTree { // !!!!!! NOTE: may cause BUG here, nodes.resize int AllocNode() { if (param.num_deleted != 0) { - int nd = deleted_nodes_.back(); + int nid = deleted_nodes_.back(); deleted_nodes_.pop_back(); + nodes_[nid].Reuse(); --param.num_deleted; - return nd; + return nid; } int nd = param.num_nodes++; CHECK_LT(param.num_nodes, std::numeric_limits::max()) diff --git a/tests/cpp/tree/test_tree_model.cc b/tests/cpp/tree/test_tree_model.cc index 21c2eeeda..da0e6bcb6 100644 --- a/tests/cpp/tree/test_tree_model.cc +++ b/tests/cpp/tree/test_tree_model.cc @@ -84,4 +84,21 @@ TEST(Tree, Load) { EXPECT_EQ(tree[1].LeafValue(), 0.1f); EXPECT_TRUE(tree[1].IsLeaf()); } + +TEST(Tree, AllocateNode) { + RegTree tree; + tree.ExpandNode( + 0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f); + tree.CollapseToLeaf(0, 0); + ASSERT_EQ(tree.NumExtraNodes(), 0); + + tree.ExpandNode( + 0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f); + ASSERT_EQ(tree.NumExtraNodes(), 2); + + auto& nodes = tree.GetNodes(); + ASSERT_FALSE(nodes.at(1).IsDeleted()); + ASSERT_TRUE(nodes.at(1).IsLeaf()); + ASSERT_TRUE(nodes.at(2).IsLeaf()); +} } // namespace xgboost