[Breaking] Don't drop trees during DART prediction by default (#5115)
* Simplify DropTrees calling logic * Add `training` parameter for prediction method. * [Breaking]: Add `training` to C API. * Change for R and Python custom objective. * Correct comment. Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu> Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
@@ -570,10 +570,11 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
|
||||
DMatrixHandle dmat,
|
||||
int option_mask,
|
||||
unsigned ntree_limit,
|
||||
int32_t training,
|
||||
xgboost::bst_ulong *len,
|
||||
const bst_float **out_result) {
|
||||
std::vector<bst_float>&preds =
|
||||
XGBAPIThreadLocalStore::Get()->ret_vec_float;
|
||||
std::vector<bst_float>& preds =
|
||||
XGBAPIThreadLocalStore::Get()->ret_vec_float;
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto *bst = static_cast<Learner*>(handle);
|
||||
@@ -582,6 +583,7 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
|
||||
static_cast<std::shared_ptr<DMatrix>*>(dmat)->get(),
|
||||
(option_mask & 1) != 0,
|
||||
&tmp_preds, ntree_limit,
|
||||
static_cast<bool>(training),
|
||||
(option_mask & 2) != 0,
|
||||
(option_mask & 4) != 0,
|
||||
(option_mask & 8) != 0,
|
||||
|
||||
@@ -127,6 +127,7 @@ class GBLinear : public GradientBooster {
|
||||
|
||||
void PredictBatch(DMatrix *p_fmat,
|
||||
HostDeviceVector<bst_float> *out_preds,
|
||||
bool training,
|
||||
unsigned ntree_limit) override {
|
||||
monitor_.Start("PredictBatch");
|
||||
CHECK_EQ(ntree_limit, 0U)
|
||||
|
||||
@@ -414,17 +414,36 @@ class Dart : public GBTree {
|
||||
out["dart_train_param"] = toJson(dparam_);
|
||||
}
|
||||
|
||||
// predict the leaf scores with dropout if ntree_limit = 0
|
||||
void PredictBatch(DMatrix* p_fmat,
|
||||
HostDeviceVector<bst_float>* out_preds,
|
||||
HostDeviceVector<bst_float>* p_out_preds,
|
||||
bool training,
|
||||
unsigned ntree_limit) override {
|
||||
DropTrees(ntree_limit);
|
||||
PredLoopInternal<Dart>(p_fmat, &out_preds->HostVector(), 0, ntree_limit, true);
|
||||
DropTrees(training);
|
||||
int num_group = model_.learner_model_param_->num_output_group;
|
||||
ntree_limit *= num_group;
|
||||
if (ntree_limit == 0 || ntree_limit > model_.trees.size()) {
|
||||
ntree_limit = static_cast<unsigned>(model_.trees.size());
|
||||
}
|
||||
size_t n = num_group * p_fmat->Info().num_row_;
|
||||
const auto &base_margin = p_fmat->Info().base_margin_.ConstHostVector();
|
||||
auto& out_preds = p_out_preds->HostVector();
|
||||
out_preds.resize(n);
|
||||
if (base_margin.size() != 0) {
|
||||
CHECK_EQ(out_preds.size(), n);
|
||||
std::copy(base_margin.begin(), base_margin.end(), out_preds.begin());
|
||||
} else {
|
||||
std::fill(out_preds.begin(), out_preds.end(),
|
||||
model_.learner_model_param_->base_score);
|
||||
}
|
||||
|
||||
PredLoopSpecalize(p_fmat, &out_preds, num_group, 0,
|
||||
ntree_limit, training);
|
||||
}
|
||||
|
||||
void PredictInstance(const SparsePage::Inst &inst,
|
||||
std::vector<bst_float> *out_preds, unsigned ntree_limit) override {
|
||||
DropTrees(1);
|
||||
std::vector<bst_float> *out_preds,
|
||||
unsigned ntree_limit) override {
|
||||
DropTrees(false);
|
||||
if (thread_temp_.size() == 0) {
|
||||
thread_temp_.resize(1, RegTree::FVec());
|
||||
thread_temp_[0].Init(model_.learner_model_param_->num_feature);
|
||||
@@ -465,46 +484,13 @@ class Dart : public GBTree {
|
||||
|
||||
|
||||
protected:
|
||||
friend class GBTree;
|
||||
// internal prediction loop
|
||||
// add predictions to out_preds
|
||||
template<typename Derived>
|
||||
inline void PredLoopInternal(
|
||||
DMatrix* p_fmat,
|
||||
std::vector<bst_float>* out_preds,
|
||||
unsigned tree_begin,
|
||||
unsigned ntree_limit,
|
||||
bool init_out_preds) {
|
||||
int num_group = model_.learner_model_param_->num_output_group;
|
||||
ntree_limit *= num_group;
|
||||
if (ntree_limit == 0 || ntree_limit > model_.trees.size()) {
|
||||
ntree_limit = static_cast<unsigned>(model_.trees.size());
|
||||
}
|
||||
|
||||
if (init_out_preds) {
|
||||
size_t n = num_group * p_fmat->Info().num_row_;
|
||||
const auto& base_margin =
|
||||
p_fmat->Info().base_margin_.ConstHostVector();
|
||||
out_preds->resize(n);
|
||||
if (base_margin.size() != 0) {
|
||||
CHECK_EQ(out_preds->size(), n);
|
||||
std::copy(base_margin.begin(), base_margin.end(), out_preds->begin());
|
||||
} else {
|
||||
std::fill(out_preds->begin(), out_preds->end(),
|
||||
model_.learner_model_param_->base_score);
|
||||
}
|
||||
}
|
||||
PredLoopSpecalize<Derived>(p_fmat, out_preds, num_group, tree_begin,
|
||||
ntree_limit);
|
||||
}
|
||||
|
||||
template<typename Derived>
|
||||
inline void PredLoopSpecalize(
|
||||
DMatrix* p_fmat,
|
||||
std::vector<bst_float>* out_preds,
|
||||
int num_group,
|
||||
unsigned tree_begin,
|
||||
unsigned tree_end) {
|
||||
unsigned tree_end,
|
||||
bool training) {
|
||||
const int nthread = omp_get_max_threads();
|
||||
CHECK_EQ(num_group, model_.learner_model_param_->num_output_group);
|
||||
InitThreadTemp(nthread);
|
||||
@@ -513,13 +499,12 @@ class Dart : public GBTree {
|
||||
<< "size_leaf_vector is enforced to 0 so far";
|
||||
CHECK_EQ(preds.size(), p_fmat->Info().num_row_ * num_group);
|
||||
// start collecting the prediction
|
||||
auto* self = static_cast<Derived*>(this);
|
||||
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
constexpr int kUnroll = 8;
|
||||
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
|
||||
const bst_omp_uint rest = nsize % kUnroll;
|
||||
if (nsize >= kUnroll) {
|
||||
#pragma omp parallel for schedule(static)
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < nsize - rest; i += kUnroll) {
|
||||
const int tid = omp_get_thread_num();
|
||||
RegTree::FVec& feats = thread_temp_[tid];
|
||||
@@ -535,7 +520,7 @@ class Dart : public GBTree {
|
||||
for (int gid = 0; gid < num_group; ++gid) {
|
||||
const size_t offset = ridx[k] * num_group + gid;
|
||||
preds[offset] +=
|
||||
self->PredValue(inst[k], gid, &feats, tree_begin, tree_end);
|
||||
this->PredValue(inst[k], gid, &feats, tree_begin, tree_end);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -548,7 +533,7 @@ class Dart : public GBTree {
|
||||
for (int gid = 0; gid < num_group; ++gid) {
|
||||
const size_t offset = ridx * num_group + gid;
|
||||
preds[offset] +=
|
||||
self->PredValue(inst, gid,
|
||||
this->PredValue(inst, gid,
|
||||
&feats, tree_begin, tree_end);
|
||||
}
|
||||
}
|
||||
@@ -569,11 +554,9 @@ class Dart : public GBTree {
|
||||
}
|
||||
|
||||
// predict the leaf scores without dropped trees
|
||||
inline bst_float PredValue(const SparsePage::Inst &inst,
|
||||
int bst_group,
|
||||
RegTree::FVec *p_feats,
|
||||
unsigned tree_begin,
|
||||
unsigned tree_end) {
|
||||
bst_float PredValue(const SparsePage::Inst &inst, int bst_group,
|
||||
RegTree::FVec *p_feats, unsigned tree_begin,
|
||||
unsigned tree_end) const {
|
||||
bst_float psum = 0.0f;
|
||||
p_feats->Fill(inst);
|
||||
for (size_t i = tree_begin; i < tree_end; ++i) {
|
||||
@@ -590,9 +573,12 @@ class Dart : public GBTree {
|
||||
}
|
||||
|
||||
// select which trees to drop
|
||||
inline void DropTrees(unsigned ntree_limit_drop) {
|
||||
// passing clear=True will clear selection
|
||||
inline void DropTrees(bool is_training) {
|
||||
idx_drop_.clear();
|
||||
if (ntree_limit_drop > 0) return;
|
||||
if (!is_training) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::uniform_real_distribution<> runif(0.0, 1.0);
|
||||
auto& rnd = common::GlobalRandom();
|
||||
|
||||
@@ -205,6 +205,7 @@ class GBTree : public GradientBooster {
|
||||
|
||||
void PredictBatch(DMatrix* p_fmat,
|
||||
HostDeviceVector<bst_float>* out_preds,
|
||||
bool training,
|
||||
unsigned ntree_limit) override {
|
||||
CHECK(configured_);
|
||||
GetPredictor(out_preds, p_fmat)->PredictBatch(p_fmat, out_preds, model_, 0, ntree_limit);
|
||||
|
||||
@@ -694,7 +694,7 @@ class LearnerImpl : public Learner {
|
||||
this->ValidateDMatrix(train);
|
||||
|
||||
monitor_.Start("PredictRaw");
|
||||
this->PredictRaw(train, &preds_[train]);
|
||||
this->PredictRaw(train, &preds_[train], true);
|
||||
monitor_.Stop("PredictRaw");
|
||||
TrainingObserver::Instance().Observe(preds_[train], "Predictions");
|
||||
|
||||
@@ -735,7 +735,7 @@ class LearnerImpl : public Learner {
|
||||
for (size_t i = 0; i < data_sets.size(); ++i) {
|
||||
DMatrix * dmat = data_sets[i];
|
||||
this->ValidateDMatrix(dmat);
|
||||
this->PredictRaw(data_sets[i], &preds_[dmat]);
|
||||
this->PredictRaw(data_sets[i], &preds_[dmat], false);
|
||||
obj_->EvalTransform(&preds_[dmat]);
|
||||
for (auto& ev : metrics_) {
|
||||
os << '\t' << data_names[i] << '-' << ev->Name() << ':'
|
||||
@@ -799,6 +799,7 @@ class LearnerImpl : public Learner {
|
||||
|
||||
void Predict(DMatrix* data, bool output_margin,
|
||||
HostDeviceVector<bst_float>* out_preds, unsigned ntree_limit,
|
||||
bool training,
|
||||
bool pred_leaf, bool pred_contribs, bool approx_contribs,
|
||||
bool pred_interactions) override {
|
||||
int multiple_predictions = static_cast<int>(pred_leaf) +
|
||||
@@ -814,7 +815,7 @@ class LearnerImpl : public Learner {
|
||||
} else if (pred_leaf) {
|
||||
gbm_->PredictLeaf(data, &out_preds->HostVector(), ntree_limit);
|
||||
} else {
|
||||
this->PredictRaw(data, out_preds, ntree_limit);
|
||||
this->PredictRaw(data, out_preds, training, ntree_limit);
|
||||
if (!output_margin) {
|
||||
obj_->PredTransform(out_preds);
|
||||
}
|
||||
@@ -832,13 +833,15 @@ class LearnerImpl : public Learner {
|
||||
* \param out_preds output vector that stores the prediction
|
||||
* \param ntree_limit limit number of trees used for boosted tree
|
||||
* predictor, when it equals 0, this means we are using all the trees
|
||||
* \param training allow dropout when the DART booster is being used
|
||||
*/
|
||||
void PredictRaw(DMatrix* data, HostDeviceVector<bst_float>* out_preds,
|
||||
bool training,
|
||||
unsigned ntree_limit = 0) const {
|
||||
CHECK(gbm_ != nullptr)
|
||||
<< "Predict must happen after Load or configuration";
|
||||
this->ValidateDMatrix(data);
|
||||
gbm_->PredictBatch(data, out_preds, ntree_limit);
|
||||
gbm_->PredictBatch(data, out_preds, training, ntree_limit);
|
||||
}
|
||||
|
||||
void ConfigureObjective(LearnerTrainParam const& old, Args* p_args) {
|
||||
|
||||
@@ -18,7 +18,7 @@ DMLC_REGISTRY_FILE_TAG(cpu_predictor);
|
||||
|
||||
class CPUPredictor : public Predictor {
|
||||
protected:
|
||||
static bst_float PredValue(const SparsePage::Inst& inst,
|
||||
static bst_float PredValue(const SparsePage::Inst& inst,
|
||||
const std::vector<std::unique_ptr<RegTree>>& trees,
|
||||
const std::vector<int>& tree_info, int bst_group,
|
||||
RegTree::FVec* p_feats,
|
||||
@@ -175,13 +175,15 @@ class CPUPredictor : public Predictor {
|
||||
this->PredLoopInternal(dmat, &out_preds->HostVector(), model,
|
||||
tree_begin, ntree_limit);
|
||||
|
||||
auto cache_emtry = this->FindCache(dmat);
|
||||
if (cache_emtry == cache_->cend()) { return; }
|
||||
if (cache_emtry->second.predictions.Size() == 0) {
|
||||
auto cache_entry = this->FindCache(dmat);
|
||||
if (cache_entry == cache_->cend()) {
|
||||
return;
|
||||
}
|
||||
if (cache_entry->second.predictions.Size() == 0) {
|
||||
// See comment in GPUPredictor::PredictBatch.
|
||||
InitOutPredictions(cache_emtry->second.data->Info(),
|
||||
&(cache_emtry->second.predictions), model);
|
||||
cache_emtry->second.predictions.Copy(*out_preds);
|
||||
InitOutPredictions(cache_entry->second.data->Info(),
|
||||
&(cache_entry->second.predictions), model);
|
||||
cache_entry->second.predictions.Copy(*out_preds);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user