[Breaking] Don't drop trees during DART prediction by default (#5115)
* Simplify DropTrees calling logic * Add `training` parameter for prediction method. * [Breaking]: Add `training` to C API. * Change for R and Python custom objective. * Correct comment. Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu> Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
@@ -570,10 +570,11 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
|
||||
DMatrixHandle dmat,
|
||||
int option_mask,
|
||||
unsigned ntree_limit,
|
||||
int32_t training,
|
||||
xgboost::bst_ulong *len,
|
||||
const bst_float **out_result) {
|
||||
std::vector<bst_float>&preds =
|
||||
XGBAPIThreadLocalStore::Get()->ret_vec_float;
|
||||
std::vector<bst_float>& preds =
|
||||
XGBAPIThreadLocalStore::Get()->ret_vec_float;
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto *bst = static_cast<Learner*>(handle);
|
||||
@@ -582,6 +583,7 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
|
||||
static_cast<std::shared_ptr<DMatrix>*>(dmat)->get(),
|
||||
(option_mask & 1) != 0,
|
||||
&tmp_preds, ntree_limit,
|
||||
static_cast<bool>(training),
|
||||
(option_mask & 2) != 0,
|
||||
(option_mask & 4) != 0,
|
||||
(option_mask & 8) != 0,
|
||||
|
||||
Reference in New Issue
Block a user