diff --git a/R-package/R/utils.R b/R-package/R/utils.R index 1528d90bb..75482e10a 100644 --- a/R-package/R/utils.R +++ b/R-package/R/utils.R @@ -145,7 +145,7 @@ xgb.iter.update <- function(booster_handle, dtrain, iter, obj = NULL) { if (is.null(obj)) { .Call(XGBoosterUpdateOneIter_R, booster_handle, as.integer(iter), dtrain) } else { - pred <- predict(booster_handle, dtrain) + pred <- predict(booster_handle, dtrain, training = TRUE) gpair <- obj(pred, dtrain) .Call(XGBoosterBoostOneIter_R, booster_handle, dtrain, gpair$grad, gpair$hess) } diff --git a/R-package/R/xgb.Booster.R b/R-package/R/xgb.Booster.R index 3fc232edf..f18632500 100644 --- a/R-package/R/xgb.Booster.R +++ b/R-package/R/xgb.Booster.R @@ -288,7 +288,7 @@ xgb.Booster.complete <- function(object, saveraw = TRUE) { #' @export predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FALSE, ntreelimit = NULL, predleaf = FALSE, predcontrib = FALSE, approxcontrib = FALSE, predinteraction = FALSE, - reshape = FALSE, ...) { + reshape = FALSE, training = FALSE, ...) { object <- xgb.Booster.complete(object, saveraw = FALSE) if (!inherits(newdata, "xgb.DMatrix")) @@ -307,7 +307,8 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA option <- 0L + 1L * as.logical(outputmargin) + 2L * as.logical(predleaf) + 4L * as.logical(predcontrib) + 8L * as.logical(approxcontrib) + 16L * as.logical(predinteraction) - ret <- .Call(XGBoosterPredict_R, object$handle, newdata, option[1], as.integer(ntreelimit)) + ret <- .Call(XGBoosterPredict_R, object$handle, newdata, option[1], + as.integer(ntreelimit), as.integer(training)) n_ret <- length(ret) n_row <- nrow(newdata) diff --git a/R-package/src/init.c b/R-package/src/init.c index 1020d279e..82b853217 100644 --- a/R-package/src/init.c +++ b/R-package/src/init.c @@ -24,7 +24,7 @@ extern SEXP XGBoosterGetAttr_R(SEXP, SEXP); extern SEXP XGBoosterLoadModelFromRaw_R(SEXP, SEXP); extern SEXP XGBoosterLoadModel_R(SEXP, SEXP); extern SEXP XGBoosterModelToRaw_R(SEXP); -extern SEXP XGBoosterPredict_R(SEXP, SEXP, SEXP, SEXP); +extern SEXP XGBoosterPredict_R(SEXP, SEXP, SEXP, SEXP, SEXP); extern SEXP XGBoosterSaveModel_R(SEXP, SEXP); extern SEXP XGBoosterSetAttr_R(SEXP, SEXP, SEXP); extern SEXP XGBoosterSetParam_R(SEXP, SEXP, SEXP); @@ -50,7 +50,7 @@ static const R_CallMethodDef CallEntries[] = { {"XGBoosterLoadModelFromRaw_R", (DL_FUNC) &XGBoosterLoadModelFromRaw_R, 2}, {"XGBoosterLoadModel_R", (DL_FUNC) &XGBoosterLoadModel_R, 2}, {"XGBoosterModelToRaw_R", (DL_FUNC) &XGBoosterModelToRaw_R, 1}, - {"XGBoosterPredict_R", (DL_FUNC) &XGBoosterPredict_R, 4}, + {"XGBoosterPredict_R", (DL_FUNC) &XGBoosterPredict_R, 5}, {"XGBoosterSaveModel_R", (DL_FUNC) &XGBoosterSaveModel_R, 2}, {"XGBoosterSetAttr_R", (DL_FUNC) &XGBoosterSetAttr_R, 3}, {"XGBoosterSetParam_R", (DL_FUNC) &XGBoosterSetParam_R, 3}, diff --git a/R-package/src/xgboost_R.cc b/R-package/src/xgboost_R.cc index eb296dc14..c929ba204 100644 --- a/R-package/src/xgboost_R.cc +++ b/R-package/src/xgboost_R.cc @@ -295,24 +295,26 @@ SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evnames) { vec_sptr.push_back(vec_names[i].c_str()); } CHECK_CALL(XGBoosterEvalOneIter(R_ExternalPtrAddr(handle), - asInteger(iter), - BeginPtr(vec_dmats), - BeginPtr(vec_sptr), - len, &ret)); + asInteger(iter), + BeginPtr(vec_dmats), + BeginPtr(vec_sptr), + len, &ret)); R_API_END(); return mkString(ret); } -SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP option_mask, SEXP ntree_limit) { +SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP option_mask, + SEXP ntree_limit, SEXP training) { SEXP ret; R_API_BEGIN(); bst_ulong olen; const float *res; CHECK_CALL(XGBoosterPredict(R_ExternalPtrAddr(handle), - R_ExternalPtrAddr(dmat), - asInteger(option_mask), - asInteger(ntree_limit), - &olen, &res)); + R_ExternalPtrAddr(dmat), + asInteger(option_mask), + asInteger(ntree_limit), + 0, + &olen, &res)); ret = PROTECT(allocVector(REALSXP, olen)); for (size_t i = 0; i < olen; ++i) { REAL(ret)[i] = res[i]; diff --git a/R-package/src/xgboost_R.h b/R-package/src/xgboost_R.h index 272a64177..764050fd8 100644 --- a/R-package/src/xgboost_R.h +++ b/R-package/src/xgboost_R.h @@ -148,8 +148,10 @@ XGB_DLL SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evn * \param dmat data matrix * \param option_mask output_margin:1 predict_leaf:2 * \param ntree_limit limit number of trees used in prediction + * \param training Whether the prediction value is used for training. */ -XGB_DLL SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP option_mask, SEXP ntree_limit); +XGB_DLL SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP option_mask, + SEXP ntree_limit, SEXP training); /*! * \brief load model from existing file * \param handle handle diff --git a/R-package/tests/testthat/test_helpers.R b/R-package/tests/testthat/test_helpers.R index 38319ae30..a71ce4692 100644 --- a/R-package/tests/testthat/test_helpers.R +++ b/R-package/tests/testthat/test_helpers.R @@ -166,7 +166,7 @@ test_that("SHAPs sum to predictions, with or without DART", { nrounds = nrounds) pr <- function(...) - predict(fit, newdata = d, ntreelimit = nrounds, ...) + predict(fit, newdata = d, ...) pred <- pr() shap <- pr(predcontrib = T) shapi <- pr(predinteraction = T) diff --git a/doc/tutorials/dart.rst b/doc/tutorials/dart.rst index a5fc60b2c..a660e9983 100644 --- a/doc/tutorials/dart.rst +++ b/doc/tutorials/dart.rst @@ -108,12 +108,4 @@ Sample Script 'skip_drop': 0.5} num_round = 50 bst = xgb.train(param, dtrain, num_round) - # make prediction - # ntree_limit must not be 0 - preds = bst.predict(dtest, ntree_limit=num_round) - -.. note:: Specify ``ntree_limit`` when predicting with test sets - - By default, ``bst.predict()`` will perform dropouts on trees. To obtain - correct results on test sets, disable dropouts by specifying - a nonzero value for ``ntree_limit``. + preds = bst.predict(dtest) diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index 3147f1d6a..f9c0a0ffc 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -416,6 +416,7 @@ XGB_DLL int XGBoosterEvalOneIter(BoosterHandle handle, * 4:output feature contributions to individual predictions * \param ntree_limit limit number of trees used for prediction, this is only valid for boosted trees * when the parameter is set to 0, we will use all the trees + * \param training Whether the prediction value is used for training. * \param out_len used to store length of returning result * \param out_result used to set a pointer to array * \return 0 when success, -1 when failure happens @@ -424,6 +425,7 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle, DMatrixHandle dmat, int option_mask, unsigned ntree_limit, + int training, bst_ulong *out_len, const float **out_result); /* diff --git a/include/xgboost/gbm.h b/include/xgboost/gbm.h index fde8d2e0d..90645371b 100644 --- a/include/xgboost/gbm.h +++ b/include/xgboost/gbm.h @@ -84,6 +84,7 @@ class GradientBooster : public Model, public Configurable { */ virtual void PredictBatch(DMatrix* dmat, HostDeviceVector* out_preds, + bool training, unsigned ntree_limit = 0) = 0; /*! * \brief online prediction function, predict score for one instance at a time diff --git a/include/xgboost/learner.h b/include/xgboost/learner.h index 853f6bbc7..1f78382c7 100644 --- a/include/xgboost/learner.h +++ b/include/xgboost/learner.h @@ -96,6 +96,7 @@ class Learner : public Model, public Configurable, public rabit::Serializable { bool output_margin, HostDeviceVector *out_preds, unsigned ntree_limit = 0, + bool training = false, bool pred_leaf = false, bool pred_contribs = false, bool approx_contribs = false, diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp index b6b9a8377..5528b1e83 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp @@ -556,7 +556,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredict DMatrixHandle dmat = (DMatrixHandle) jdmat; bst_ulong len; float *result; - int ret = XGBoosterPredict(handle, dmat, joption_mask, (unsigned int) jntree_limit, &len, (const float **) &result); + int ret = XGBoosterPredict(handle, dmat, joption_mask, (unsigned int) jntree_limit, 0, + &len, (const float **) &result); if (len) { jsize jlen = (jsize) len; jfloatArray jarray = jenv->NewFloatArray(jlen); diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index d16dc67a6..97252db0e 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -196,7 +196,8 @@ def ctypes2numpy(cptr, length, dtype): np.uint32: ctypes.c_uint, } if dtype not in NUMPY_TO_CTYPES_MAPPING: - raise RuntimeError('Supported types: {}'.format(NUMPY_TO_CTYPES_MAPPING.keys())) + raise RuntimeError('Supported types: {}'.format( + NUMPY_TO_CTYPES_MAPPING.keys())) ctype = NUMPY_TO_CTYPES_MAPPING[dtype] if not isinstance(cptr, ctypes.POINTER(ctype)): raise RuntimeError('expected {} pointer'.format(ctype)) @@ -1275,14 +1276,16 @@ class Booster(object): """ if not isinstance(dtrain, DMatrix): - raise TypeError('invalid training matrix: {}'.format(type(dtrain).__name__)) + raise TypeError('invalid training matrix: {}'.format( + type(dtrain).__name__)) self._validate_features(dtrain) if fobj is None: - _check_call(_LIB.XGBoosterUpdateOneIter(self.handle, ctypes.c_int(iteration), + _check_call(_LIB.XGBoosterUpdateOneIter(self.handle, + ctypes.c_int(iteration), dtrain.handle)) else: - pred = self.predict(dtrain) + pred = self.predict(dtrain, training=True) grad, hess = fobj(pred, dtrain) self.boost(dtrain, grad, hess) @@ -1332,22 +1335,25 @@ class Booster(object): """ for d in evals: if not isinstance(d[0], DMatrix): - raise TypeError('expected DMatrix, got {}'.format(type(d[0]).__name__)) + raise TypeError('expected DMatrix, got {}'.format( + type(d[0]).__name__)) if not isinstance(d[1], STRING_TYPES): - raise TypeError('expected string, got {}'.format(type(d[1]).__name__)) + raise TypeError('expected string, got {}'.format( + type(d[1]).__name__)) self._validate_features(d[0]) dmats = c_array(ctypes.c_void_p, [d[0].handle for d in evals]) evnames = c_array(ctypes.c_char_p, [c_str(d[1]) for d in evals]) msg = ctypes.c_char_p() - _check_call(_LIB.XGBoosterEvalOneIter(self.handle, ctypes.c_int(iteration), + _check_call(_LIB.XGBoosterEvalOneIter(self.handle, + ctypes.c_int(iteration), dmats, evnames, c_bst_ulong(len(evals)), ctypes.byref(msg))) res = msg.value.decode() if feval is not None: for dmat, evname in evals: - feval_ret = feval(self.predict(dmat), dmat) + feval_ret = feval(self.predict(dmat, training=False), dmat) if isinstance(feval_ret, list): for name, val in feval_ret: res += '\t%s-%s:%f' % (evname, name, val) @@ -1378,27 +1384,24 @@ class Booster(object): self._validate_features(data) return self.eval_set([(data, name)], iteration) - def predict(self, data, output_margin=False, ntree_limit=0, pred_leaf=False, - pred_contribs=False, approx_contribs=False, pred_interactions=False, - validate_features=True): + def predict(self, + data, + output_margin=False, + ntree_limit=0, + pred_leaf=False, + pred_contribs=False, + approx_contribs=False, + pred_interactions=False, + validate_features=True, + training=False): """Predict with data. .. note:: This function is not thread safe. For each booster object, predict can only be called from one thread. - If you want to run prediction using multiple thread, call ``bst.copy()`` to make copies - of model object and then call ``predict()``. - - .. note:: Using ``predict()`` with DART booster - - If the booster object is DART type, ``predict()`` will perform dropouts, i.e. only - some of the trees will be evaluated. This will produce incorrect results if ``data`` is - not the training data. To obtain correct results on test sets, set ``ntree_limit`` to - a nonzero value, e.g. - - .. code-block:: python - - preds = bst.predict(dtest, ntree_limit=num_round) + If you want to run prediction using multiple thread, call + ``bst.copy()`` to make copies of model object and then call + ``predict()``. Parameters ---------- @@ -1409,38 +1412,53 @@ class Booster(object): Whether to output the raw untransformed margin value. ntree_limit : int - Limit number of trees in the prediction; defaults to 0 (use all trees). + Limit number of trees in the prediction; defaults to 0 (use all + trees). pred_leaf : bool - When this option is on, the output will be a matrix of (nsample, ntrees) - with each record indicating the predicted leaf index of each sample in each tree. - Note that the leaf index of a tree is unique per tree, so you may find leaf 1 - in both tree 1 and tree 0. + When this option is on, the output will be a matrix of (nsample, + ntrees) with each record indicating the predicted leaf index of + each sample in each tree. Note that the leaf index of a tree is + unique per tree, so you may find leaf 1 in both tree 1 and tree 0. pred_contribs : bool - When this is True the output will be a matrix of size (nsample, nfeats + 1) - with each record indicating the feature contributions (SHAP values) for that - prediction. The sum of all feature contributions is equal to the raw untransformed - margin value of the prediction. Note the final column is the bias term. + When this is True the output will be a matrix of size (nsample, + nfeats + 1) with each record indicating the feature contributions + (SHAP values) for that prediction. The sum of all feature + contributions is equal to the raw untransformed margin value of the + prediction. Note the final column is the bias term. approx_contribs : bool Approximate the contributions of each feature pred_interactions : bool - When this is True the output will be a matrix of size (nsample, nfeats + 1, nfeats + 1) - indicating the SHAP interaction values for each pair of features. The sum of each - row (or column) of the interaction values equals the corresponding SHAP value (from - pred_contribs), and the sum of the entire matrix equals the raw untransformed margin - value of the prediction. Note the last row and column correspond to the bias term. + When this is True the output will be a matrix of size (nsample, + nfeats + 1, nfeats + 1) indicating the SHAP interaction values for + each pair of features. The sum of each row (or column) of the + interaction values equals the corresponding SHAP value (from + pred_contribs), and the sum of the entire matrix equals the raw + untransformed margin value of the prediction. Note the last row and + column correspond to the bias term. validate_features : bool When this is True, validate that the Booster's and data's feature_names are identical. Otherwise, it is assumed that the feature_names are the same. + training : bool + Whether the prediction value is used for training. This can effect + `dart` booster, which performs dropouts during training iterations. + + .. note:: Using ``predict()`` with DART booster + + If the booster object is DART type, ``predict()`` will not perform + dropouts, i.e. all the trees will be evaluated. If you want to + obtain result with dropouts, provide `training=True`. + Returns ------- prediction : numpy array + """ option_mask = 0x00 if output_margin: @@ -1466,6 +1484,7 @@ class Booster(object): _check_call(_LIB.XGBoosterPredict(self.handle, data.handle, ctypes.c_int(option_mask), ctypes.c_uint(ntree_limit), + ctypes.c_int(training), ctypes.byref(length), ctypes.byref(preds))) preds = ctypes2numpy(preds, length.value, np.float32) @@ -1476,11 +1495,16 @@ class Booster(object): chunk_size = int(preds.size / nrow) if pred_interactions: - ngroup = int(chunk_size / ((data.num_col() + 1) * (data.num_col() + 1))) + ngroup = int(chunk_size / ((data.num_col() + 1) * + (data.num_col() + 1))) if ngroup == 1: - preds = preds.reshape(nrow, data.num_col() + 1, data.num_col() + 1) + preds = preds.reshape(nrow, + data.num_col() + 1, + data.num_col() + 1) else: - preds = preds.reshape(nrow, ngroup, data.num_col() + 1, data.num_col() + 1) + preds = preds.reshape(nrow, ngroup, + data.num_col() + 1, + data.num_col() + 1) elif pred_contribs: ngroup = int(chunk_size / (data.num_col() + 1)) if ngroup == 1: diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 98a0f1dc3..d71364527 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -570,10 +570,11 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle, DMatrixHandle dmat, int option_mask, unsigned ntree_limit, + int32_t training, xgboost::bst_ulong *len, const bst_float **out_result) { - std::vector&preds = - XGBAPIThreadLocalStore::Get()->ret_vec_float; + std::vector& preds = + XGBAPIThreadLocalStore::Get()->ret_vec_float; API_BEGIN(); CHECK_HANDLE(); auto *bst = static_cast(handle); @@ -582,6 +583,7 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle, static_cast*>(dmat)->get(), (option_mask & 1) != 0, &tmp_preds, ntree_limit, + static_cast(training), (option_mask & 2) != 0, (option_mask & 4) != 0, (option_mask & 8) != 0, diff --git a/src/gbm/gblinear.cc b/src/gbm/gblinear.cc index 46a1706e4..a0108e21a 100644 --- a/src/gbm/gblinear.cc +++ b/src/gbm/gblinear.cc @@ -127,6 +127,7 @@ class GBLinear : public GradientBooster { void PredictBatch(DMatrix *p_fmat, HostDeviceVector *out_preds, + bool training, unsigned ntree_limit) override { monitor_.Start("PredictBatch"); CHECK_EQ(ntree_limit, 0U) diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index f1b5065ec..07bf17f8e 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -414,17 +414,36 @@ class Dart : public GBTree { out["dart_train_param"] = toJson(dparam_); } - // predict the leaf scores with dropout if ntree_limit = 0 void PredictBatch(DMatrix* p_fmat, - HostDeviceVector* out_preds, + HostDeviceVector* p_out_preds, + bool training, unsigned ntree_limit) override { - DropTrees(ntree_limit); - PredLoopInternal(p_fmat, &out_preds->HostVector(), 0, ntree_limit, true); + DropTrees(training); + int num_group = model_.learner_model_param_->num_output_group; + ntree_limit *= num_group; + if (ntree_limit == 0 || ntree_limit > model_.trees.size()) { + ntree_limit = static_cast(model_.trees.size()); + } + size_t n = num_group * p_fmat->Info().num_row_; + const auto &base_margin = p_fmat->Info().base_margin_.ConstHostVector(); + auto& out_preds = p_out_preds->HostVector(); + out_preds.resize(n); + if (base_margin.size() != 0) { + CHECK_EQ(out_preds.size(), n); + std::copy(base_margin.begin(), base_margin.end(), out_preds.begin()); + } else { + std::fill(out_preds.begin(), out_preds.end(), + model_.learner_model_param_->base_score); + } + + PredLoopSpecalize(p_fmat, &out_preds, num_group, 0, + ntree_limit, training); } void PredictInstance(const SparsePage::Inst &inst, - std::vector *out_preds, unsigned ntree_limit) override { - DropTrees(1); + std::vector *out_preds, + unsigned ntree_limit) override { + DropTrees(false); if (thread_temp_.size() == 0) { thread_temp_.resize(1, RegTree::FVec()); thread_temp_[0].Init(model_.learner_model_param_->num_feature); @@ -465,46 +484,13 @@ class Dart : public GBTree { protected: - friend class GBTree; - // internal prediction loop - // add predictions to out_preds - template - inline void PredLoopInternal( - DMatrix* p_fmat, - std::vector* out_preds, - unsigned tree_begin, - unsigned ntree_limit, - bool init_out_preds) { - int num_group = model_.learner_model_param_->num_output_group; - ntree_limit *= num_group; - if (ntree_limit == 0 || ntree_limit > model_.trees.size()) { - ntree_limit = static_cast(model_.trees.size()); - } - - if (init_out_preds) { - size_t n = num_group * p_fmat->Info().num_row_; - const auto& base_margin = - p_fmat->Info().base_margin_.ConstHostVector(); - out_preds->resize(n); - if (base_margin.size() != 0) { - CHECK_EQ(out_preds->size(), n); - std::copy(base_margin.begin(), base_margin.end(), out_preds->begin()); - } else { - std::fill(out_preds->begin(), out_preds->end(), - model_.learner_model_param_->base_score); - } - } - PredLoopSpecalize(p_fmat, out_preds, num_group, tree_begin, - ntree_limit); - } - - template inline void PredLoopSpecalize( DMatrix* p_fmat, std::vector* out_preds, int num_group, unsigned tree_begin, - unsigned tree_end) { + unsigned tree_end, + bool training) { const int nthread = omp_get_max_threads(); CHECK_EQ(num_group, model_.learner_model_param_->num_output_group); InitThreadTemp(nthread); @@ -513,13 +499,12 @@ class Dart : public GBTree { << "size_leaf_vector is enforced to 0 so far"; CHECK_EQ(preds.size(), p_fmat->Info().num_row_ * num_group); // start collecting the prediction - auto* self = static_cast(this); for (const auto &batch : p_fmat->GetBatches()) { constexpr int kUnroll = 8; const auto nsize = static_cast(batch.Size()); const bst_omp_uint rest = nsize % kUnroll; if (nsize >= kUnroll) { - #pragma omp parallel for schedule(static) +#pragma omp parallel for schedule(static) for (bst_omp_uint i = 0; i < nsize - rest; i += kUnroll) { const int tid = omp_get_thread_num(); RegTree::FVec& feats = thread_temp_[tid]; @@ -535,7 +520,7 @@ class Dart : public GBTree { for (int gid = 0; gid < num_group; ++gid) { const size_t offset = ridx[k] * num_group + gid; preds[offset] += - self->PredValue(inst[k], gid, &feats, tree_begin, tree_end); + this->PredValue(inst[k], gid, &feats, tree_begin, tree_end); } } } @@ -548,7 +533,7 @@ class Dart : public GBTree { for (int gid = 0; gid < num_group; ++gid) { const size_t offset = ridx * num_group + gid; preds[offset] += - self->PredValue(inst, gid, + this->PredValue(inst, gid, &feats, tree_begin, tree_end); } } @@ -569,11 +554,9 @@ class Dart : public GBTree { } // predict the leaf scores without dropped trees - inline bst_float PredValue(const SparsePage::Inst &inst, - int bst_group, - RegTree::FVec *p_feats, - unsigned tree_begin, - unsigned tree_end) { + bst_float PredValue(const SparsePage::Inst &inst, int bst_group, + RegTree::FVec *p_feats, unsigned tree_begin, + unsigned tree_end) const { bst_float psum = 0.0f; p_feats->Fill(inst); for (size_t i = tree_begin; i < tree_end; ++i) { @@ -590,9 +573,12 @@ class Dart : public GBTree { } // select which trees to drop - inline void DropTrees(unsigned ntree_limit_drop) { + // passing clear=True will clear selection + inline void DropTrees(bool is_training) { idx_drop_.clear(); - if (ntree_limit_drop > 0) return; + if (!is_training) { + return; + } std::uniform_real_distribution<> runif(0.0, 1.0); auto& rnd = common::GlobalRandom(); diff --git a/src/gbm/gbtree.h b/src/gbm/gbtree.h index 09f1c4f0a..fb546969d 100644 --- a/src/gbm/gbtree.h +++ b/src/gbm/gbtree.h @@ -205,6 +205,7 @@ class GBTree : public GradientBooster { void PredictBatch(DMatrix* p_fmat, HostDeviceVector* out_preds, + bool training, unsigned ntree_limit) override { CHECK(configured_); GetPredictor(out_preds, p_fmat)->PredictBatch(p_fmat, out_preds, model_, 0, ntree_limit); diff --git a/src/learner.cc b/src/learner.cc index 7cf21bb5f..56f2932de 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -694,7 +694,7 @@ class LearnerImpl : public Learner { this->ValidateDMatrix(train); monitor_.Start("PredictRaw"); - this->PredictRaw(train, &preds_[train]); + this->PredictRaw(train, &preds_[train], true); monitor_.Stop("PredictRaw"); TrainingObserver::Instance().Observe(preds_[train], "Predictions"); @@ -735,7 +735,7 @@ class LearnerImpl : public Learner { for (size_t i = 0; i < data_sets.size(); ++i) { DMatrix * dmat = data_sets[i]; this->ValidateDMatrix(dmat); - this->PredictRaw(data_sets[i], &preds_[dmat]); + this->PredictRaw(data_sets[i], &preds_[dmat], false); obj_->EvalTransform(&preds_[dmat]); for (auto& ev : metrics_) { os << '\t' << data_names[i] << '-' << ev->Name() << ':' @@ -799,6 +799,7 @@ class LearnerImpl : public Learner { void Predict(DMatrix* data, bool output_margin, HostDeviceVector* out_preds, unsigned ntree_limit, + bool training, bool pred_leaf, bool pred_contribs, bool approx_contribs, bool pred_interactions) override { int multiple_predictions = static_cast(pred_leaf) + @@ -814,7 +815,7 @@ class LearnerImpl : public Learner { } else if (pred_leaf) { gbm_->PredictLeaf(data, &out_preds->HostVector(), ntree_limit); } else { - this->PredictRaw(data, out_preds, ntree_limit); + this->PredictRaw(data, out_preds, training, ntree_limit); if (!output_margin) { obj_->PredTransform(out_preds); } @@ -832,13 +833,15 @@ class LearnerImpl : public Learner { * \param out_preds output vector that stores the prediction * \param ntree_limit limit number of trees used for boosted tree * predictor, when it equals 0, this means we are using all the trees + * \param training allow dropout when the DART booster is being used */ void PredictRaw(DMatrix* data, HostDeviceVector* out_preds, + bool training, unsigned ntree_limit = 0) const { CHECK(gbm_ != nullptr) << "Predict must happen after Load or configuration"; this->ValidateDMatrix(data); - gbm_->PredictBatch(data, out_preds, ntree_limit); + gbm_->PredictBatch(data, out_preds, training, ntree_limit); } void ConfigureObjective(LearnerTrainParam const& old, Args* p_args) { diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index bc14f4ccd..d11822782 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -18,7 +18,7 @@ DMLC_REGISTRY_FILE_TAG(cpu_predictor); class CPUPredictor : public Predictor { protected: - static bst_float PredValue(const SparsePage::Inst& inst, + static bst_float PredValue(const SparsePage::Inst& inst, const std::vector>& trees, const std::vector& tree_info, int bst_group, RegTree::FVec* p_feats, @@ -175,13 +175,15 @@ class CPUPredictor : public Predictor { this->PredLoopInternal(dmat, &out_preds->HostVector(), model, tree_begin, ntree_limit); - auto cache_emtry = this->FindCache(dmat); - if (cache_emtry == cache_->cend()) { return; } - if (cache_emtry->second.predictions.Size() == 0) { + auto cache_entry = this->FindCache(dmat); + if (cache_entry == cache_->cend()) { + return; + } + if (cache_entry->second.predictions.Size() == 0) { // See comment in GPUPredictor::PredictBatch. - InitOutPredictions(cache_emtry->second.data->Info(), - &(cache_emtry->second.predictions), model); - cache_emtry->second.predictions.Copy(*out_preds); + InitOutPredictions(cache_entry->second.data->Info(), + &(cache_entry->second.predictions), model); + cache_entry->second.predictions.Copy(*out_preds); } } diff --git a/tests/cpp/gbm/test_gbtree.cc b/tests/cpp/gbm/test_gbtree.cc index 58c8c44f2..8533b8302 100644 --- a/tests/cpp/gbm/test_gbtree.cc +++ b/tests/cpp/gbm/test_gbtree.cc @@ -2,6 +2,8 @@ #include #include +#include "xgboost/base.h" +#include "xgboost/host_device_vector.h" #include "xgboost/learner.h" #include "../helpers.h" #include "../../../src/gbm/gbtree.h" @@ -18,7 +20,7 @@ TEST(GBTree, SelectTreeMethod) { mparam.num_output_group = 1; std::vector > caches; - std::unique_ptr p_gbm{ + std::unique_ptr p_gbm { GradientBooster::Create("gbtree", &generic_param, &mparam, caches)}; auto& gbtree = dynamic_cast (*p_gbm); @@ -175,4 +177,41 @@ TEST(Dart, Json_IO) { ASSERT_TRUE(IsA(model["model"]["gbtree"])); ASSERT_NE(get(model["model"]["weight_drop"]).size(), 0); } + +TEST(Dart, Prediction) { + size_t constexpr kRows = 16, kCols = 10; + + auto pp_dmat = CreateDMatrix(kRows, kCols, 0); + auto& p_mat = *pp_dmat; + + std::vector labels (kRows); + for (size_t i = 0; i < kRows; ++i) { + labels[i] = i % 2; + } + p_mat->Info().SetInfo("label", labels.data(), DataType::kFloat32, kRows); + + auto learner = std::unique_ptr(Learner::Create({p_mat})); + learner->SetParam("booster", "dart"); + learner->SetParam("rate_drop", "0.5"); + learner->Configure(); + + for (size_t i = 0; i < 16; ++i) { + learner->UpdateOneIter(i, p_mat.get()); + } + + HostDeviceVector predts_training; + learner->Predict(p_mat.get(), false, &predts_training, 0, true); + HostDeviceVector predts_inference; + learner->Predict(p_mat.get(), false, &predts_inference, 0, false); + + auto& h_predts_training = predts_training.ConstHostVector(); + auto& h_predts_inference = predts_inference.ConstHostVector(); + ASSERT_EQ(h_predts_training.size(), h_predts_inference.size()); + for (size_t i = 0; i < predts_inference.Size(); ++i) { + // Inference doesn't drop tree. + ASSERT_GT(std::abs(h_predts_training[i] - h_predts_inference[i]), kRtEps); + } + + delete pp_dmat; +} } // namespace xgboost diff --git a/tests/cpp/test_learner.cc b/tests/cpp/test_learner.cc index dddf2756c..1260a8327 100644 --- a/tests/cpp/test_learner.cc +++ b/tests/cpp/test_learner.cc @@ -159,7 +159,6 @@ TEST(Learner, Json_ModelIO) { { std::unique_ptr learner { Learner::Create({p_dmat}) }; - learner->SetParam("verbosity", "3"); for (int32_t iter = 0; iter < kIters; ++iter) { learner->UpdateOneIter(iter, p_dmat.get()); } diff --git a/tests/cpp/test_logging.cc b/tests/cpp/test_logging.cc index c36134ab1..6bb4291f0 100644 --- a/tests/cpp/test_logging.cc +++ b/tests/cpp/test_logging.cc @@ -54,7 +54,7 @@ TEST(Logging, Basic) { ASSERT_NE(output.find("Test Log Console"), std::string::npos); args["silent"] = "False"; - args["verbosity"] = "1"; // restore + args["verbosity"] = "2"; // restore ConsoleLogger::Configure({args.cbegin(), args.cend()}); } diff --git a/tests/python/test_basic_models.py b/tests/python/test_basic_models.py index 267078f22..bb9799eb3 100644 --- a/tests/python/test_basic_models.py +++ b/tests/python/test_basic_models.py @@ -44,7 +44,8 @@ class TestModels(unittest.TestCase): def test_dart(self): dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train') dtest = xgb.DMatrix(dpath + 'agaricus.txt.test') - param = {'max_depth': 5, 'objective': 'binary:logistic', 'booster': 'dart', 'verbosity': 1} + param = {'max_depth': 5, 'objective': 'binary:logistic', + 'eval_metric': 'logloss', 'booster': 'dart', 'verbosity': 1} # specify validations set to watch performance watchlist = [(dtest, 'eval'), (dtrain, 'train')] num_round = 2 @@ -52,7 +53,8 @@ class TestModels(unittest.TestCase): # this is prediction preds = bst.predict(dtest, ntree_limit=num_round) labels = dtest.get_label() - err = sum(1 for i in range(len(preds)) if int(preds[i] > 0.5) != labels[i]) / float(len(preds)) + err = sum(1 for i in range(len(preds)) + if int(preds[i] > 0.5) != labels[i]) / float(len(preds)) # error must be smaller than 10% assert err < 0.1 @@ -68,18 +70,31 @@ class TestModels(unittest.TestCase): # assert they are the same assert np.sum(np.abs(preds2 - preds)) == 0 + def my_logloss(preds, dtrain): + labels = dtrain.get_label() + return 'logloss', np.sum( + np.log(np.where(labels, preds, 1 - preds))) + + # check whether custom evaluation metrics work + bst = xgb.train(param, dtrain, num_round, watchlist, + feval=my_logloss) + preds3 = bst.predict(dtest, ntree_limit=num_round) + assert all(preds3 == preds) + # check whether sample_type and normalize_type work num_round = 50 param['verbosity'] = 0 param['learning_rate'] = 0.1 param['rate_drop'] = 0.1 preds_list = [] - for p in [[p0, p1] for p0 in ['uniform', 'weighted'] for p1 in ['tree', 'forest']]: + for p in [[p0, p1] for p0 in ['uniform', 'weighted'] + for p1 in ['tree', 'forest']]: param['sample_type'] = p[0] param['normalize_type'] = p[1] bst = xgb.train(param, dtrain, num_round, watchlist) preds = bst.predict(dtest, ntree_limit=num_round) - err = sum(1 for i in range(len(preds)) if int(preds[i] > 0.5) != labels[i]) / float(len(preds)) + err = sum(1 for i in range(len(preds)) + if int(preds[i] > 0.5) != labels[i]) / float(len(preds)) assert err < 0.1 preds_list.append(preds) diff --git a/tests/python/test_ranking.py b/tests/python/test_ranking.py index 51e5e18a9..23a518073 100644 --- a/tests/python/test_ranking.py +++ b/tests/python/test_ranking.py @@ -135,7 +135,7 @@ class TestRanking(unittest.TestCase): # specify validations set to watch performance watchlist = [(self.dtest, 'eval'), (self.dtrain, 'train')] bst = xgboost.train(self.params, self.dtrain, num_boost_round=2500, - early_stopping_rounds=10, evals=watchlist) + early_stopping_rounds=10, evals=watchlist) assert bst.best_score > 0.98 def test_cv(self):