[Breaking] Remove num roots. (#5059)

This commit is contained in:
Jiaming Yuan 2019-12-05 21:58:43 +08:00 committed by GitHub
parent f3d8536702
commit 64af1ecf86
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 87 additions and 189 deletions

View File

@ -48,11 +48,6 @@ class MetaInfo {
uint64_t num_nonzero_{0}; uint64_t num_nonzero_{0};
/*! \brief label of each instance */ /*! \brief label of each instance */
HostDeviceVector<bst_float> labels_; HostDeviceVector<bst_float> labels_;
/*!
* \brief specified root index of each instance,
* can be used for multi task setting
*/
std::vector<bst_uint> root_index_;
/*! /*!
* \brief the index of begin and end of a group * \brief the index of begin and end of a group
* needed when the learning task is ranking. * needed when the learning task is ranking.
@ -76,14 +71,6 @@ class MetaInfo {
inline bst_float GetWeight(size_t i) const { inline bst_float GetWeight(size_t i) const {
return weights_.Size() != 0 ? weights_.HostVector()[i] : 1.0f; return weights_.Size() != 0 ? weights_.HostVector()[i] : 1.0f;
} }
/*!
* \brief Get the root index of i-th instance.
* \param i Instance index.
* \return The pre-defined root index of i-th instance.
*/
inline unsigned GetRoot(size_t i) const {
return !root_index_.empty() ? root_index_[i] : 0U;
}
/*! \brief get sorted indexes (argsort) of labels by absolute value (used by cox loss) */ /*! \brief get sorted indexes (argsort) of labels by absolute value (used by cox loss) */
inline const std::vector<size_t>& LabelAbsSort() const { inline const std::vector<size_t>& LabelAbsSort() const {
if (label_order_cache_.size() == labels_.Size()) { if (label_order_cache_.size() == labels_.Size()) {

View File

@ -88,13 +88,11 @@ class GradientBooster {
* \param inst the instance you want to predict * \param inst the instance you want to predict
* \param out_preds output vector to hold the predictions * \param out_preds output vector to hold the predictions
* \param ntree_limit limit the number of trees used in prediction * \param ntree_limit limit the number of trees used in prediction
* \param root_index the root index
* \sa Predict * \sa Predict
*/ */
virtual void PredictInstance(const SparsePage::Inst& inst, virtual void PredictInstance(const SparsePage::Inst& inst,
std::vector<bst_float>* out_preds, std::vector<bst_float>* out_preds,
unsigned ntree_limit = 0, unsigned ntree_limit = 0) = 0;
unsigned root_index = 0) = 0;
/*! /*!
* \brief predict the leaf index of each tree, the output will be nsample * ntree vector * \brief predict the leaf index of each tree, the output will be nsample * ntree vector
* this is only valid in gbtree predictor * this is only valid in gbtree predictor

View File

@ -98,7 +98,6 @@ class Predictor {
/** /**
* \fn virtual void Predictor::PredictInstance( const SparsePage::Inst& * \fn virtual void Predictor::PredictInstance( const SparsePage::Inst&
* inst, std::vector<bst_float>* out_preds, const gbm::GBTreeModel& model, * inst, std::vector<bst_float>* out_preds, const gbm::GBTreeModel& model,
* unsigned ntree_limit = 0, unsigned root_index = 0) = 0;
* *
* \brief online prediction function, predict score for one instance at a time * \brief online prediction function, predict score for one instance at a time
* NOTE: use the batch prediction interface if possible, batch prediction is * NOTE: use the batch prediction interface if possible, batch prediction is
@ -109,14 +108,12 @@ class Predictor {
* \param [in,out] out_preds The output preds. * \param [in,out] out_preds The output preds.
* \param model The model to predict from * \param model The model to predict from
* \param ntree_limit (Optional) The ntree limit. * \param ntree_limit (Optional) The ntree limit.
* \param root_index (Optional) Zero-based index of the root.
*/ */
virtual void PredictInstance(const SparsePage::Inst& inst, virtual void PredictInstance(const SparsePage::Inst& inst,
std::vector<bst_float>* out_preds, std::vector<bst_float>* out_preds,
const gbm::GBTreeModel& model, const gbm::GBTreeModel& model,
unsigned ntree_limit = 0, unsigned ntree_limit = 0) = 0;
unsigned root_index = 0) = 0;
/** /**
* \fn virtual void Predictor::PredictLeaf(DMatrix* dmat, * \fn virtual void Predictor::PredictLeaf(DMatrix* dmat,

View File

@ -29,8 +29,8 @@ struct PathElement; // forward declaration
/*! \brief meta parameters of the tree */ /*! \brief meta parameters of the tree */
struct TreeParam : public dmlc::Parameter<TreeParam> { struct TreeParam : public dmlc::Parameter<TreeParam> {
/*! \brief number of start root */ /*! \brief (Deprecated) number of start root */
int num_roots; int deprecated_num_roots;
/*! \brief total number of nodes */ /*! \brief total number of nodes */
int num_nodes; int num_nodes;
/*!\brief number of deleted nodes */ /*!\brief number of deleted nodes */
@ -52,14 +52,12 @@ struct TreeParam : public dmlc::Parameter<TreeParam> {
static_assert(sizeof(TreeParam) == (31 + 6) * sizeof(int), static_assert(sizeof(TreeParam) == (31 + 6) * sizeof(int),
"TreeParam: 64 bit align"); "TreeParam: 64 bit align");
std::memset(this, 0, sizeof(TreeParam)); std::memset(this, 0, sizeof(TreeParam));
num_nodes = num_roots = 1; num_nodes = 1;
} }
// declare the parameters // declare the parameters
DMLC_DECLARE_PARAMETER(TreeParam) { DMLC_DECLARE_PARAMETER(TreeParam) {
// only declare the parameters that can be set by the user. // only declare the parameters that can be set by the user.
// other arguments are set by the algorithm. // other arguments are set by the algorithm.
DMLC_DECLARE_FIELD(num_roots).set_lower_bound(1).set_default(1)
.describe("Number of start root of trees.");
DMLC_DECLARE_FIELD(num_nodes).set_lower_bound(1).set_default(1); DMLC_DECLARE_FIELD(num_nodes).set_lower_bound(1).set_default(1);
DMLC_DECLARE_FIELD(num_feature) DMLC_DECLARE_FIELD(num_feature)
.describe("Number of features used in tree construction."); .describe("Number of features used in tree construction.");
@ -68,7 +66,7 @@ struct TreeParam : public dmlc::Parameter<TreeParam> {
} }
bool operator==(const TreeParam& b) const { bool operator==(const TreeParam& b) const {
return num_roots == b.num_roots && num_nodes == b.num_nodes && return num_nodes == b.num_nodes &&
num_deleted == b.num_deleted && max_depth == b.max_depth && num_deleted == b.num_deleted && max_depth == b.max_depth &&
num_feature == b.num_feature && num_feature == b.num_feature &&
size_leaf_vector == b.size_leaf_vector; size_leaf_vector == b.size_leaf_vector;
@ -269,7 +267,6 @@ class RegTree : public Model {
/*! \brief constructor */ /*! \brief constructor */
RegTree() { RegTree() {
param.num_nodes = 1; param.num_nodes = 1;
param.num_roots = 1;
param.num_deleted = 0; param.num_deleted = 0;
nodes_.resize(param.num_nodes); nodes_.resize(param.num_nodes);
stats_.resize(param.num_nodes); stats_.resize(param.num_nodes);
@ -380,16 +377,12 @@ class RegTree : public Model {
* \brief get maximum depth * \brief get maximum depth
*/ */
int MaxDepth() { int MaxDepth() {
int maxd = 0; return MaxDepth(0);
for (int i = 0; i < param.num_roots; ++i) {
maxd = std::max(maxd, MaxDepth(i));
}
return maxd;
} }
/*! \brief number of extra nodes besides the root */ /*! \brief number of extra nodes besides the root */
int NumExtraNodes() const { int NumExtraNodes() const {
return param.num_nodes - param.num_roots - param.num_deleted; return param.num_nodes - 1 - param.num_deleted;
} }
/*! /*!
@ -444,19 +437,17 @@ class RegTree : public Model {
/*! /*!
* \brief get the leaf index * \brief get the leaf index
* \param feat dense feature vector, if the feature is missing the field is set to NaN * \param feat dense feature vector, if the feature is missing the field is set to NaN
* \param root_id starting root index of the instance
* \return the leaf index of the given feature * \return the leaf index of the given feature
*/ */
int GetLeafIndex(const FVec& feat, unsigned root_id = 0) const; int GetLeafIndex(const FVec& feat) const;
/*! /*!
* \brief calculate the feature contributions (https://arxiv.org/abs/1706.06060) for the tree * \brief calculate the feature contributions (https://arxiv.org/abs/1706.06060) for the tree
* \param feat dense feature vector, if the feature is missing the field is set to NaN * \param feat dense feature vector, if the feature is missing the field is set to NaN
* \param root_id starting root index of the instance
* \param out_contribs output vector to hold the contributions * \param out_contribs output vector to hold the contributions
* \param condition fix one feature to either off (-1) on (1) or not fixed (0 default) * \param condition fix one feature to either off (-1) on (1) or not fixed (0 default)
* \param condition_feature the index of the feature to fix * \param condition_feature the index of the feature to fix
*/ */
void CalculateContributions(const RegTree::FVec& feat, unsigned root_id, void CalculateContributions(const RegTree::FVec& feat,
bst_float* out_contribs, int condition = 0, bst_float* out_contribs, int condition = 0,
unsigned condition_feature = 0) const; unsigned condition_feature = 0) const;
/*! /*!
@ -482,10 +473,9 @@ class RegTree : public Model {
/*! /*!
* \brief calculate the approximate feature contributions for the given root * \brief calculate the approximate feature contributions for the given root
* \param feat dense feature vector, if the feature is missing the field is set to NaN * \param feat dense feature vector, if the feature is missing the field is set to NaN
* \param root_id starting root index of the instance
* \param out_contribs output vector to hold the contributions * \param out_contribs output vector to hold the contributions
*/ */
void CalculateContributionsApprox(const RegTree::FVec& feat, unsigned root_id, void CalculateContributionsApprox(const RegTree::FVec& feat,
bst_float* out_contribs) const; bst_float* out_contribs) const;
/*! /*!
* \brief get next position of the tree given current pid * \brief get next position of the tree given current pid
@ -536,7 +526,7 @@ class RegTree : public Model {
} }
// delete a tree node, keep the parent field to allow trace back // delete a tree node, keep the parent field to allow trace back
void DeleteNode(int nid) { void DeleteNode(int nid) {
CHECK_GE(nid, param.num_roots); CHECK_GE(nid, 1);
deleted_nodes_.push_back(nid); deleted_nodes_.push_back(nid);
nodes_[nid].MarkDelete(); nodes_[nid].MarkDelete();
++param.num_deleted; ++param.num_deleted;
@ -576,14 +566,13 @@ inline bool RegTree::FVec::IsMissing(size_t i) const {
return data_[i].flag == -1; return data_[i].flag == -1;
} }
inline int RegTree::GetLeafIndex(const RegTree::FVec& feat, inline int RegTree::GetLeafIndex(const RegTree::FVec& feat) const {
unsigned root_id) const { bst_node_t nid = 0;
auto pid = static_cast<int>(root_id); while (!(*this)[nid].IsLeaf()) {
while (!(*this)[pid].IsLeaf()) { unsigned split_index = (*this)[nid].SplitIndex();
unsigned split_index = (*this)[pid].SplitIndex(); nid = this->GetNext(nid, feat.Fvalue(split_index), feat.IsMissing(split_index));
pid = this->GetNext(pid, feat.Fvalue(split_index), feat.IsMissing(split_index));
} }
return pid; return nid;
} }
/*! \brief get next position of the tree given current pid */ /*! \brief get next position of the tree given current pid */

View File

@ -329,9 +329,6 @@ XGB_DLL int XGDMatrixSliceDMatrixEx(DMatrixHandle handle,
if (src_base_margin.size() != 0) { if (src_base_margin.size() != 0) {
ret_base_margin.push_back(src_base_margin[ridx]); ret_base_margin.push_back(src_base_margin[ridx]);
} }
if (src.info.root_index_.size() != 0) {
ret.info.root_index_.push_back(src.info.root_index_[ridx]);
}
} }
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(std::move(source))); *out = new std::shared_ptr<DMatrix>(DMatrix::Create(std::move(source)));
API_END(); API_END();
@ -426,9 +423,7 @@ XGB_DLL int XGDMatrixGetUIntInfo(const DMatrixHandle handle,
CHECK_HANDLE(); CHECK_HANDLE();
const MetaInfo& info = static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info(); const MetaInfo& info = static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info();
const std::vector<unsigned>* vec = nullptr; const std::vector<unsigned>* vec = nullptr;
if (!std::strcmp(field, "root_index")) { if (!std::strcmp(field, "group_ptr")) {
vec = &info.root_index_;
} else if (!std::strcmp(field, "group_ptr")) {
vec = &info.group_ptr_; vec = &info.group_ptr_;
} else { } else {
LOG(FATAL) << "Unknown comp uint field name " << field LOG(FATAL) << "Unknown comp uint field name " << field

View File

@ -34,7 +34,6 @@ namespace xgboost {
void MetaInfo::Clear() { void MetaInfo::Clear() {
num_row_ = num_col_ = num_nonzero_ = 0; num_row_ = num_col_ = num_nonzero_ = 0;
labels_.HostVector().clear(); labels_.HostVector().clear();
root_index_.clear();
group_ptr_.clear(); group_ptr_.clear();
weights_.HostVector().clear(); weights_.HostVector().clear();
base_margin_.HostVector().clear(); base_margin_.HostVector().clear();
@ -48,7 +47,6 @@ void MetaInfo::SaveBinary(dmlc::Stream *fo) const {
fo->Write(labels_.HostVector()); fo->Write(labels_.HostVector());
fo->Write(group_ptr_); fo->Write(group_ptr_);
fo->Write(weights_.HostVector()); fo->Write(weights_.HostVector());
fo->Write(root_index_);
fo->Write(base_margin_.HostVector()); fo->Write(base_margin_.HostVector());
} }
@ -69,7 +67,6 @@ void MetaInfo::LoadBinary(dmlc::Stream *fi) {
CHECK(fi->Read(&group_ptr_)) << "MetaInfo: invalid format"; CHECK(fi->Read(&group_ptr_)) << "MetaInfo: invalid format";
CHECK(fi->Read(&weights_.HostVector())) << "MetaInfo: invalid format"; CHECK(fi->Read(&weights_.HostVector())) << "MetaInfo: invalid format";
CHECK(fi->Read(&root_index_)) << "MetaInfo: invalid format";
CHECK(fi->Read(&base_margin_.HostVector())) << "MetaInfo: invalid format"; CHECK(fi->Read(&base_margin_.HostVector())) << "MetaInfo: invalid format";
} }
@ -121,11 +118,7 @@ inline bool MetaTryLoadFloatInfo(const std::string& fname,
} \ } \
void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t num) { void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t num) {
if (!std::strcmp(key, "root_index")) { if (!std::strcmp(key, "label")) {
root_index_.resize(num);
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
std::copy(cast_dptr, cast_dptr + num, root_index_.begin()));
} else if (!std::strcmp(key, "label")) {
auto& labels = labels_.HostVector(); auto& labels = labels_.HostVector();
labels.resize(num); labels.resize(num);
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr, DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,

View File

@ -117,8 +117,7 @@ class GBLinear : public GradientBooster {
// add base margin // add base margin
void PredictInstance(const SparsePage::Inst &inst, void PredictInstance(const SparsePage::Inst &inst,
std::vector<bst_float> *out_preds, std::vector<bst_float> *out_preds,
unsigned ntree_limit, unsigned ntree_limit) override {
unsigned root_index) override {
const int ngroup = model_.param.num_output_group; const int ngroup = model_.param.num_output_group;
for (int gid = 0; gid < ngroup; ++gid) { for (int gid = 0; gid < ngroup; ++gid) {
this->Pred(inst, dmlc::BeginPtr(*out_preds), gid, base_margin_); this->Pred(inst, dmlc::BeginPtr(*out_preds), gid, base_margin_);

View File

@ -322,10 +322,8 @@ class Dart : public GBTree {
PredLoopInternal<Dart>(p_fmat, &out_preds->HostVector(), 0, ntree_limit, true); PredLoopInternal<Dart>(p_fmat, &out_preds->HostVector(), 0, ntree_limit, true);
} }
void PredictInstance(const SparsePage::Inst& inst, void PredictInstance(const SparsePage::Inst &inst,
std::vector<bst_float>* out_preds, std::vector<bst_float> *out_preds, unsigned ntree_limit) override {
unsigned ntree_limit,
unsigned root_index) override {
DropTrees(1); DropTrees(1);
if (thread_temp_.size() == 0) { if (thread_temp_.size() == 0) {
thread_temp_.resize(1, RegTree::FVec()); thread_temp_.resize(1, RegTree::FVec());
@ -338,9 +336,9 @@ class Dart : public GBTree {
} }
// loop over output groups // loop over output groups
for (int gid = 0; gid < model_.param.num_output_group; ++gid) { for (int gid = 0; gid < model_.param.num_output_group; ++gid) {
(*out_preds)[gid] (*out_preds)[gid] =
= PredValue(inst, gid, root_index, PredValue(inst, gid, &thread_temp_[0], 0, ntree_limit) +
&thread_temp_[0], 0, ntree_limit) + model_.base_margin; model_.base_margin;
} }
} }
@ -411,7 +409,6 @@ class Dart : public GBTree {
int num_group, int num_group,
unsigned tree_begin, unsigned tree_begin,
unsigned tree_end) { unsigned tree_end) {
const MetaInfo& info = p_fmat->Info();
const int nthread = omp_get_max_threads(); const int nthread = omp_get_max_threads();
CHECK_EQ(num_group, model_.param.num_output_group); CHECK_EQ(num_group, model_.param.num_output_group);
InitThreadTemp(nthread); InitThreadTemp(nthread);
@ -442,8 +439,7 @@ class Dart : public GBTree {
for (int gid = 0; gid < num_group; ++gid) { for (int gid = 0; gid < num_group; ++gid) {
const size_t offset = ridx[k] * num_group + gid; const size_t offset = ridx[k] * num_group + gid;
preds[offset] += preds[offset] +=
self->PredValue(inst[k], gid, info.GetRoot(ridx[k]), self->PredValue(inst[k], gid, &feats, tree_begin, tree_end);
&feats, tree_begin, tree_end);
} }
} }
} }
@ -455,7 +451,7 @@ class Dart : public GBTree {
for (int gid = 0; gid < num_group; ++gid) { for (int gid = 0; gid < num_group; ++gid) {
const size_t offset = ridx * num_group + gid; const size_t offset = ridx * num_group + gid;
preds[offset] += preds[offset] +=
self->PredValue(inst, gid, info.GetRoot(ridx), self->PredValue(inst, gid,
&feats, tree_begin, tree_end); &feats, tree_begin, tree_end);
} }
} }
@ -478,7 +474,6 @@ class Dart : public GBTree {
// predict the leaf scores without dropped trees // predict the leaf scores without dropped trees
inline bst_float PredValue(const SparsePage::Inst &inst, inline bst_float PredValue(const SparsePage::Inst &inst,
int bst_group, int bst_group,
unsigned root_index,
RegTree::FVec *p_feats, RegTree::FVec *p_feats,
unsigned tree_begin, unsigned tree_begin,
unsigned tree_end) { unsigned tree_end) {
@ -488,7 +483,7 @@ class Dart : public GBTree {
if (model_.tree_info[i] == bst_group) { if (model_.tree_info[i] == bst_group) {
bool drop = (std::binary_search(idx_drop_.begin(), idx_drop_.end(), i)); bool drop = (std::binary_search(idx_drop_.begin(), idx_drop_.end(), i));
if (!drop) { if (!drop) {
int tid = model_.trees[i]->GetLeafIndex(*p_feats, root_index); int tid = model_.trees[i]->GetLeafIndex(*p_feats);
psum += weight_drop_[i] * (*model_.trees[i])[tid].LeafValue(); psum += weight_drop_[i] * (*model_.trees[i])[tid].LeafValue();
} }
} }

View File

@ -196,12 +196,11 @@ class GBTree : public GradientBooster {
} }
void PredictInstance(const SparsePage::Inst& inst, void PredictInstance(const SparsePage::Inst& inst,
std::vector<bst_float>* out_preds, std::vector<bst_float>* out_preds,
unsigned ntree_limit, unsigned ntree_limit) override {
unsigned root_index) override {
CHECK(configured_); CHECK(configured_);
cpu_predictor_->PredictInstance(inst, out_preds, model_, cpu_predictor_->PredictInstance(inst, out_preds, model_,
ntree_limit, root_index); ntree_limit);
} }
void PredictLeaf(DMatrix* p_fmat, void PredictLeaf(DMatrix* p_fmat,

View File

@ -18,8 +18,8 @@ namespace gbm {
struct GBTreeModelParam : public dmlc::Parameter<GBTreeModelParam> { struct GBTreeModelParam : public dmlc::Parameter<GBTreeModelParam> {
/*! \brief number of trees */ /*! \brief number of trees */
int num_trees; int num_trees;
/*! \brief number of roots */ /*! \brief (Deprecated) number of roots */
int num_roots; int deprecated_num_roots;
/*! \brief number of features to be used by trees */ /*! \brief number of features to be used by trees */
int num_feature; int num_feature;
/*! \brief pad this space, for backward compatibility reason.*/ /*! \brief pad this space, for backward compatibility reason.*/
@ -50,8 +50,6 @@ struct GBTreeModelParam : public dmlc::Parameter<GBTreeModelParam> {
.describe( .describe(
"Number of output groups to be predicted," "Number of output groups to be predicted,"
" used for multi-class classification."); " used for multi-class classification.");
DMLC_DECLARE_FIELD(num_roots).set_lower_bound(1).set_default(1).describe(
"Tree updater sequence.");
DMLC_DECLARE_FIELD(num_feature) DMLC_DECLARE_FIELD(num_feature)
.set_lower_bound(0) .set_lower_bound(0)
.describe("Number of features used for training and prediction."); .describe("Number of features used for training and prediction.");

View File

@ -19,13 +19,13 @@ class CPUPredictor : public Predictor {
static bst_float PredValue(const SparsePage::Inst& inst, static bst_float PredValue(const SparsePage::Inst& inst,
const std::vector<std::unique_ptr<RegTree>>& trees, const std::vector<std::unique_ptr<RegTree>>& trees,
const std::vector<int>& tree_info, int bst_group, const std::vector<int>& tree_info, int bst_group,
unsigned root_index, RegTree::FVec* p_feats, RegTree::FVec* p_feats,
unsigned tree_begin, unsigned tree_end) { unsigned tree_begin, unsigned tree_end) {
bst_float psum = 0.0f; bst_float psum = 0.0f;
p_feats->Fill(inst); p_feats->Fill(inst);
for (size_t i = tree_begin; i < tree_end; ++i) { for (size_t i = tree_begin; i < tree_end; ++i) {
if (tree_info[i] == bst_group) { if (tree_info[i] == bst_group) {
int tid = trees[i]->GetLeafIndex(*p_feats, root_index); int tid = trees[i]->GetLeafIndex(*p_feats);
psum += (*trees[i])[tid].LeafValue(); psum += (*trees[i])[tid].LeafValue();
} }
} }
@ -47,7 +47,6 @@ class CPUPredictor : public Predictor {
std::vector<bst_float>* out_preds, std::vector<bst_float>* out_preds,
const gbm::GBTreeModel& model, int num_group, const gbm::GBTreeModel& model, int num_group,
unsigned tree_begin, unsigned tree_end) { unsigned tree_begin, unsigned tree_end) {
const MetaInfo& info = p_fmat->Info();
const int nthread = omp_get_max_threads(); const int nthread = omp_get_max_threads();
InitThreadTemp(nthread, model.param.num_feature); InitThreadTemp(nthread, model.param.num_feature);
std::vector<bst_float>& preds = *out_preds; std::vector<bst_float>& preds = *out_preds;
@ -81,7 +80,7 @@ class CPUPredictor : public Predictor {
const size_t offset = ridx[k] * num_group + gid; const size_t offset = ridx[k] * num_group + gid;
preds[offset] += this->PredValue( preds[offset] += this->PredValue(
inst[k], model.trees, model.tree_info, gid, inst[k], model.trees, model.tree_info, gid,
info.GetRoot(ridx[k]), &feats, tree_begin, tree_end); &feats, tree_begin, tree_end);
} }
} }
} }
@ -94,7 +93,7 @@ class CPUPredictor : public Predictor {
const size_t offset = ridx * num_group + gid; const size_t offset = ridx * num_group + gid;
preds[offset] += preds[offset] +=
this->PredValue(inst, model.trees, model.tree_info, gid, this->PredValue(inst, model.trees, model.tree_info, gid,
info.GetRoot(ridx), &feats, tree_begin, tree_end); &feats, tree_begin, tree_end);
} }
} }
} }
@ -204,8 +203,7 @@ class CPUPredictor : public Predictor {
void PredictInstance(const SparsePage::Inst& inst, void PredictInstance(const SparsePage::Inst& inst,
std::vector<bst_float>* out_preds, std::vector<bst_float>* out_preds,
const gbm::GBTreeModel& model, unsigned ntree_limit, const gbm::GBTreeModel& model, unsigned ntree_limit) override {
unsigned root_index) override {
if (thread_temp.size() == 0) { if (thread_temp.size() == 0) {
thread_temp.resize(1, RegTree::FVec()); thread_temp.resize(1, RegTree::FVec());
thread_temp[0].Init(model.param.num_feature); thread_temp[0].Init(model.param.num_feature);
@ -219,7 +217,7 @@ class CPUPredictor : public Predictor {
// loop over output groups // loop over output groups
for (int gid = 0; gid < model.param.num_output_group; ++gid) { for (int gid = 0; gid < model.param.num_output_group; ++gid) {
(*out_preds)[gid] = (*out_preds)[gid] =
PredValue(inst, model.trees, model.tree_info, gid, root_index, PredValue(inst, model.trees, model.tree_info, gid,
&thread_temp[0], 0, ntree_limit) + &thread_temp[0], 0, ntree_limit) +
model.base_margin; model.base_margin;
} }
@ -247,7 +245,7 @@ class CPUPredictor : public Predictor {
RegTree::FVec& feats = thread_temp[tid]; RegTree::FVec& feats = thread_temp[tid];
feats.Fill(batch[i]); feats.Fill(batch[i]);
for (unsigned j = 0; j < ntree_limit; ++j) { for (unsigned j = 0; j < ntree_limit; ++j) {
int tid = model.trees[j]->GetLeafIndex(feats, info.GetRoot(ridx)); int tid = model.trees[j]->GetLeafIndex(feats);
preds[ridx * ntree_limit + j] = static_cast<bst_float>(tid); preds[ridx * ntree_limit + j] = static_cast<bst_float>(tid);
} }
feats.Drop(batch[i]); feats.Drop(batch[i]);
@ -270,7 +268,7 @@ class CPUPredictor : public Predictor {
ntree_limit = static_cast<unsigned>(model.trees.size()); ntree_limit = static_cast<unsigned>(model.trees.size());
} }
const int ngroup = model.param.num_output_group; const int ngroup = model.param.num_output_group;
size_t ncolumns = model.param.num_feature + 1; size_t const ncolumns = model.param.num_feature + 1;
// allocate space for (number of features + bias) times the number of rows // allocate space for (number of features + bias) times the number of rows
std::vector<bst_float>& contribs = *out_contribs; std::vector<bst_float>& contribs = *out_contribs;
contribs.resize(info.num_row_ * ncolumns * model.param.num_output_group); contribs.resize(info.num_row_ * ncolumns * model.param.num_output_group);
@ -290,15 +288,13 @@ class CPUPredictor : public Predictor {
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < nsize; ++i) { for (bst_omp_uint i = 0; i < nsize; ++i) {
auto row_idx = static_cast<size_t>(batch.base_rowid + i); auto row_idx = static_cast<size_t>(batch.base_rowid + i);
unsigned root_id = info.GetRoot(row_idx); std::vector<bst_float> this_tree_contribs(ncolumns);
RegTree::FVec& feats = thread_temp[omp_get_thread_num()]; RegTree::FVec& feats = thread_temp[omp_get_thread_num()];
// loop over all classes // loop over all classes
for (int gid = 0; gid < ngroup; ++gid) { for (int gid = 0; gid < ngroup; ++gid) {
bst_float* p_contribs = bst_float* p_contribs =
&contribs[(row_idx * ngroup + gid) * ncolumns]; &contribs[(row_idx * ngroup + gid) * ncolumns];
feats.Fill(batch[i]); feats.Fill(batch[i]);
std::vector<bst_float> this_tree_contribs;
this_tree_contribs.resize(ncolumns);
// calculate contributions // calculate contributions
for (unsigned j = 0; j < ntree_limit; ++j) { for (unsigned j = 0; j < ntree_limit; ++j) {
std::fill(this_tree_contribs.begin(), this_tree_contribs.end(), 0); std::fill(this_tree_contribs.begin(), this_tree_contribs.end(), 0);
@ -306,10 +302,10 @@ class CPUPredictor : public Predictor {
continue; continue;
} }
if (!approximate) { if (!approximate) {
model.trees[j]->CalculateContributions(feats, root_id, &this_tree_contribs[0], model.trees[j]->CalculateContributions(feats, &this_tree_contribs[0],
condition, condition_feature); condition, condition_feature);
} else { } else {
model.trees[j]->CalculateContributionsApprox(feats, root_id, &this_tree_contribs[0]); model.trees[j]->CalculateContributionsApprox(feats, &this_tree_contribs[0]);
} }
for (int ci = 0 ; ci < ncolumns ; ++ci) { for (int ci = 0 ; ci < ncolumns ; ++ci) {
p_contribs[ci] += this_tree_contribs[ci] * p_contribs[ci] += this_tree_contribs[ci] *

View File

@ -381,8 +381,7 @@ class GPUPredictor : public xgboost::Predictor {
void PredictInstance(const SparsePage::Inst& inst, void PredictInstance(const SparsePage::Inst& inst,
std::vector<bst_float>* out_preds, std::vector<bst_float>* out_preds,
const gbm::GBTreeModel& model, unsigned ntree_limit, const gbm::GBTreeModel& model, unsigned ntree_limit) override {
unsigned root_index) override {
LOG(FATAL) << "Internal error: " << __func__ LOG(FATAL) << "Internal error: " << __func__
<< " is not implemented in GPU Predictor."; << " is not implemented in GPU Predictor.";
} }

View File

@ -609,9 +609,7 @@ std::string RegTree::DumpModel(const FeatureMap& fmap,
std::unique_ptr<TreeGenerator> builder { std::unique_ptr<TreeGenerator> builder {
TreeGenerator::Create(format, fmap, with_stats) TreeGenerator::Create(format, fmap, with_stats)
}; };
for (int32_t i = 0; i < param.num_roots; ++i) { builder->BuildTree(*this);
builder->BuildTree(*this);
}
std::string result = builder->Str(); std::string result = builder->Str();
return result; return result;
@ -628,8 +626,10 @@ void RegTree::LoadModel(dmlc::Stream* fi) {
sizeof(RTreeNodeStat) * stats_.size()); sizeof(RTreeNodeStat) * stats_.size());
// chg deleted nodes // chg deleted nodes
deleted_nodes_.resize(0); deleted_nodes_.resize(0);
for (int i = param.num_roots; i < param.num_nodes; ++i) { for (int i = 1; i < param.num_nodes; ++i) {
if (nodes_[i].IsDeleted()) deleted_nodes_.push_back(i); if (nodes_[i].IsDeleted()) {
deleted_nodes_.push_back(i);
}
} }
CHECK_EQ(static_cast<int>(deleted_nodes_.size()), param.num_deleted); CHECK_EQ(static_cast<int>(deleted_nodes_.size()), param.num_deleted);
} }
@ -652,9 +652,7 @@ void RegTree::FillNodeMeanValues() {
return; return;
} }
this->node_mean_values_.resize(num_nodes); this->node_mean_values_.resize(num_nodes);
for (int root_id = 0; root_id < param.num_roots; ++root_id) { this->FillNodeMeanValue(0);
this->FillNodeMeanValue(root_id);
}
} }
bst_float RegTree::FillNodeMeanValue(int nid) { bst_float RegTree::FillNodeMeanValue(int nid) {
@ -672,28 +670,27 @@ bst_float RegTree::FillNodeMeanValue(int nid) {
} }
void RegTree::CalculateContributionsApprox(const RegTree::FVec &feat, void RegTree::CalculateContributionsApprox(const RegTree::FVec &feat,
unsigned root_id,
bst_float *out_contribs) const { bst_float *out_contribs) const {
CHECK_GT(this->node_mean_values_.size(), 0U); CHECK_GT(this->node_mean_values_.size(), 0U);
// this follows the idea of http://blog.datadive.net/interpreting-random-forests/ // this follows the idea of http://blog.datadive.net/interpreting-random-forests/
unsigned split_index = 0; unsigned split_index = 0;
auto pid = static_cast<int>(root_id);
// update bias value // update bias value
bst_float node_value = this->node_mean_values_[pid]; bst_float node_value = this->node_mean_values_[0];
out_contribs[feat.Size()] += node_value; out_contribs[feat.Size()] += node_value;
if ((*this)[pid].IsLeaf()) { if ((*this)[0].IsLeaf()) {
// nothing to do anymore // nothing to do anymore
return; return;
} }
while (!(*this)[pid].IsLeaf()) { bst_node_t nid = 0;
split_index = (*this)[pid].SplitIndex(); while (!(*this)[nid].IsLeaf()) {
pid = this->GetNext(pid, feat.Fvalue(split_index), feat.IsMissing(split_index)); split_index = (*this)[nid].SplitIndex();
bst_float new_value = this->node_mean_values_[pid]; nid = this->GetNext(nid, feat.Fvalue(split_index), feat.IsMissing(split_index));
bst_float new_value = this->node_mean_values_[nid];
// update feature weight // update feature weight
out_contribs[split_index] += new_value - node_value; out_contribs[split_index] += new_value - node_value;
node_value = new_value; node_value = new_value;
} }
bst_float leaf_value = (*this)[pid].LeafValue(); bst_float leaf_value = (*this)[nid].LeafValue();
// update leaf feature weight // update leaf feature weight
out_contribs[split_index] += leaf_value - node_value; out_contribs[split_index] += leaf_value - node_value;
} }
@ -868,21 +865,20 @@ void RegTree::TreeShap(const RegTree::FVec &feat, bst_float *phi,
} }
void RegTree::CalculateContributions(const RegTree::FVec &feat, void RegTree::CalculateContributions(const RegTree::FVec &feat,
unsigned root_id, bst_float *out_contribs, bst_float *out_contribs,
int condition, int condition,
unsigned condition_feature) const { unsigned condition_feature) const {
// find the expected value of the tree's predictions // find the expected value of the tree's predictions
if (condition == 0) { if (condition == 0) {
bst_float node_value = this->node_mean_values_[static_cast<int>(root_id)]; bst_float node_value = this->node_mean_values_[0];
out_contribs[feat.Size()] += node_value; out_contribs[feat.Size()] += node_value;
} }
// Preallocate space for the unique path data // Preallocate space for the unique path data
const int maxd = this->MaxDepth(root_id) + 2; const int maxd = this->MaxDepth(0) + 2;
auto *unique_path_data = new PathElement[(maxd * (maxd + 1)) / 2]; std::vector<PathElement> unique_path_data((maxd * (maxd + 1)) / 2);
TreeShap(feat, out_contribs, root_id, 0, unique_path_data, TreeShap(feat, out_contribs, 0, 0, unique_path_data.data(),
1, 1, -1, condition, condition_feature, 1); 1, 1, -1, condition, condition_feature, 1);
delete[] unique_path_data;
} }
} // namespace xgboost } // namespace xgboost

View File

@ -128,21 +128,10 @@ class BaseMaker: public TreeUpdater {
inline void InitData(const std::vector<GradientPair> &gpair, inline void InitData(const std::vector<GradientPair> &gpair,
const DMatrix &fmat, const DMatrix &fmat,
const RegTree &tree) { const RegTree &tree) {
CHECK_EQ(tree.param.num_nodes, tree.param.num_roots)
<< "TreeMaker: can only grow new tree";
const std::vector<unsigned> &root_index = fmat.Info().root_index_;
{ {
// setup position // setup position
position_.resize(gpair.size()); position_.resize(gpair.size());
if (root_index.size() == 0) { std::fill(position_.begin(), position_.end(), 0);
std::fill(position_.begin(), position_.end(), 0);
} else {
for (size_t i = 0; i < position_.size(); ++i) {
position_[i] = root_index[i];
CHECK_LT(root_index[i], (unsigned)tree.param.num_roots)
<< "root index exceed setting";
}
}
// mark delete for the deleted datas // mark delete for the deleted datas
for (size_t i = 0; i < position_.size(); ++i) { for (size_t i = 0; i < position_.size(); ++i) {
if (gpair[i].GetHess() < 0.0f) position_[i] = ~position_[i]; if (gpair[i].GetHess() < 0.0f) position_[i] = ~position_[i];
@ -160,9 +149,7 @@ class BaseMaker: public TreeUpdater {
{ {
// expand query // expand query
qexpand_.reserve(256); qexpand_.clear(); qexpand_.reserve(256); qexpand_.clear();
for (int i = 0; i < tree.param.num_roots; ++i) { qexpand_.push_back(0);
qexpand_.push_back(i);
}
this->UpdateNode2WorkIndex(tree); this->UpdateNode2WorkIndex(tree);
} }
this->interaction_constraints_.Configure(param_, fmat.Info().num_col_); this->interaction_constraints_.Configure(param_, fmat.Info().num_col_);

View File

@ -146,21 +146,11 @@ class ColMaker: public TreeUpdater {
inline void InitData(const std::vector<GradientPair>& gpair, inline void InitData(const std::vector<GradientPair>& gpair,
const DMatrix& fmat, const DMatrix& fmat,
const RegTree& tree) { const RegTree& tree) {
CHECK_EQ(tree.param.num_nodes, tree.param.num_roots)
<< "ColMaker: can only grow new tree";
const std::vector<unsigned>& root_index = fmat.Info().root_index_;
{ {
// setup position // setup position
position_.resize(gpair.size()); position_.resize(gpair.size());
CHECK_EQ(fmat.Info().num_row_, position_.size()); CHECK_EQ(fmat.Info().num_row_, position_.size());
if (root_index.size() == 0) { std::fill(position_.begin(), position_.end(), 0);
std::fill(position_.begin(), position_.end(), 0);
} else {
for (size_t ridx = 0; ridx < position_.size(); ++ridx) {
position_[ridx] = root_index[ridx];
CHECK_LT(root_index[ridx], (unsigned)tree.param.num_roots);
}
}
// mark delete for the deleted datas // mark delete for the deleted datas
for (size_t ridx = 0; ridx < position_.size(); ++ridx) { for (size_t ridx = 0; ridx < position_.size(); ++ridx) {
if (gpair[ridx].GetHess() < 0.0f) position_[ridx] = ~position_[ridx]; if (gpair[ridx].GetHess() < 0.0f) position_[ridx] = ~position_[ridx];
@ -192,9 +182,7 @@ class ColMaker: public TreeUpdater {
{ {
// expand query // expand query
qexpand_.reserve(256); qexpand_.clear(); qexpand_.reserve(256); qexpand_.clear();
for (int i = 0; i < tree.param.num_roots; ++i) { qexpand_.push_back(0);
qexpand_.push_back(i);
}
} }
} }
/*! /*!

View File

@ -119,10 +119,7 @@ class HistMaker: public BaseMaker {
this->InitData(gpair, *p_fmat, *p_tree); this->InitData(gpair, *p_fmat, *p_tree);
this->InitWorkSet(p_fmat, *p_tree, &selected_features_); this->InitWorkSet(p_fmat, *p_tree, &selected_features_);
// mark root node as fresh. // mark root node as fresh.
for (int i = 0; i < p_tree->param.num_roots; ++i) { (*p_tree)[0].SetLeaf(0.0f, 0);
(*p_tree)[i].SetLeaf(0.0f, 0);
}
CHECK_EQ(p_tree->param.num_roots, 1) << "Support for num roots is removed.";
for (int depth = 0; depth < param_.max_depth; ++depth) { for (int depth = 0; depth < param_.max_depth; ++depth) {
// reset and propose candidate split // reset and propose candidate split

View File

@ -75,7 +75,7 @@ class TreePruner: public TreeUpdater {
npruned = this->TryPruneLeaf(tree, nid, tree.GetDepth(nid), npruned); npruned = this->TryPruneLeaf(tree, nid, tree.GetDepth(nid), npruned);
} }
} }
LOG(INFO) << "tree pruning end, " << tree.param.num_roots << " roots, " LOG(INFO) << "tree pruning end, "
<< tree.NumExtraNodes() << " extra nodes, " << npruned << tree.NumExtraNodes() << " extra nodes, " << npruned
<< " pruned nodes, max_depth=" << tree.MaxDepth(); << " pruned nodes, max_depth=" << tree.MaxDepth();
} }

View File

@ -255,18 +255,15 @@ void QuantileHistMaker::Builder::ExpandWithLossGuide(
unsigned timestamp = 0; unsigned timestamp = 0;
int num_leaves = 0; int num_leaves = 0;
for (int nid = 0; nid < p_tree->param.num_roots; ++nid) { hist_.AddHistRow(0);
hist_.AddHistRow(nid); BuildHist(gpair_h, row_set_collection_[0], gmat, gmatb, hist_[0], true);
BuildHist(gpair_h, row_set_collection_[nid], gmat, gmatb, hist_[nid], true);
this->InitNewNode(nid, gmat, gpair_h, *p_fmat, *p_tree); this->InitNewNode(0, gmat, gpair_h, *p_fmat, *p_tree);
this->EvaluateSplit(nid, gmat, hist_, *p_fmat, *p_tree); this->EvaluateSplit(0, gmat, hist_, *p_fmat, *p_tree);
qexpand_loss_guided_->push(ExpandEntry(nid, p_tree->GetDepth(nid), qexpand_loss_guided_->push(ExpandEntry(0, p_tree->GetDepth(0),
snode_[nid].best.loss_chg, snode_[0].best.loss_chg, timestamp++));
timestamp++)); ++num_leaves;
++num_leaves;
}
while (!qexpand_loss_guided_->empty()) { while (!qexpand_loss_guided_->empty()) {
const ExpandEntry candidate = qexpand_loss_guided_->top(); const ExpandEntry candidate = qexpand_loss_guided_->top();
@ -397,8 +394,6 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat,
const std::vector<GradientPair>& gpair, const std::vector<GradientPair>& gpair,
const DMatrix& fmat, const DMatrix& fmat,
const RegTree& tree) { const RegTree& tree) {
CHECK_EQ(tree.param.num_nodes, tree.param.num_roots)
<< "ColMakerHist: can only grow new tree";
CHECK((param_.max_depth > 0 || param_.max_leaves > 0)) CHECK((param_.max_depth > 0 || param_.max_leaves > 0))
<< "max_depth or max_leaves cannot be both 0 (unlimited); " << "max_depth or max_leaves cannot be both 0 (unlimited); "
<< "at least one should be a positive quantity."; << "at least one should be a positive quantity.";
@ -425,7 +420,6 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat,
} }
hist_builder_.Init(this->nthread_, nbins); hist_builder_.Init(this->nthread_, nbins);
CHECK_EQ(info.root_index_.size(), 0U);
std::vector<size_t>& row_indices = row_set_collection_.row_indices_; std::vector<size_t>& row_indices = row_set_collection_.row_indices_;
row_indices.resize(info.num_row_); row_indices.resize(info.num_row_);
auto* p_row_indices = row_indices.data(); auto* p_row_indices = row_indices.data();

View File

@ -90,9 +90,7 @@ class TreeRefresher: public TreeUpdater {
param_.learning_rate = lr / trees.size(); param_.learning_rate = lr / trees.size();
int offset = 0; int offset = 0;
for (auto tree : trees) { for (auto tree : trees) {
for (int rid = 0; rid < tree->param.num_roots; ++rid) { this->Refresh(dmlc::BeginPtr(stemp[0]) + offset, 0, tree);
this->Refresh(dmlc::BeginPtr(stemp[0]) + offset, rid, tree);
}
offset += tree->param.num_nodes; offset += tree->param.num_nodes;
} }
// set learning rate back // set learning rate back
@ -107,7 +105,7 @@ class TreeRefresher: public TreeUpdater {
const bst_uint ridx, const bst_uint ridx,
GradStats *gstats) { GradStats *gstats) {
// start from groups that belongs to current data // start from groups that belongs to current data
auto pid = static_cast<int>(info.GetRoot(ridx)); auto pid = 0;
gstats[pid].Add(gpair[ridx]); gstats[pid].Add(gpair[ridx]);
// tranverse tree // tranverse tree
while (!tree[pid].IsLeaf()) { while (!tree[pid].IsLeaf()) {

View File

@ -12,10 +12,6 @@ TEST(MetaInfo, GetSet) {
xgboost::MetaInfo info; xgboost::MetaInfo info;
double double2[2] = {1.0, 2.0}; double double2[2] = {1.0, 2.0};
EXPECT_EQ(info.GetRoot(1), 0)
<< "When no root_index is given, was expecting default value 0";
info.SetInfo("root_index", double2, xgboost::kDouble, 2);
EXPECT_EQ(info.GetRoot(1), 2.0f);
EXPECT_EQ(info.labels_.Size(), 0); EXPECT_EQ(info.labels_.Size(), 0);
info.SetInfo("label", double2, xgboost::kFloat32, 2); info.SetInfo("label", double2, xgboost::kFloat32, 2);
@ -58,7 +54,7 @@ TEST(MetaInfo, SaveLoadBinary) {
info.SaveBinary(fs.get()); info.SaveBinary(fs.get());
} }
ASSERT_EQ(GetFileSize(tmp_file), 92) ASSERT_EQ(GetFileSize(tmp_file), 84)
<< "Expected saved binary file size to be same as object size"; << "Expected saved binary file size to be same as object size";
std::unique_ptr<dmlc::Stream> fs { std::unique_ptr<dmlc::Stream> fs {

View File

@ -231,7 +231,6 @@ class TestBasic(unittest.TestCase):
assert output == solution assert output == solution
class TestBasicPathLike(unittest.TestCase): class TestBasicPathLike(unittest.TestCase):
"""Unit tests using the os_fspath and pathlib.Path for file interaction.""" """Unit tests using the os_fspath and pathlib.Path for file interaction."""

View File

@ -63,11 +63,11 @@ class TestDMatrix(unittest.TestCase):
# Sliced UInt array # Sliced UInt array
z = np.array([12, 34, 56], np.uint32)[::2] z = np.array([12, 34, 56], np.uint32)[::2]
dmat = xgb.DMatrix(np.array([[]])) dmat = xgb.DMatrix(np.array([[]]))
dmat.set_uint_info('root_index', z) dmat.set_uint_info('group', z)
from_view = dmat.get_uint_info('root_index') from_view = dmat.get_uint_info('group_ptr')
dmat = xgb.DMatrix(np.array([[]])) dmat = xgb.DMatrix(np.array([[]]))
dmat.set_uint_info('root_index', z + 0) dmat.set_uint_info('group', z + 0)
from_array = dmat.get_uint_info('root_index') from_array = dmat.get_uint_info('group_ptr')
assert (from_view.shape == from_array.shape) assert (from_view.shape == from_array.shape)
assert (from_view == from_array).all() assert (from_view == from_array).all()
@ -142,7 +142,7 @@ class TestDMatrix(unittest.TestCase):
dtrain.get_float_info('label') dtrain.get_float_info('label')
dtrain.get_float_info('weight') dtrain.get_float_info('weight')
dtrain.get_float_info('base_margin') dtrain.get_float_info('base_margin')
dtrain.get_uint_info('root_index') dtrain.get_uint_info('group_ptr')
def test_sparse_dmatrix_csr(self): def test_sparse_dmatrix_csr(self):
nrow = 100 nrow = 100

View File

@ -161,8 +161,6 @@ class TestRanking(unittest.TestCase):
""" """
Retrieve the group number from the dmatrix Retrieve the group number from the dmatrix
""" """
# control that should work
self.dtrain.get_uint_info('root_index')
# test the new getter # test the new getter
self.dtrain.get_uint_info('group_ptr') self.dtrain.get_uint_info('group_ptr')