From 8dfe7b3686d2cd087b79cf127ae47db4365661a8 Mon Sep 17 00:00:00 2001 From: "Jason E. Aten, Ph.D" Date: Sat, 25 Apr 2020 19:48:42 -0400 Subject: [PATCH] Clarify meaning of `training` parameter in XGBoosterPredict() (#5604) Co-authored-by: Hyunsu Cho Co-authored-by: Jiaming Yuan --- include/xgboost/c_api.h | 9 ++++++++- src/c_api/c_api.cc | 2 +- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index 237858bdc..5b335ede1 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -418,7 +418,14 @@ XGB_DLL int XGBoosterEvalOneIter(BoosterHandle handle, * 4:output feature contributions to individual predictions * \param ntree_limit limit number of trees used for prediction, this is only valid for boosted trees * when the parameter is set to 0, we will use all the trees - * \param training Whether the prediction value is used for training. + * \param training Whether the prediction function is used as part of a training loop. + * Prediction can be run in 2 scenarios: + * 1. Given data matrix X, obtain prediction y_pred from the model. + * 2. Obtain the prediction for computing gradients. For example, DART booster performs dropout + * during training, and the prediction result will be different from the one obtained by normal + * inference step due to dropped trees. + * Set training=false for the first scenario. Set training=true for the second scenario. + * The second scenario applies when you are defining a custom objective function. * \param out_len used to store length of returning result * \param out_result used to set a pointer to array * \return 0 when success, -1 when failure happens diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 754e27ef0..0ba795a39 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -426,7 +426,7 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle, DMatrixHandle dmat, int option_mask, unsigned ntree_limit, - int32_t training, + int training, xgboost::bst_ulong *len, const bst_float **out_result) { API_BEGIN();