From 64af1ecf863be0fdfd52229d610852c99f9845ac Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 5 Dec 2019 21:58:43 +0800 Subject: [PATCH] [Breaking] Remove num roots. (#5059) --- include/xgboost/data.h | 13 ---------- include/xgboost/gbm.h | 6 ++--- include/xgboost/predictor.h | 5 +--- include/xgboost/tree_model.h | 43 ++++++++++++------------------- src/c_api/c_api.cc | 7 +---- src/data/data.cc | 9 +------ src/gbm/gblinear.cc | 3 +-- src/gbm/gbtree.cc | 21 ++++++--------- src/gbm/gbtree.h | 7 +++-- src/gbm/gbtree_model.h | 6 ++--- src/predictor/cpu_predictor.cc | 26 ++++++++----------- src/predictor/gpu_predictor.cu | 3 +-- src/tree/tree_model.cc | 42 ++++++++++++++---------------- src/tree/updater_basemaker-inl.h | 17 ++---------- src/tree/updater_colmaker.cc | 16 ++---------- src/tree/updater_histmaker.cc | 5 +--- src/tree/updater_prune.cc | 2 +- src/tree/updater_quantile_hist.cc | 20 +++++--------- src/tree/updater_refresh.cc | 6 ++--- tests/cpp/data/test_metainfo.cc | 6 +---- tests/python/test_basic.py | 1 - tests/python/test_dmatrix.py | 10 +++---- tests/python/test_ranking.py | 2 -- 23 files changed, 87 insertions(+), 189 deletions(-) diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 23318a685..ba4a73f70 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -48,11 +48,6 @@ class MetaInfo { uint64_t num_nonzero_{0}; /*! \brief label of each instance */ HostDeviceVector labels_; - /*! - * \brief specified root index of each instance, - * can be used for multi task setting - */ - std::vector root_index_; /*! * \brief the index of begin and end of a group * needed when the learning task is ranking. @@ -76,14 +71,6 @@ class MetaInfo { inline bst_float GetWeight(size_t i) const { return weights_.Size() != 0 ? weights_.HostVector()[i] : 1.0f; } - /*! - * \brief Get the root index of i-th instance. - * \param i Instance index. - * \return The pre-defined root index of i-th instance. - */ - inline unsigned GetRoot(size_t i) const { - return !root_index_.empty() ? root_index_[i] : 0U; - } /*! \brief get sorted indexes (argsort) of labels by absolute value (used by cox loss) */ inline const std::vector& LabelAbsSort() const { if (label_order_cache_.size() == labels_.Size()) { diff --git a/include/xgboost/gbm.h b/include/xgboost/gbm.h index 5d5b8eb74..b7ba637cf 100644 --- a/include/xgboost/gbm.h +++ b/include/xgboost/gbm.h @@ -88,13 +88,11 @@ class GradientBooster { * \param inst the instance you want to predict * \param out_preds output vector to hold the predictions * \param ntree_limit limit the number of trees used in prediction - * \param root_index the root index * \sa Predict */ virtual void PredictInstance(const SparsePage::Inst& inst, - std::vector* out_preds, - unsigned ntree_limit = 0, - unsigned root_index = 0) = 0; + std::vector* out_preds, + unsigned ntree_limit = 0) = 0; /*! * \brief predict the leaf index of each tree, the output will be nsample * ntree vector * this is only valid in gbtree predictor diff --git a/include/xgboost/predictor.h b/include/xgboost/predictor.h index 85738a27f..8960a1acf 100644 --- a/include/xgboost/predictor.h +++ b/include/xgboost/predictor.h @@ -98,7 +98,6 @@ class Predictor { /** * \fn virtual void Predictor::PredictInstance( const SparsePage::Inst& * inst, std::vector* out_preds, const gbm::GBTreeModel& model, - * unsigned ntree_limit = 0, unsigned root_index = 0) = 0; * * \brief online prediction function, predict score for one instance at a time * NOTE: use the batch prediction interface if possible, batch prediction is @@ -109,14 +108,12 @@ class Predictor { * \param [in,out] out_preds The output preds. * \param model The model to predict from * \param ntree_limit (Optional) The ntree limit. - * \param root_index (Optional) Zero-based index of the root. */ virtual void PredictInstance(const SparsePage::Inst& inst, std::vector* out_preds, const gbm::GBTreeModel& model, - unsigned ntree_limit = 0, - unsigned root_index = 0) = 0; + unsigned ntree_limit = 0) = 0; /** * \fn virtual void Predictor::PredictLeaf(DMatrix* dmat, diff --git a/include/xgboost/tree_model.h b/include/xgboost/tree_model.h index 35ab56079..256b40ccf 100644 --- a/include/xgboost/tree_model.h +++ b/include/xgboost/tree_model.h @@ -29,8 +29,8 @@ struct PathElement; // forward declaration /*! \brief meta parameters of the tree */ struct TreeParam : public dmlc::Parameter { - /*! \brief number of start root */ - int num_roots; + /*! \brief (Deprecated) number of start root */ + int deprecated_num_roots; /*! \brief total number of nodes */ int num_nodes; /*!\brief number of deleted nodes */ @@ -52,14 +52,12 @@ struct TreeParam : public dmlc::Parameter { static_assert(sizeof(TreeParam) == (31 + 6) * sizeof(int), "TreeParam: 64 bit align"); std::memset(this, 0, sizeof(TreeParam)); - num_nodes = num_roots = 1; + num_nodes = 1; } // declare the parameters DMLC_DECLARE_PARAMETER(TreeParam) { // only declare the parameters that can be set by the user. // other arguments are set by the algorithm. - DMLC_DECLARE_FIELD(num_roots).set_lower_bound(1).set_default(1) - .describe("Number of start root of trees."); DMLC_DECLARE_FIELD(num_nodes).set_lower_bound(1).set_default(1); DMLC_DECLARE_FIELD(num_feature) .describe("Number of features used in tree construction."); @@ -68,7 +66,7 @@ struct TreeParam : public dmlc::Parameter { } bool operator==(const TreeParam& b) const { - return num_roots == b.num_roots && num_nodes == b.num_nodes && + return num_nodes == b.num_nodes && num_deleted == b.num_deleted && max_depth == b.max_depth && num_feature == b.num_feature && size_leaf_vector == b.size_leaf_vector; @@ -269,7 +267,6 @@ class RegTree : public Model { /*! \brief constructor */ RegTree() { param.num_nodes = 1; - param.num_roots = 1; param.num_deleted = 0; nodes_.resize(param.num_nodes); stats_.resize(param.num_nodes); @@ -380,16 +377,12 @@ class RegTree : public Model { * \brief get maximum depth */ int MaxDepth() { - int maxd = 0; - for (int i = 0; i < param.num_roots; ++i) { - maxd = std::max(maxd, MaxDepth(i)); - } - return maxd; + return MaxDepth(0); } /*! \brief number of extra nodes besides the root */ int NumExtraNodes() const { - return param.num_nodes - param.num_roots - param.num_deleted; + return param.num_nodes - 1 - param.num_deleted; } /*! @@ -444,19 +437,17 @@ class RegTree : public Model { /*! * \brief get the leaf index * \param feat dense feature vector, if the feature is missing the field is set to NaN - * \param root_id starting root index of the instance * \return the leaf index of the given feature */ - int GetLeafIndex(const FVec& feat, unsigned root_id = 0) const; + int GetLeafIndex(const FVec& feat) const; /*! * \brief calculate the feature contributions (https://arxiv.org/abs/1706.06060) for the tree * \param feat dense feature vector, if the feature is missing the field is set to NaN - * \param root_id starting root index of the instance * \param out_contribs output vector to hold the contributions * \param condition fix one feature to either off (-1) on (1) or not fixed (0 default) * \param condition_feature the index of the feature to fix */ - void CalculateContributions(const RegTree::FVec& feat, unsigned root_id, + void CalculateContributions(const RegTree::FVec& feat, bst_float* out_contribs, int condition = 0, unsigned condition_feature = 0) const; /*! @@ -482,10 +473,9 @@ class RegTree : public Model { /*! * \brief calculate the approximate feature contributions for the given root * \param feat dense feature vector, if the feature is missing the field is set to NaN - * \param root_id starting root index of the instance * \param out_contribs output vector to hold the contributions */ - void CalculateContributionsApprox(const RegTree::FVec& feat, unsigned root_id, + void CalculateContributionsApprox(const RegTree::FVec& feat, bst_float* out_contribs) const; /*! * \brief get next position of the tree given current pid @@ -536,7 +526,7 @@ class RegTree : public Model { } // delete a tree node, keep the parent field to allow trace back void DeleteNode(int nid) { - CHECK_GE(nid, param.num_roots); + CHECK_GE(nid, 1); deleted_nodes_.push_back(nid); nodes_[nid].MarkDelete(); ++param.num_deleted; @@ -576,14 +566,13 @@ inline bool RegTree::FVec::IsMissing(size_t i) const { return data_[i].flag == -1; } -inline int RegTree::GetLeafIndex(const RegTree::FVec& feat, - unsigned root_id) const { - auto pid = static_cast(root_id); - while (!(*this)[pid].IsLeaf()) { - unsigned split_index = (*this)[pid].SplitIndex(); - pid = this->GetNext(pid, feat.Fvalue(split_index), feat.IsMissing(split_index)); +inline int RegTree::GetLeafIndex(const RegTree::FVec& feat) const { + bst_node_t nid = 0; + while (!(*this)[nid].IsLeaf()) { + unsigned split_index = (*this)[nid].SplitIndex(); + nid = this->GetNext(nid, feat.Fvalue(split_index), feat.IsMissing(split_index)); } - return pid; + return nid; } /*! \brief get next position of the tree given current pid */ diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 51da22bee..66280c6b8 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -329,9 +329,6 @@ XGB_DLL int XGDMatrixSliceDMatrixEx(DMatrixHandle handle, if (src_base_margin.size() != 0) { ret_base_margin.push_back(src_base_margin[ridx]); } - if (src.info.root_index_.size() != 0) { - ret.info.root_index_.push_back(src.info.root_index_[ridx]); - } } *out = new std::shared_ptr(DMatrix::Create(std::move(source))); API_END(); @@ -426,9 +423,7 @@ XGB_DLL int XGDMatrixGetUIntInfo(const DMatrixHandle handle, CHECK_HANDLE(); const MetaInfo& info = static_cast*>(handle)->get()->Info(); const std::vector* vec = nullptr; - if (!std::strcmp(field, "root_index")) { - vec = &info.root_index_; - } else if (!std::strcmp(field, "group_ptr")) { + if (!std::strcmp(field, "group_ptr")) { vec = &info.group_ptr_; } else { LOG(FATAL) << "Unknown comp uint field name " << field diff --git a/src/data/data.cc b/src/data/data.cc index 03d3887c8..88e593cc3 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -34,7 +34,6 @@ namespace xgboost { void MetaInfo::Clear() { num_row_ = num_col_ = num_nonzero_ = 0; labels_.HostVector().clear(); - root_index_.clear(); group_ptr_.clear(); weights_.HostVector().clear(); base_margin_.HostVector().clear(); @@ -48,7 +47,6 @@ void MetaInfo::SaveBinary(dmlc::Stream *fo) const { fo->Write(labels_.HostVector()); fo->Write(group_ptr_); fo->Write(weights_.HostVector()); - fo->Write(root_index_); fo->Write(base_margin_.HostVector()); } @@ -69,7 +67,6 @@ void MetaInfo::LoadBinary(dmlc::Stream *fi) { CHECK(fi->Read(&group_ptr_)) << "MetaInfo: invalid format"; CHECK(fi->Read(&weights_.HostVector())) << "MetaInfo: invalid format"; - CHECK(fi->Read(&root_index_)) << "MetaInfo: invalid format"; CHECK(fi->Read(&base_margin_.HostVector())) << "MetaInfo: invalid format"; } @@ -121,11 +118,7 @@ inline bool MetaTryLoadFloatInfo(const std::string& fname, } \ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t num) { - if (!std::strcmp(key, "root_index")) { - root_index_.resize(num); - DISPATCH_CONST_PTR(dtype, dptr, cast_dptr, - std::copy(cast_dptr, cast_dptr + num, root_index_.begin())); - } else if (!std::strcmp(key, "label")) { + if (!std::strcmp(key, "label")) { auto& labels = labels_.HostVector(); labels.resize(num); DISPATCH_CONST_PTR(dtype, dptr, cast_dptr, diff --git a/src/gbm/gblinear.cc b/src/gbm/gblinear.cc index a4e751a28..f1d2b3d94 100644 --- a/src/gbm/gblinear.cc +++ b/src/gbm/gblinear.cc @@ -117,8 +117,7 @@ class GBLinear : public GradientBooster { // add base margin void PredictInstance(const SparsePage::Inst &inst, std::vector *out_preds, - unsigned ntree_limit, - unsigned root_index) override { + unsigned ntree_limit) override { const int ngroup = model_.param.num_output_group; for (int gid = 0; gid < ngroup; ++gid) { this->Pred(inst, dmlc::BeginPtr(*out_preds), gid, base_margin_); diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index 5c78b32e3..694d8cc2f 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -322,10 +322,8 @@ class Dart : public GBTree { PredLoopInternal(p_fmat, &out_preds->HostVector(), 0, ntree_limit, true); } - void PredictInstance(const SparsePage::Inst& inst, - std::vector* out_preds, - unsigned ntree_limit, - unsigned root_index) override { + void PredictInstance(const SparsePage::Inst &inst, + std::vector *out_preds, unsigned ntree_limit) override { DropTrees(1); if (thread_temp_.size() == 0) { thread_temp_.resize(1, RegTree::FVec()); @@ -338,9 +336,9 @@ class Dart : public GBTree { } // loop over output groups for (int gid = 0; gid < model_.param.num_output_group; ++gid) { - (*out_preds)[gid] - = PredValue(inst, gid, root_index, - &thread_temp_[0], 0, ntree_limit) + model_.base_margin; + (*out_preds)[gid] = + PredValue(inst, gid, &thread_temp_[0], 0, ntree_limit) + + model_.base_margin; } } @@ -411,7 +409,6 @@ class Dart : public GBTree { int num_group, unsigned tree_begin, unsigned tree_end) { - const MetaInfo& info = p_fmat->Info(); const int nthread = omp_get_max_threads(); CHECK_EQ(num_group, model_.param.num_output_group); InitThreadTemp(nthread); @@ -442,8 +439,7 @@ class Dart : public GBTree { for (int gid = 0; gid < num_group; ++gid) { const size_t offset = ridx[k] * num_group + gid; preds[offset] += - self->PredValue(inst[k], gid, info.GetRoot(ridx[k]), - &feats, tree_begin, tree_end); + self->PredValue(inst[k], gid, &feats, tree_begin, tree_end); } } } @@ -455,7 +451,7 @@ class Dart : public GBTree { for (int gid = 0; gid < num_group; ++gid) { const size_t offset = ridx * num_group + gid; preds[offset] += - self->PredValue(inst, gid, info.GetRoot(ridx), + self->PredValue(inst, gid, &feats, tree_begin, tree_end); } } @@ -478,7 +474,6 @@ class Dart : public GBTree { // predict the leaf scores without dropped trees inline bst_float PredValue(const SparsePage::Inst &inst, int bst_group, - unsigned root_index, RegTree::FVec *p_feats, unsigned tree_begin, unsigned tree_end) { @@ -488,7 +483,7 @@ class Dart : public GBTree { if (model_.tree_info[i] == bst_group) { bool drop = (std::binary_search(idx_drop_.begin(), idx_drop_.end(), i)); if (!drop) { - int tid = model_.trees[i]->GetLeafIndex(*p_feats, root_index); + int tid = model_.trees[i]->GetLeafIndex(*p_feats); psum += weight_drop_[i] * (*model_.trees[i])[tid].LeafValue(); } } diff --git a/src/gbm/gbtree.h b/src/gbm/gbtree.h index 37eee01c1..5c9373cfd 100644 --- a/src/gbm/gbtree.h +++ b/src/gbm/gbtree.h @@ -196,12 +196,11 @@ class GBTree : public GradientBooster { } void PredictInstance(const SparsePage::Inst& inst, - std::vector* out_preds, - unsigned ntree_limit, - unsigned root_index) override { + std::vector* out_preds, + unsigned ntree_limit) override { CHECK(configured_); cpu_predictor_->PredictInstance(inst, out_preds, model_, - ntree_limit, root_index); + ntree_limit); } void PredictLeaf(DMatrix* p_fmat, diff --git a/src/gbm/gbtree_model.h b/src/gbm/gbtree_model.h index 26ac6cffa..1e34a756f 100644 --- a/src/gbm/gbtree_model.h +++ b/src/gbm/gbtree_model.h @@ -18,8 +18,8 @@ namespace gbm { struct GBTreeModelParam : public dmlc::Parameter { /*! \brief number of trees */ int num_trees; - /*! \brief number of roots */ - int num_roots; + /*! \brief (Deprecated) number of roots */ + int deprecated_num_roots; /*! \brief number of features to be used by trees */ int num_feature; /*! \brief pad this space, for backward compatibility reason.*/ @@ -50,8 +50,6 @@ struct GBTreeModelParam : public dmlc::Parameter { .describe( "Number of output groups to be predicted," " used for multi-class classification."); - DMLC_DECLARE_FIELD(num_roots).set_lower_bound(1).set_default(1).describe( - "Tree updater sequence."); DMLC_DECLARE_FIELD(num_feature) .set_lower_bound(0) .describe("Number of features used for training and prediction."); diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index 7226e0145..4d04b1f4a 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -19,13 +19,13 @@ class CPUPredictor : public Predictor { static bst_float PredValue(const SparsePage::Inst& inst, const std::vector>& trees, const std::vector& tree_info, int bst_group, - unsigned root_index, RegTree::FVec* p_feats, + RegTree::FVec* p_feats, unsigned tree_begin, unsigned tree_end) { bst_float psum = 0.0f; p_feats->Fill(inst); for (size_t i = tree_begin; i < tree_end; ++i) { if (tree_info[i] == bst_group) { - int tid = trees[i]->GetLeafIndex(*p_feats, root_index); + int tid = trees[i]->GetLeafIndex(*p_feats); psum += (*trees[i])[tid].LeafValue(); } } @@ -47,7 +47,6 @@ class CPUPredictor : public Predictor { std::vector* out_preds, const gbm::GBTreeModel& model, int num_group, unsigned tree_begin, unsigned tree_end) { - const MetaInfo& info = p_fmat->Info(); const int nthread = omp_get_max_threads(); InitThreadTemp(nthread, model.param.num_feature); std::vector& preds = *out_preds; @@ -81,7 +80,7 @@ class CPUPredictor : public Predictor { const size_t offset = ridx[k] * num_group + gid; preds[offset] += this->PredValue( inst[k], model.trees, model.tree_info, gid, - info.GetRoot(ridx[k]), &feats, tree_begin, tree_end); + &feats, tree_begin, tree_end); } } } @@ -94,7 +93,7 @@ class CPUPredictor : public Predictor { const size_t offset = ridx * num_group + gid; preds[offset] += this->PredValue(inst, model.trees, model.tree_info, gid, - info.GetRoot(ridx), &feats, tree_begin, tree_end); + &feats, tree_begin, tree_end); } } } @@ -204,8 +203,7 @@ class CPUPredictor : public Predictor { void PredictInstance(const SparsePage::Inst& inst, std::vector* out_preds, - const gbm::GBTreeModel& model, unsigned ntree_limit, - unsigned root_index) override { + const gbm::GBTreeModel& model, unsigned ntree_limit) override { if (thread_temp.size() == 0) { thread_temp.resize(1, RegTree::FVec()); thread_temp[0].Init(model.param.num_feature); @@ -219,7 +217,7 @@ class CPUPredictor : public Predictor { // loop over output groups for (int gid = 0; gid < model.param.num_output_group; ++gid) { (*out_preds)[gid] = - PredValue(inst, model.trees, model.tree_info, gid, root_index, + PredValue(inst, model.trees, model.tree_info, gid, &thread_temp[0], 0, ntree_limit) + model.base_margin; } @@ -247,7 +245,7 @@ class CPUPredictor : public Predictor { RegTree::FVec& feats = thread_temp[tid]; feats.Fill(batch[i]); for (unsigned j = 0; j < ntree_limit; ++j) { - int tid = model.trees[j]->GetLeafIndex(feats, info.GetRoot(ridx)); + int tid = model.trees[j]->GetLeafIndex(feats); preds[ridx * ntree_limit + j] = static_cast(tid); } feats.Drop(batch[i]); @@ -270,7 +268,7 @@ class CPUPredictor : public Predictor { ntree_limit = static_cast(model.trees.size()); } const int ngroup = model.param.num_output_group; - size_t ncolumns = model.param.num_feature + 1; + size_t const ncolumns = model.param.num_feature + 1; // allocate space for (number of features + bias) times the number of rows std::vector& contribs = *out_contribs; contribs.resize(info.num_row_ * ncolumns * model.param.num_output_group); @@ -290,15 +288,13 @@ class CPUPredictor : public Predictor { #pragma omp parallel for schedule(static) for (bst_omp_uint i = 0; i < nsize; ++i) { auto row_idx = static_cast(batch.base_rowid + i); - unsigned root_id = info.GetRoot(row_idx); + std::vector this_tree_contribs(ncolumns); RegTree::FVec& feats = thread_temp[omp_get_thread_num()]; // loop over all classes for (int gid = 0; gid < ngroup; ++gid) { bst_float* p_contribs = &contribs[(row_idx * ngroup + gid) * ncolumns]; feats.Fill(batch[i]); - std::vector this_tree_contribs; - this_tree_contribs.resize(ncolumns); // calculate contributions for (unsigned j = 0; j < ntree_limit; ++j) { std::fill(this_tree_contribs.begin(), this_tree_contribs.end(), 0); @@ -306,10 +302,10 @@ class CPUPredictor : public Predictor { continue; } if (!approximate) { - model.trees[j]->CalculateContributions(feats, root_id, &this_tree_contribs[0], + model.trees[j]->CalculateContributions(feats, &this_tree_contribs[0], condition, condition_feature); } else { - model.trees[j]->CalculateContributionsApprox(feats, root_id, &this_tree_contribs[0]); + model.trees[j]->CalculateContributionsApprox(feats, &this_tree_contribs[0]); } for (int ci = 0 ; ci < ncolumns ; ++ci) { p_contribs[ci] += this_tree_contribs[ci] * diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index caf0f6904..e3371bfce 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -381,8 +381,7 @@ class GPUPredictor : public xgboost::Predictor { void PredictInstance(const SparsePage::Inst& inst, std::vector* out_preds, - const gbm::GBTreeModel& model, unsigned ntree_limit, - unsigned root_index) override { + const gbm::GBTreeModel& model, unsigned ntree_limit) override { LOG(FATAL) << "Internal error: " << __func__ << " is not implemented in GPU Predictor."; } diff --git a/src/tree/tree_model.cc b/src/tree/tree_model.cc index 2b5e7226a..810906a2b 100644 --- a/src/tree/tree_model.cc +++ b/src/tree/tree_model.cc @@ -609,9 +609,7 @@ std::string RegTree::DumpModel(const FeatureMap& fmap, std::unique_ptr builder { TreeGenerator::Create(format, fmap, with_stats) }; - for (int32_t i = 0; i < param.num_roots; ++i) { - builder->BuildTree(*this); - } + builder->BuildTree(*this); std::string result = builder->Str(); return result; @@ -628,8 +626,10 @@ void RegTree::LoadModel(dmlc::Stream* fi) { sizeof(RTreeNodeStat) * stats_.size()); // chg deleted nodes deleted_nodes_.resize(0); - for (int i = param.num_roots; i < param.num_nodes; ++i) { - if (nodes_[i].IsDeleted()) deleted_nodes_.push_back(i); + for (int i = 1; i < param.num_nodes; ++i) { + if (nodes_[i].IsDeleted()) { + deleted_nodes_.push_back(i); + } } CHECK_EQ(static_cast(deleted_nodes_.size()), param.num_deleted); } @@ -652,9 +652,7 @@ void RegTree::FillNodeMeanValues() { return; } this->node_mean_values_.resize(num_nodes); - for (int root_id = 0; root_id < param.num_roots; ++root_id) { - this->FillNodeMeanValue(root_id); - } + this->FillNodeMeanValue(0); } bst_float RegTree::FillNodeMeanValue(int nid) { @@ -672,28 +670,27 @@ bst_float RegTree::FillNodeMeanValue(int nid) { } void RegTree::CalculateContributionsApprox(const RegTree::FVec &feat, - unsigned root_id, bst_float *out_contribs) const { CHECK_GT(this->node_mean_values_.size(), 0U); // this follows the idea of http://blog.datadive.net/interpreting-random-forests/ unsigned split_index = 0; - auto pid = static_cast(root_id); // update bias value - bst_float node_value = this->node_mean_values_[pid]; + bst_float node_value = this->node_mean_values_[0]; out_contribs[feat.Size()] += node_value; - if ((*this)[pid].IsLeaf()) { + if ((*this)[0].IsLeaf()) { // nothing to do anymore return; } - while (!(*this)[pid].IsLeaf()) { - split_index = (*this)[pid].SplitIndex(); - pid = this->GetNext(pid, feat.Fvalue(split_index), feat.IsMissing(split_index)); - bst_float new_value = this->node_mean_values_[pid]; + bst_node_t nid = 0; + while (!(*this)[nid].IsLeaf()) { + split_index = (*this)[nid].SplitIndex(); + nid = this->GetNext(nid, feat.Fvalue(split_index), feat.IsMissing(split_index)); + bst_float new_value = this->node_mean_values_[nid]; // update feature weight out_contribs[split_index] += new_value - node_value; node_value = new_value; } - bst_float leaf_value = (*this)[pid].LeafValue(); + bst_float leaf_value = (*this)[nid].LeafValue(); // update leaf feature weight out_contribs[split_index] += leaf_value - node_value; } @@ -868,21 +865,20 @@ void RegTree::TreeShap(const RegTree::FVec &feat, bst_float *phi, } void RegTree::CalculateContributions(const RegTree::FVec &feat, - unsigned root_id, bst_float *out_contribs, + bst_float *out_contribs, int condition, unsigned condition_feature) const { // find the expected value of the tree's predictions if (condition == 0) { - bst_float node_value = this->node_mean_values_[static_cast(root_id)]; + bst_float node_value = this->node_mean_values_[0]; out_contribs[feat.Size()] += node_value; } // Preallocate space for the unique path data - const int maxd = this->MaxDepth(root_id) + 2; - auto *unique_path_data = new PathElement[(maxd * (maxd + 1)) / 2]; + const int maxd = this->MaxDepth(0) + 2; + std::vector unique_path_data((maxd * (maxd + 1)) / 2); - TreeShap(feat, out_contribs, root_id, 0, unique_path_data, + TreeShap(feat, out_contribs, 0, 0, unique_path_data.data(), 1, 1, -1, condition, condition_feature, 1); - delete[] unique_path_data; } } // namespace xgboost diff --git a/src/tree/updater_basemaker-inl.h b/src/tree/updater_basemaker-inl.h index 700f6e07a..0c7426242 100644 --- a/src/tree/updater_basemaker-inl.h +++ b/src/tree/updater_basemaker-inl.h @@ -128,21 +128,10 @@ class BaseMaker: public TreeUpdater { inline void InitData(const std::vector &gpair, const DMatrix &fmat, const RegTree &tree) { - CHECK_EQ(tree.param.num_nodes, tree.param.num_roots) - << "TreeMaker: can only grow new tree"; - const std::vector &root_index = fmat.Info().root_index_; { // setup position position_.resize(gpair.size()); - if (root_index.size() == 0) { - std::fill(position_.begin(), position_.end(), 0); - } else { - for (size_t i = 0; i < position_.size(); ++i) { - position_[i] = root_index[i]; - CHECK_LT(root_index[i], (unsigned)tree.param.num_roots) - << "root index exceed setting"; - } - } + std::fill(position_.begin(), position_.end(), 0); // mark delete for the deleted datas for (size_t i = 0; i < position_.size(); ++i) { if (gpair[i].GetHess() < 0.0f) position_[i] = ~position_[i]; @@ -160,9 +149,7 @@ class BaseMaker: public TreeUpdater { { // expand query qexpand_.reserve(256); qexpand_.clear(); - for (int i = 0; i < tree.param.num_roots; ++i) { - qexpand_.push_back(i); - } + qexpand_.push_back(0); this->UpdateNode2WorkIndex(tree); } this->interaction_constraints_.Configure(param_, fmat.Info().num_col_); diff --git a/src/tree/updater_colmaker.cc b/src/tree/updater_colmaker.cc index dbd6cc6e9..5f38d5ed6 100644 --- a/src/tree/updater_colmaker.cc +++ b/src/tree/updater_colmaker.cc @@ -146,21 +146,11 @@ class ColMaker: public TreeUpdater { inline void InitData(const std::vector& gpair, const DMatrix& fmat, const RegTree& tree) { - CHECK_EQ(tree.param.num_nodes, tree.param.num_roots) - << "ColMaker: can only grow new tree"; - const std::vector& root_index = fmat.Info().root_index_; { // setup position position_.resize(gpair.size()); CHECK_EQ(fmat.Info().num_row_, position_.size()); - if (root_index.size() == 0) { - std::fill(position_.begin(), position_.end(), 0); - } else { - for (size_t ridx = 0; ridx < position_.size(); ++ridx) { - position_[ridx] = root_index[ridx]; - CHECK_LT(root_index[ridx], (unsigned)tree.param.num_roots); - } - } + std::fill(position_.begin(), position_.end(), 0); // mark delete for the deleted datas for (size_t ridx = 0; ridx < position_.size(); ++ridx) { if (gpair[ridx].GetHess() < 0.0f) position_[ridx] = ~position_[ridx]; @@ -192,9 +182,7 @@ class ColMaker: public TreeUpdater { { // expand query qexpand_.reserve(256); qexpand_.clear(); - for (int i = 0; i < tree.param.num_roots; ++i) { - qexpand_.push_back(i); - } + qexpand_.push_back(0); } } /*! diff --git a/src/tree/updater_histmaker.cc b/src/tree/updater_histmaker.cc index 246aeabe9..725634b9e 100644 --- a/src/tree/updater_histmaker.cc +++ b/src/tree/updater_histmaker.cc @@ -119,10 +119,7 @@ class HistMaker: public BaseMaker { this->InitData(gpair, *p_fmat, *p_tree); this->InitWorkSet(p_fmat, *p_tree, &selected_features_); // mark root node as fresh. - for (int i = 0; i < p_tree->param.num_roots; ++i) { - (*p_tree)[i].SetLeaf(0.0f, 0); - } - CHECK_EQ(p_tree->param.num_roots, 1) << "Support for num roots is removed."; + (*p_tree)[0].SetLeaf(0.0f, 0); for (int depth = 0; depth < param_.max_depth; ++depth) { // reset and propose candidate split diff --git a/src/tree/updater_prune.cc b/src/tree/updater_prune.cc index 38ca5bf96..386031e89 100644 --- a/src/tree/updater_prune.cc +++ b/src/tree/updater_prune.cc @@ -75,7 +75,7 @@ class TreePruner: public TreeUpdater { npruned = this->TryPruneLeaf(tree, nid, tree.GetDepth(nid), npruned); } } - LOG(INFO) << "tree pruning end, " << tree.param.num_roots << " roots, " + LOG(INFO) << "tree pruning end, " << tree.NumExtraNodes() << " extra nodes, " << npruned << " pruned nodes, max_depth=" << tree.MaxDepth(); } diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 1a457a2a0..4a7e4b75d 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -255,18 +255,15 @@ void QuantileHistMaker::Builder::ExpandWithLossGuide( unsigned timestamp = 0; int num_leaves = 0; - for (int nid = 0; nid < p_tree->param.num_roots; ++nid) { - hist_.AddHistRow(nid); - BuildHist(gpair_h, row_set_collection_[nid], gmat, gmatb, hist_[nid], true); + hist_.AddHistRow(0); + BuildHist(gpair_h, row_set_collection_[0], gmat, gmatb, hist_[0], true); - this->InitNewNode(nid, gmat, gpair_h, *p_fmat, *p_tree); + this->InitNewNode(0, gmat, gpair_h, *p_fmat, *p_tree); - this->EvaluateSplit(nid, gmat, hist_, *p_fmat, *p_tree); - qexpand_loss_guided_->push(ExpandEntry(nid, p_tree->GetDepth(nid), - snode_[nid].best.loss_chg, - timestamp++)); - ++num_leaves; - } + this->EvaluateSplit(0, gmat, hist_, *p_fmat, *p_tree); + qexpand_loss_guided_->push(ExpandEntry(0, p_tree->GetDepth(0), + snode_[0].best.loss_chg, timestamp++)); + ++num_leaves; while (!qexpand_loss_guided_->empty()) { const ExpandEntry candidate = qexpand_loss_guided_->top(); @@ -397,8 +394,6 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat, const std::vector& gpair, const DMatrix& fmat, const RegTree& tree) { - CHECK_EQ(tree.param.num_nodes, tree.param.num_roots) - << "ColMakerHist: can only grow new tree"; CHECK((param_.max_depth > 0 || param_.max_leaves > 0)) << "max_depth or max_leaves cannot be both 0 (unlimited); " << "at least one should be a positive quantity."; @@ -425,7 +420,6 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat, } hist_builder_.Init(this->nthread_, nbins); - CHECK_EQ(info.root_index_.size(), 0U); std::vector& row_indices = row_set_collection_.row_indices_; row_indices.resize(info.num_row_); auto* p_row_indices = row_indices.data(); diff --git a/src/tree/updater_refresh.cc b/src/tree/updater_refresh.cc index 98ae902e4..2a0f773b8 100644 --- a/src/tree/updater_refresh.cc +++ b/src/tree/updater_refresh.cc @@ -90,9 +90,7 @@ class TreeRefresher: public TreeUpdater { param_.learning_rate = lr / trees.size(); int offset = 0; for (auto tree : trees) { - for (int rid = 0; rid < tree->param.num_roots; ++rid) { - this->Refresh(dmlc::BeginPtr(stemp[0]) + offset, rid, tree); - } + this->Refresh(dmlc::BeginPtr(stemp[0]) + offset, 0, tree); offset += tree->param.num_nodes; } // set learning rate back @@ -107,7 +105,7 @@ class TreeRefresher: public TreeUpdater { const bst_uint ridx, GradStats *gstats) { // start from groups that belongs to current data - auto pid = static_cast(info.GetRoot(ridx)); + auto pid = 0; gstats[pid].Add(gpair[ridx]); // tranverse tree while (!tree[pid].IsLeaf()) { diff --git a/tests/cpp/data/test_metainfo.cc b/tests/cpp/data/test_metainfo.cc index 9c32a0386..700dd87bc 100644 --- a/tests/cpp/data/test_metainfo.cc +++ b/tests/cpp/data/test_metainfo.cc @@ -12,10 +12,6 @@ TEST(MetaInfo, GetSet) { xgboost::MetaInfo info; double double2[2] = {1.0, 2.0}; - EXPECT_EQ(info.GetRoot(1), 0) - << "When no root_index is given, was expecting default value 0"; - info.SetInfo("root_index", double2, xgboost::kDouble, 2); - EXPECT_EQ(info.GetRoot(1), 2.0f); EXPECT_EQ(info.labels_.Size(), 0); info.SetInfo("label", double2, xgboost::kFloat32, 2); @@ -58,7 +54,7 @@ TEST(MetaInfo, SaveLoadBinary) { info.SaveBinary(fs.get()); } - ASSERT_EQ(GetFileSize(tmp_file), 92) + ASSERT_EQ(GetFileSize(tmp_file), 84) << "Expected saved binary file size to be same as object size"; std::unique_ptr fs { diff --git a/tests/python/test_basic.py b/tests/python/test_basic.py index 6dbd12799..d27c50d71 100644 --- a/tests/python/test_basic.py +++ b/tests/python/test_basic.py @@ -231,7 +231,6 @@ class TestBasic(unittest.TestCase): assert output == solution - class TestBasicPathLike(unittest.TestCase): """Unit tests using the os_fspath and pathlib.Path for file interaction.""" diff --git a/tests/python/test_dmatrix.py b/tests/python/test_dmatrix.py index 5877e04fd..2becb674f 100644 --- a/tests/python/test_dmatrix.py +++ b/tests/python/test_dmatrix.py @@ -63,11 +63,11 @@ class TestDMatrix(unittest.TestCase): # Sliced UInt array z = np.array([12, 34, 56], np.uint32)[::2] dmat = xgb.DMatrix(np.array([[]])) - dmat.set_uint_info('root_index', z) - from_view = dmat.get_uint_info('root_index') + dmat.set_uint_info('group', z) + from_view = dmat.get_uint_info('group_ptr') dmat = xgb.DMatrix(np.array([[]])) - dmat.set_uint_info('root_index', z + 0) - from_array = dmat.get_uint_info('root_index') + dmat.set_uint_info('group', z + 0) + from_array = dmat.get_uint_info('group_ptr') assert (from_view.shape == from_array.shape) assert (from_view == from_array).all() @@ -142,7 +142,7 @@ class TestDMatrix(unittest.TestCase): dtrain.get_float_info('label') dtrain.get_float_info('weight') dtrain.get_float_info('base_margin') - dtrain.get_uint_info('root_index') + dtrain.get_uint_info('group_ptr') def test_sparse_dmatrix_csr(self): nrow = 100 diff --git a/tests/python/test_ranking.py b/tests/python/test_ranking.py index dc257cf23..51e5e18a9 100644 --- a/tests/python/test_ranking.py +++ b/tests/python/test_ranking.py @@ -161,8 +161,6 @@ class TestRanking(unittest.TestCase): """ Retrieve the group number from the dmatrix """ - # control that should work - self.dtrain.get_uint_info('root_index') # test the new getter self.dtrain.get_uint_info('group_ptr')