[Breaking] Remove num roots. (#5059)
This commit is contained in:
@@ -117,8 +117,7 @@ class GBLinear : public GradientBooster {
|
||||
// add base margin
|
||||
void PredictInstance(const SparsePage::Inst &inst,
|
||||
std::vector<bst_float> *out_preds,
|
||||
unsigned ntree_limit,
|
||||
unsigned root_index) override {
|
||||
unsigned ntree_limit) override {
|
||||
const int ngroup = model_.param.num_output_group;
|
||||
for (int gid = 0; gid < ngroup; ++gid) {
|
||||
this->Pred(inst, dmlc::BeginPtr(*out_preds), gid, base_margin_);
|
||||
|
||||
@@ -322,10 +322,8 @@ class Dart : public GBTree {
|
||||
PredLoopInternal<Dart>(p_fmat, &out_preds->HostVector(), 0, ntree_limit, true);
|
||||
}
|
||||
|
||||
void PredictInstance(const SparsePage::Inst& inst,
|
||||
std::vector<bst_float>* out_preds,
|
||||
unsigned ntree_limit,
|
||||
unsigned root_index) override {
|
||||
void PredictInstance(const SparsePage::Inst &inst,
|
||||
std::vector<bst_float> *out_preds, unsigned ntree_limit) override {
|
||||
DropTrees(1);
|
||||
if (thread_temp_.size() == 0) {
|
||||
thread_temp_.resize(1, RegTree::FVec());
|
||||
@@ -338,9 +336,9 @@ class Dart : public GBTree {
|
||||
}
|
||||
// loop over output groups
|
||||
for (int gid = 0; gid < model_.param.num_output_group; ++gid) {
|
||||
(*out_preds)[gid]
|
||||
= PredValue(inst, gid, root_index,
|
||||
&thread_temp_[0], 0, ntree_limit) + model_.base_margin;
|
||||
(*out_preds)[gid] =
|
||||
PredValue(inst, gid, &thread_temp_[0], 0, ntree_limit) +
|
||||
model_.base_margin;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -411,7 +409,6 @@ class Dart : public GBTree {
|
||||
int num_group,
|
||||
unsigned tree_begin,
|
||||
unsigned tree_end) {
|
||||
const MetaInfo& info = p_fmat->Info();
|
||||
const int nthread = omp_get_max_threads();
|
||||
CHECK_EQ(num_group, model_.param.num_output_group);
|
||||
InitThreadTemp(nthread);
|
||||
@@ -442,8 +439,7 @@ class Dart : public GBTree {
|
||||
for (int gid = 0; gid < num_group; ++gid) {
|
||||
const size_t offset = ridx[k] * num_group + gid;
|
||||
preds[offset] +=
|
||||
self->PredValue(inst[k], gid, info.GetRoot(ridx[k]),
|
||||
&feats, tree_begin, tree_end);
|
||||
self->PredValue(inst[k], gid, &feats, tree_begin, tree_end);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -455,7 +451,7 @@ class Dart : public GBTree {
|
||||
for (int gid = 0; gid < num_group; ++gid) {
|
||||
const size_t offset = ridx * num_group + gid;
|
||||
preds[offset] +=
|
||||
self->PredValue(inst, gid, info.GetRoot(ridx),
|
||||
self->PredValue(inst, gid,
|
||||
&feats, tree_begin, tree_end);
|
||||
}
|
||||
}
|
||||
@@ -478,7 +474,6 @@ class Dart : public GBTree {
|
||||
// predict the leaf scores without dropped trees
|
||||
inline bst_float PredValue(const SparsePage::Inst &inst,
|
||||
int bst_group,
|
||||
unsigned root_index,
|
||||
RegTree::FVec *p_feats,
|
||||
unsigned tree_begin,
|
||||
unsigned tree_end) {
|
||||
@@ -488,7 +483,7 @@ class Dart : public GBTree {
|
||||
if (model_.tree_info[i] == bst_group) {
|
||||
bool drop = (std::binary_search(idx_drop_.begin(), idx_drop_.end(), i));
|
||||
if (!drop) {
|
||||
int tid = model_.trees[i]->GetLeafIndex(*p_feats, root_index);
|
||||
int tid = model_.trees[i]->GetLeafIndex(*p_feats);
|
||||
psum += weight_drop_[i] * (*model_.trees[i])[tid].LeafValue();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -196,12 +196,11 @@ class GBTree : public GradientBooster {
|
||||
}
|
||||
|
||||
void PredictInstance(const SparsePage::Inst& inst,
|
||||
std::vector<bst_float>* out_preds,
|
||||
unsigned ntree_limit,
|
||||
unsigned root_index) override {
|
||||
std::vector<bst_float>* out_preds,
|
||||
unsigned ntree_limit) override {
|
||||
CHECK(configured_);
|
||||
cpu_predictor_->PredictInstance(inst, out_preds, model_,
|
||||
ntree_limit, root_index);
|
||||
ntree_limit);
|
||||
}
|
||||
|
||||
void PredictLeaf(DMatrix* p_fmat,
|
||||
|
||||
@@ -18,8 +18,8 @@ namespace gbm {
|
||||
struct GBTreeModelParam : public dmlc::Parameter<GBTreeModelParam> {
|
||||
/*! \brief number of trees */
|
||||
int num_trees;
|
||||
/*! \brief number of roots */
|
||||
int num_roots;
|
||||
/*! \brief (Deprecated) number of roots */
|
||||
int deprecated_num_roots;
|
||||
/*! \brief number of features to be used by trees */
|
||||
int num_feature;
|
||||
/*! \brief pad this space, for backward compatibility reason.*/
|
||||
@@ -50,8 +50,6 @@ struct GBTreeModelParam : public dmlc::Parameter<GBTreeModelParam> {
|
||||
.describe(
|
||||
"Number of output groups to be predicted,"
|
||||
" used for multi-class classification.");
|
||||
DMLC_DECLARE_FIELD(num_roots).set_lower_bound(1).set_default(1).describe(
|
||||
"Tree updater sequence.");
|
||||
DMLC_DECLARE_FIELD(num_feature)
|
||||
.set_lower_bound(0)
|
||||
.describe("Number of features used for training and prediction.");
|
||||
|
||||
Reference in New Issue
Block a user