[Breaking] Don't drop trees during DART prediction by default (#5115)
* Simplify DropTrees calling logic * Add `training` parameter for prediction method. * [Breaking]: Add `training` to C API. * Change for R and Python custom objective. * Correct comment. Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu> Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
@@ -127,6 +127,7 @@ class GBLinear : public GradientBooster {
|
||||
|
||||
void PredictBatch(DMatrix *p_fmat,
|
||||
HostDeviceVector<bst_float> *out_preds,
|
||||
bool training,
|
||||
unsigned ntree_limit) override {
|
||||
monitor_.Start("PredictBatch");
|
||||
CHECK_EQ(ntree_limit, 0U)
|
||||
|
||||
@@ -414,17 +414,36 @@ class Dart : public GBTree {
|
||||
out["dart_train_param"] = toJson(dparam_);
|
||||
}
|
||||
|
||||
// predict the leaf scores with dropout if ntree_limit = 0
|
||||
void PredictBatch(DMatrix* p_fmat,
|
||||
HostDeviceVector<bst_float>* out_preds,
|
||||
HostDeviceVector<bst_float>* p_out_preds,
|
||||
bool training,
|
||||
unsigned ntree_limit) override {
|
||||
DropTrees(ntree_limit);
|
||||
PredLoopInternal<Dart>(p_fmat, &out_preds->HostVector(), 0, ntree_limit, true);
|
||||
DropTrees(training);
|
||||
int num_group = model_.learner_model_param_->num_output_group;
|
||||
ntree_limit *= num_group;
|
||||
if (ntree_limit == 0 || ntree_limit > model_.trees.size()) {
|
||||
ntree_limit = static_cast<unsigned>(model_.trees.size());
|
||||
}
|
||||
size_t n = num_group * p_fmat->Info().num_row_;
|
||||
const auto &base_margin = p_fmat->Info().base_margin_.ConstHostVector();
|
||||
auto& out_preds = p_out_preds->HostVector();
|
||||
out_preds.resize(n);
|
||||
if (base_margin.size() != 0) {
|
||||
CHECK_EQ(out_preds.size(), n);
|
||||
std::copy(base_margin.begin(), base_margin.end(), out_preds.begin());
|
||||
} else {
|
||||
std::fill(out_preds.begin(), out_preds.end(),
|
||||
model_.learner_model_param_->base_score);
|
||||
}
|
||||
|
||||
PredLoopSpecalize(p_fmat, &out_preds, num_group, 0,
|
||||
ntree_limit, training);
|
||||
}
|
||||
|
||||
void PredictInstance(const SparsePage::Inst &inst,
|
||||
std::vector<bst_float> *out_preds, unsigned ntree_limit) override {
|
||||
DropTrees(1);
|
||||
std::vector<bst_float> *out_preds,
|
||||
unsigned ntree_limit) override {
|
||||
DropTrees(false);
|
||||
if (thread_temp_.size() == 0) {
|
||||
thread_temp_.resize(1, RegTree::FVec());
|
||||
thread_temp_[0].Init(model_.learner_model_param_->num_feature);
|
||||
@@ -465,46 +484,13 @@ class Dart : public GBTree {
|
||||
|
||||
|
||||
protected:
|
||||
friend class GBTree;
|
||||
// internal prediction loop
|
||||
// add predictions to out_preds
|
||||
template<typename Derived>
|
||||
inline void PredLoopInternal(
|
||||
DMatrix* p_fmat,
|
||||
std::vector<bst_float>* out_preds,
|
||||
unsigned tree_begin,
|
||||
unsigned ntree_limit,
|
||||
bool init_out_preds) {
|
||||
int num_group = model_.learner_model_param_->num_output_group;
|
||||
ntree_limit *= num_group;
|
||||
if (ntree_limit == 0 || ntree_limit > model_.trees.size()) {
|
||||
ntree_limit = static_cast<unsigned>(model_.trees.size());
|
||||
}
|
||||
|
||||
if (init_out_preds) {
|
||||
size_t n = num_group * p_fmat->Info().num_row_;
|
||||
const auto& base_margin =
|
||||
p_fmat->Info().base_margin_.ConstHostVector();
|
||||
out_preds->resize(n);
|
||||
if (base_margin.size() != 0) {
|
||||
CHECK_EQ(out_preds->size(), n);
|
||||
std::copy(base_margin.begin(), base_margin.end(), out_preds->begin());
|
||||
} else {
|
||||
std::fill(out_preds->begin(), out_preds->end(),
|
||||
model_.learner_model_param_->base_score);
|
||||
}
|
||||
}
|
||||
PredLoopSpecalize<Derived>(p_fmat, out_preds, num_group, tree_begin,
|
||||
ntree_limit);
|
||||
}
|
||||
|
||||
template<typename Derived>
|
||||
inline void PredLoopSpecalize(
|
||||
DMatrix* p_fmat,
|
||||
std::vector<bst_float>* out_preds,
|
||||
int num_group,
|
||||
unsigned tree_begin,
|
||||
unsigned tree_end) {
|
||||
unsigned tree_end,
|
||||
bool training) {
|
||||
const int nthread = omp_get_max_threads();
|
||||
CHECK_EQ(num_group, model_.learner_model_param_->num_output_group);
|
||||
InitThreadTemp(nthread);
|
||||
@@ -513,13 +499,12 @@ class Dart : public GBTree {
|
||||
<< "size_leaf_vector is enforced to 0 so far";
|
||||
CHECK_EQ(preds.size(), p_fmat->Info().num_row_ * num_group);
|
||||
// start collecting the prediction
|
||||
auto* self = static_cast<Derived*>(this);
|
||||
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
constexpr int kUnroll = 8;
|
||||
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
|
||||
const bst_omp_uint rest = nsize % kUnroll;
|
||||
if (nsize >= kUnroll) {
|
||||
#pragma omp parallel for schedule(static)
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < nsize - rest; i += kUnroll) {
|
||||
const int tid = omp_get_thread_num();
|
||||
RegTree::FVec& feats = thread_temp_[tid];
|
||||
@@ -535,7 +520,7 @@ class Dart : public GBTree {
|
||||
for (int gid = 0; gid < num_group; ++gid) {
|
||||
const size_t offset = ridx[k] * num_group + gid;
|
||||
preds[offset] +=
|
||||
self->PredValue(inst[k], gid, &feats, tree_begin, tree_end);
|
||||
this->PredValue(inst[k], gid, &feats, tree_begin, tree_end);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -548,7 +533,7 @@ class Dart : public GBTree {
|
||||
for (int gid = 0; gid < num_group; ++gid) {
|
||||
const size_t offset = ridx * num_group + gid;
|
||||
preds[offset] +=
|
||||
self->PredValue(inst, gid,
|
||||
this->PredValue(inst, gid,
|
||||
&feats, tree_begin, tree_end);
|
||||
}
|
||||
}
|
||||
@@ -569,11 +554,9 @@ class Dart : public GBTree {
|
||||
}
|
||||
|
||||
// predict the leaf scores without dropped trees
|
||||
inline bst_float PredValue(const SparsePage::Inst &inst,
|
||||
int bst_group,
|
||||
RegTree::FVec *p_feats,
|
||||
unsigned tree_begin,
|
||||
unsigned tree_end) {
|
||||
bst_float PredValue(const SparsePage::Inst &inst, int bst_group,
|
||||
RegTree::FVec *p_feats, unsigned tree_begin,
|
||||
unsigned tree_end) const {
|
||||
bst_float psum = 0.0f;
|
||||
p_feats->Fill(inst);
|
||||
for (size_t i = tree_begin; i < tree_end; ++i) {
|
||||
@@ -590,9 +573,12 @@ class Dart : public GBTree {
|
||||
}
|
||||
|
||||
// select which trees to drop
|
||||
inline void DropTrees(unsigned ntree_limit_drop) {
|
||||
// passing clear=True will clear selection
|
||||
inline void DropTrees(bool is_training) {
|
||||
idx_drop_.clear();
|
||||
if (ntree_limit_drop > 0) return;
|
||||
if (!is_training) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::uniform_real_distribution<> runif(0.0, 1.0);
|
||||
auto& rnd = common::GlobalRandom();
|
||||
|
||||
@@ -205,6 +205,7 @@ class GBTree : public GradientBooster {
|
||||
|
||||
void PredictBatch(DMatrix* p_fmat,
|
||||
HostDeviceVector<bst_float>* out_preds,
|
||||
bool training,
|
||||
unsigned ntree_limit) override {
|
||||
CHECK(configured_);
|
||||
GetPredictor(out_preds, p_fmat)->PredictBatch(p_fmat, out_preds, model_, 0, ntree_limit);
|
||||
|
||||
Reference in New Issue
Block a user