[Breaking] Remove num roots. (#5059)

This commit is contained in:
Jiaming Yuan
2019-12-05 21:58:43 +08:00
committed by GitHub
parent f3d8536702
commit 64af1ecf86
23 changed files with 87 additions and 189 deletions

View File

@@ -117,8 +117,7 @@ class GBLinear : public GradientBooster {
// add base margin
void PredictInstance(const SparsePage::Inst &inst,
std::vector<bst_float> *out_preds,
unsigned ntree_limit,
unsigned root_index) override {
unsigned ntree_limit) override {
const int ngroup = model_.param.num_output_group;
for (int gid = 0; gid < ngroup; ++gid) {
this->Pred(inst, dmlc::BeginPtr(*out_preds), gid, base_margin_);

View File

@@ -322,10 +322,8 @@ class Dart : public GBTree {
PredLoopInternal<Dart>(p_fmat, &out_preds->HostVector(), 0, ntree_limit, true);
}
void PredictInstance(const SparsePage::Inst& inst,
std::vector<bst_float>* out_preds,
unsigned ntree_limit,
unsigned root_index) override {
void PredictInstance(const SparsePage::Inst &inst,
std::vector<bst_float> *out_preds, unsigned ntree_limit) override {
DropTrees(1);
if (thread_temp_.size() == 0) {
thread_temp_.resize(1, RegTree::FVec());
@@ -338,9 +336,9 @@ class Dart : public GBTree {
}
// loop over output groups
for (int gid = 0; gid < model_.param.num_output_group; ++gid) {
(*out_preds)[gid]
= PredValue(inst, gid, root_index,
&thread_temp_[0], 0, ntree_limit) + model_.base_margin;
(*out_preds)[gid] =
PredValue(inst, gid, &thread_temp_[0], 0, ntree_limit) +
model_.base_margin;
}
}
@@ -411,7 +409,6 @@ class Dart : public GBTree {
int num_group,
unsigned tree_begin,
unsigned tree_end) {
const MetaInfo& info = p_fmat->Info();
const int nthread = omp_get_max_threads();
CHECK_EQ(num_group, model_.param.num_output_group);
InitThreadTemp(nthread);
@@ -442,8 +439,7 @@ class Dart : public GBTree {
for (int gid = 0; gid < num_group; ++gid) {
const size_t offset = ridx[k] * num_group + gid;
preds[offset] +=
self->PredValue(inst[k], gid, info.GetRoot(ridx[k]),
&feats, tree_begin, tree_end);
self->PredValue(inst[k], gid, &feats, tree_begin, tree_end);
}
}
}
@@ -455,7 +451,7 @@ class Dart : public GBTree {
for (int gid = 0; gid < num_group; ++gid) {
const size_t offset = ridx * num_group + gid;
preds[offset] +=
self->PredValue(inst, gid, info.GetRoot(ridx),
self->PredValue(inst, gid,
&feats, tree_begin, tree_end);
}
}
@@ -478,7 +474,6 @@ class Dart : public GBTree {
// predict the leaf scores without dropped trees
inline bst_float PredValue(const SparsePage::Inst &inst,
int bst_group,
unsigned root_index,
RegTree::FVec *p_feats,
unsigned tree_begin,
unsigned tree_end) {
@@ -488,7 +483,7 @@ class Dart : public GBTree {
if (model_.tree_info[i] == bst_group) {
bool drop = (std::binary_search(idx_drop_.begin(), idx_drop_.end(), i));
if (!drop) {
int tid = model_.trees[i]->GetLeafIndex(*p_feats, root_index);
int tid = model_.trees[i]->GetLeafIndex(*p_feats);
psum += weight_drop_[i] * (*model_.trees[i])[tid].LeafValue();
}
}

View File

@@ -196,12 +196,11 @@ class GBTree : public GradientBooster {
}
void PredictInstance(const SparsePage::Inst& inst,
std::vector<bst_float>* out_preds,
unsigned ntree_limit,
unsigned root_index) override {
std::vector<bst_float>* out_preds,
unsigned ntree_limit) override {
CHECK(configured_);
cpu_predictor_->PredictInstance(inst, out_preds, model_,
ntree_limit, root_index);
ntree_limit);
}
void PredictLeaf(DMatrix* p_fmat,

View File

@@ -18,8 +18,8 @@ namespace gbm {
struct GBTreeModelParam : public dmlc::Parameter<GBTreeModelParam> {
/*! \brief number of trees */
int num_trees;
/*! \brief number of roots */
int num_roots;
/*! \brief (Deprecated) number of roots */
int deprecated_num_roots;
/*! \brief number of features to be used by trees */
int num_feature;
/*! \brief pad this space, for backward compatibility reason.*/
@@ -50,8 +50,6 @@ struct GBTreeModelParam : public dmlc::Parameter<GBTreeModelParam> {
.describe(
"Number of output groups to be predicted,"
" used for multi-class classification.");
DMLC_DECLARE_FIELD(num_roots).set_lower_bound(1).set_default(1).describe(
"Tree updater sequence.");
DMLC_DECLARE_FIELD(num_feature)
.set_lower_bound(0)
.describe("Number of features used for training and prediction.");