[Breaking] Don't drop trees during DART prediction by default (#5115)
* Simplify DropTrees calling logic * Add `training` parameter for prediction method. * [Breaking]: Add `training` to C API. * Change for R and Python custom objective. * Correct comment. Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu> Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
parent
7b65698187
commit
f100b8d878
@ -145,7 +145,7 @@ xgb.iter.update <- function(booster_handle, dtrain, iter, obj = NULL) {
|
||||
if (is.null(obj)) {
|
||||
.Call(XGBoosterUpdateOneIter_R, booster_handle, as.integer(iter), dtrain)
|
||||
} else {
|
||||
pred <- predict(booster_handle, dtrain)
|
||||
pred <- predict(booster_handle, dtrain, training = TRUE)
|
||||
gpair <- obj(pred, dtrain)
|
||||
.Call(XGBoosterBoostOneIter_R, booster_handle, dtrain, gpair$grad, gpair$hess)
|
||||
}
|
||||
|
||||
@ -288,7 +288,7 @@ xgb.Booster.complete <- function(object, saveraw = TRUE) {
|
||||
#' @export
|
||||
predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FALSE, ntreelimit = NULL,
|
||||
predleaf = FALSE, predcontrib = FALSE, approxcontrib = FALSE, predinteraction = FALSE,
|
||||
reshape = FALSE, ...) {
|
||||
reshape = FALSE, training = FALSE, ...) {
|
||||
|
||||
object <- xgb.Booster.complete(object, saveraw = FALSE)
|
||||
if (!inherits(newdata, "xgb.DMatrix"))
|
||||
@ -307,7 +307,8 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
|
||||
option <- 0L + 1L * as.logical(outputmargin) + 2L * as.logical(predleaf) + 4L * as.logical(predcontrib) +
|
||||
8L * as.logical(approxcontrib) + 16L * as.logical(predinteraction)
|
||||
|
||||
ret <- .Call(XGBoosterPredict_R, object$handle, newdata, option[1], as.integer(ntreelimit))
|
||||
ret <- .Call(XGBoosterPredict_R, object$handle, newdata, option[1],
|
||||
as.integer(ntreelimit), as.integer(training))
|
||||
|
||||
n_ret <- length(ret)
|
||||
n_row <- nrow(newdata)
|
||||
|
||||
@ -24,7 +24,7 @@ extern SEXP XGBoosterGetAttr_R(SEXP, SEXP);
|
||||
extern SEXP XGBoosterLoadModelFromRaw_R(SEXP, SEXP);
|
||||
extern SEXP XGBoosterLoadModel_R(SEXP, SEXP);
|
||||
extern SEXP XGBoosterModelToRaw_R(SEXP);
|
||||
extern SEXP XGBoosterPredict_R(SEXP, SEXP, SEXP, SEXP);
|
||||
extern SEXP XGBoosterPredict_R(SEXP, SEXP, SEXP, SEXP, SEXP);
|
||||
extern SEXP XGBoosterSaveModel_R(SEXP, SEXP);
|
||||
extern SEXP XGBoosterSetAttr_R(SEXP, SEXP, SEXP);
|
||||
extern SEXP XGBoosterSetParam_R(SEXP, SEXP, SEXP);
|
||||
@ -50,7 +50,7 @@ static const R_CallMethodDef CallEntries[] = {
|
||||
{"XGBoosterLoadModelFromRaw_R", (DL_FUNC) &XGBoosterLoadModelFromRaw_R, 2},
|
||||
{"XGBoosterLoadModel_R", (DL_FUNC) &XGBoosterLoadModel_R, 2},
|
||||
{"XGBoosterModelToRaw_R", (DL_FUNC) &XGBoosterModelToRaw_R, 1},
|
||||
{"XGBoosterPredict_R", (DL_FUNC) &XGBoosterPredict_R, 4},
|
||||
{"XGBoosterPredict_R", (DL_FUNC) &XGBoosterPredict_R, 5},
|
||||
{"XGBoosterSaveModel_R", (DL_FUNC) &XGBoosterSaveModel_R, 2},
|
||||
{"XGBoosterSetAttr_R", (DL_FUNC) &XGBoosterSetAttr_R, 3},
|
||||
{"XGBoosterSetParam_R", (DL_FUNC) &XGBoosterSetParam_R, 3},
|
||||
|
||||
@ -295,24 +295,26 @@ SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evnames) {
|
||||
vec_sptr.push_back(vec_names[i].c_str());
|
||||
}
|
||||
CHECK_CALL(XGBoosterEvalOneIter(R_ExternalPtrAddr(handle),
|
||||
asInteger(iter),
|
||||
BeginPtr(vec_dmats),
|
||||
BeginPtr(vec_sptr),
|
||||
len, &ret));
|
||||
asInteger(iter),
|
||||
BeginPtr(vec_dmats),
|
||||
BeginPtr(vec_sptr),
|
||||
len, &ret));
|
||||
R_API_END();
|
||||
return mkString(ret);
|
||||
}
|
||||
|
||||
SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP option_mask, SEXP ntree_limit) {
|
||||
SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP option_mask,
|
||||
SEXP ntree_limit, SEXP training) {
|
||||
SEXP ret;
|
||||
R_API_BEGIN();
|
||||
bst_ulong olen;
|
||||
const float *res;
|
||||
CHECK_CALL(XGBoosterPredict(R_ExternalPtrAddr(handle),
|
||||
R_ExternalPtrAddr(dmat),
|
||||
asInteger(option_mask),
|
||||
asInteger(ntree_limit),
|
||||
&olen, &res));
|
||||
R_ExternalPtrAddr(dmat),
|
||||
asInteger(option_mask),
|
||||
asInteger(ntree_limit),
|
||||
0,
|
||||
&olen, &res));
|
||||
ret = PROTECT(allocVector(REALSXP, olen));
|
||||
for (size_t i = 0; i < olen; ++i) {
|
||||
REAL(ret)[i] = res[i];
|
||||
|
||||
@ -148,8 +148,10 @@ XGB_DLL SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evn
|
||||
* \param dmat data matrix
|
||||
* \param option_mask output_margin:1 predict_leaf:2
|
||||
* \param ntree_limit limit number of trees used in prediction
|
||||
* \param training Whether the prediction value is used for training.
|
||||
*/
|
||||
XGB_DLL SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP option_mask, SEXP ntree_limit);
|
||||
XGB_DLL SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP option_mask,
|
||||
SEXP ntree_limit, SEXP training);
|
||||
/*!
|
||||
* \brief load model from existing file
|
||||
* \param handle handle
|
||||
|
||||
@ -166,7 +166,7 @@ test_that("SHAPs sum to predictions, with or without DART", {
|
||||
nrounds = nrounds)
|
||||
|
||||
pr <- function(...)
|
||||
predict(fit, newdata = d, ntreelimit = nrounds, ...)
|
||||
predict(fit, newdata = d, ...)
|
||||
pred <- pr()
|
||||
shap <- pr(predcontrib = T)
|
||||
shapi <- pr(predinteraction = T)
|
||||
|
||||
@ -108,12 +108,4 @@ Sample Script
|
||||
'skip_drop': 0.5}
|
||||
num_round = 50
|
||||
bst = xgb.train(param, dtrain, num_round)
|
||||
# make prediction
|
||||
# ntree_limit must not be 0
|
||||
preds = bst.predict(dtest, ntree_limit=num_round)
|
||||
|
||||
.. note:: Specify ``ntree_limit`` when predicting with test sets
|
||||
|
||||
By default, ``bst.predict()`` will perform dropouts on trees. To obtain
|
||||
correct results on test sets, disable dropouts by specifying
|
||||
a nonzero value for ``ntree_limit``.
|
||||
preds = bst.predict(dtest)
|
||||
|
||||
@ -416,6 +416,7 @@ XGB_DLL int XGBoosterEvalOneIter(BoosterHandle handle,
|
||||
* 4:output feature contributions to individual predictions
|
||||
* \param ntree_limit limit number of trees used for prediction, this is only valid for boosted trees
|
||||
* when the parameter is set to 0, we will use all the trees
|
||||
* \param training Whether the prediction value is used for training.
|
||||
* \param out_len used to store length of returning result
|
||||
* \param out_result used to set a pointer to array
|
||||
* \return 0 when success, -1 when failure happens
|
||||
@ -424,6 +425,7 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
|
||||
DMatrixHandle dmat,
|
||||
int option_mask,
|
||||
unsigned ntree_limit,
|
||||
int training,
|
||||
bst_ulong *out_len,
|
||||
const float **out_result);
|
||||
/*
|
||||
|
||||
@ -84,6 +84,7 @@ class GradientBooster : public Model, public Configurable {
|
||||
*/
|
||||
virtual void PredictBatch(DMatrix* dmat,
|
||||
HostDeviceVector<bst_float>* out_preds,
|
||||
bool training,
|
||||
unsigned ntree_limit = 0) = 0;
|
||||
/*!
|
||||
* \brief online prediction function, predict score for one instance at a time
|
||||
|
||||
@ -96,6 +96,7 @@ class Learner : public Model, public Configurable, public rabit::Serializable {
|
||||
bool output_margin,
|
||||
HostDeviceVector<bst_float> *out_preds,
|
||||
unsigned ntree_limit = 0,
|
||||
bool training = false,
|
||||
bool pred_leaf = false,
|
||||
bool pred_contribs = false,
|
||||
bool approx_contribs = false,
|
||||
|
||||
@ -556,7 +556,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredict
|
||||
DMatrixHandle dmat = (DMatrixHandle) jdmat;
|
||||
bst_ulong len;
|
||||
float *result;
|
||||
int ret = XGBoosterPredict(handle, dmat, joption_mask, (unsigned int) jntree_limit, &len, (const float **) &result);
|
||||
int ret = XGBoosterPredict(handle, dmat, joption_mask, (unsigned int) jntree_limit, 0,
|
||||
&len, (const float **) &result);
|
||||
if (len) {
|
||||
jsize jlen = (jsize) len;
|
||||
jfloatArray jarray = jenv->NewFloatArray(jlen);
|
||||
|
||||
@ -196,7 +196,8 @@ def ctypes2numpy(cptr, length, dtype):
|
||||
np.uint32: ctypes.c_uint,
|
||||
}
|
||||
if dtype not in NUMPY_TO_CTYPES_MAPPING:
|
||||
raise RuntimeError('Supported types: {}'.format(NUMPY_TO_CTYPES_MAPPING.keys()))
|
||||
raise RuntimeError('Supported types: {}'.format(
|
||||
NUMPY_TO_CTYPES_MAPPING.keys()))
|
||||
ctype = NUMPY_TO_CTYPES_MAPPING[dtype]
|
||||
if not isinstance(cptr, ctypes.POINTER(ctype)):
|
||||
raise RuntimeError('expected {} pointer'.format(ctype))
|
||||
@ -1275,14 +1276,16 @@ class Booster(object):
|
||||
|
||||
"""
|
||||
if not isinstance(dtrain, DMatrix):
|
||||
raise TypeError('invalid training matrix: {}'.format(type(dtrain).__name__))
|
||||
raise TypeError('invalid training matrix: {}'.format(
|
||||
type(dtrain).__name__))
|
||||
self._validate_features(dtrain)
|
||||
|
||||
if fobj is None:
|
||||
_check_call(_LIB.XGBoosterUpdateOneIter(self.handle, ctypes.c_int(iteration),
|
||||
_check_call(_LIB.XGBoosterUpdateOneIter(self.handle,
|
||||
ctypes.c_int(iteration),
|
||||
dtrain.handle))
|
||||
else:
|
||||
pred = self.predict(dtrain)
|
||||
pred = self.predict(dtrain, training=True)
|
||||
grad, hess = fobj(pred, dtrain)
|
||||
self.boost(dtrain, grad, hess)
|
||||
|
||||
@ -1332,22 +1335,25 @@ class Booster(object):
|
||||
"""
|
||||
for d in evals:
|
||||
if not isinstance(d[0], DMatrix):
|
||||
raise TypeError('expected DMatrix, got {}'.format(type(d[0]).__name__))
|
||||
raise TypeError('expected DMatrix, got {}'.format(
|
||||
type(d[0]).__name__))
|
||||
if not isinstance(d[1], STRING_TYPES):
|
||||
raise TypeError('expected string, got {}'.format(type(d[1]).__name__))
|
||||
raise TypeError('expected string, got {}'.format(
|
||||
type(d[1]).__name__))
|
||||
self._validate_features(d[0])
|
||||
|
||||
dmats = c_array(ctypes.c_void_p, [d[0].handle for d in evals])
|
||||
evnames = c_array(ctypes.c_char_p, [c_str(d[1]) for d in evals])
|
||||
msg = ctypes.c_char_p()
|
||||
_check_call(_LIB.XGBoosterEvalOneIter(self.handle, ctypes.c_int(iteration),
|
||||
_check_call(_LIB.XGBoosterEvalOneIter(self.handle,
|
||||
ctypes.c_int(iteration),
|
||||
dmats, evnames,
|
||||
c_bst_ulong(len(evals)),
|
||||
ctypes.byref(msg)))
|
||||
res = msg.value.decode()
|
||||
if feval is not None:
|
||||
for dmat, evname in evals:
|
||||
feval_ret = feval(self.predict(dmat), dmat)
|
||||
feval_ret = feval(self.predict(dmat, training=False), dmat)
|
||||
if isinstance(feval_ret, list):
|
||||
for name, val in feval_ret:
|
||||
res += '\t%s-%s:%f' % (evname, name, val)
|
||||
@ -1378,27 +1384,24 @@ class Booster(object):
|
||||
self._validate_features(data)
|
||||
return self.eval_set([(data, name)], iteration)
|
||||
|
||||
def predict(self, data, output_margin=False, ntree_limit=0, pred_leaf=False,
|
||||
pred_contribs=False, approx_contribs=False, pred_interactions=False,
|
||||
validate_features=True):
|
||||
def predict(self,
|
||||
data,
|
||||
output_margin=False,
|
||||
ntree_limit=0,
|
||||
pred_leaf=False,
|
||||
pred_contribs=False,
|
||||
approx_contribs=False,
|
||||
pred_interactions=False,
|
||||
validate_features=True,
|
||||
training=False):
|
||||
"""Predict with data.
|
||||
|
||||
.. note:: This function is not thread safe.
|
||||
|
||||
For each booster object, predict can only be called from one thread.
|
||||
If you want to run prediction using multiple thread, call ``bst.copy()`` to make copies
|
||||
of model object and then call ``predict()``.
|
||||
|
||||
.. note:: Using ``predict()`` with DART booster
|
||||
|
||||
If the booster object is DART type, ``predict()`` will perform dropouts, i.e. only
|
||||
some of the trees will be evaluated. This will produce incorrect results if ``data`` is
|
||||
not the training data. To obtain correct results on test sets, set ``ntree_limit`` to
|
||||
a nonzero value, e.g.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
preds = bst.predict(dtest, ntree_limit=num_round)
|
||||
If you want to run prediction using multiple thread, call
|
||||
``bst.copy()`` to make copies of model object and then call
|
||||
``predict()``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@ -1409,38 +1412,53 @@ class Booster(object):
|
||||
Whether to output the raw untransformed margin value.
|
||||
|
||||
ntree_limit : int
|
||||
Limit number of trees in the prediction; defaults to 0 (use all trees).
|
||||
Limit number of trees in the prediction; defaults to 0 (use all
|
||||
trees).
|
||||
|
||||
pred_leaf : bool
|
||||
When this option is on, the output will be a matrix of (nsample, ntrees)
|
||||
with each record indicating the predicted leaf index of each sample in each tree.
|
||||
Note that the leaf index of a tree is unique per tree, so you may find leaf 1
|
||||
in both tree 1 and tree 0.
|
||||
When this option is on, the output will be a matrix of (nsample,
|
||||
ntrees) with each record indicating the predicted leaf index of
|
||||
each sample in each tree. Note that the leaf index of a tree is
|
||||
unique per tree, so you may find leaf 1 in both tree 1 and tree 0.
|
||||
|
||||
pred_contribs : bool
|
||||
When this is True the output will be a matrix of size (nsample, nfeats + 1)
|
||||
with each record indicating the feature contributions (SHAP values) for that
|
||||
prediction. The sum of all feature contributions is equal to the raw untransformed
|
||||
margin value of the prediction. Note the final column is the bias term.
|
||||
When this is True the output will be a matrix of size (nsample,
|
||||
nfeats + 1) with each record indicating the feature contributions
|
||||
(SHAP values) for that prediction. The sum of all feature
|
||||
contributions is equal to the raw untransformed margin value of the
|
||||
prediction. Note the final column is the bias term.
|
||||
|
||||
approx_contribs : bool
|
||||
Approximate the contributions of each feature
|
||||
|
||||
pred_interactions : bool
|
||||
When this is True the output will be a matrix of size (nsample, nfeats + 1, nfeats + 1)
|
||||
indicating the SHAP interaction values for each pair of features. The sum of each
|
||||
row (or column) of the interaction values equals the corresponding SHAP value (from
|
||||
pred_contribs), and the sum of the entire matrix equals the raw untransformed margin
|
||||
value of the prediction. Note the last row and column correspond to the bias term.
|
||||
When this is True the output will be a matrix of size (nsample,
|
||||
nfeats + 1, nfeats + 1) indicating the SHAP interaction values for
|
||||
each pair of features. The sum of each row (or column) of the
|
||||
interaction values equals the corresponding SHAP value (from
|
||||
pred_contribs), and the sum of the entire matrix equals the raw
|
||||
untransformed margin value of the prediction. Note the last row and
|
||||
column correspond to the bias term.
|
||||
|
||||
validate_features : bool
|
||||
When this is True, validate that the Booster's and data's
|
||||
feature_names are identical. Otherwise, it is assumed that the
|
||||
feature_names are the same.
|
||||
|
||||
training : bool
|
||||
Whether the prediction value is used for training. This can effect
|
||||
`dart` booster, which performs dropouts during training iterations.
|
||||
|
||||
.. note:: Using ``predict()`` with DART booster
|
||||
|
||||
If the booster object is DART type, ``predict()`` will not perform
|
||||
dropouts, i.e. all the trees will be evaluated. If you want to
|
||||
obtain result with dropouts, provide `training=True`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
prediction : numpy array
|
||||
|
||||
"""
|
||||
option_mask = 0x00
|
||||
if output_margin:
|
||||
@ -1466,6 +1484,7 @@ class Booster(object):
|
||||
_check_call(_LIB.XGBoosterPredict(self.handle, data.handle,
|
||||
ctypes.c_int(option_mask),
|
||||
ctypes.c_uint(ntree_limit),
|
||||
ctypes.c_int(training),
|
||||
ctypes.byref(length),
|
||||
ctypes.byref(preds)))
|
||||
preds = ctypes2numpy(preds, length.value, np.float32)
|
||||
@ -1476,11 +1495,16 @@ class Booster(object):
|
||||
chunk_size = int(preds.size / nrow)
|
||||
|
||||
if pred_interactions:
|
||||
ngroup = int(chunk_size / ((data.num_col() + 1) * (data.num_col() + 1)))
|
||||
ngroup = int(chunk_size / ((data.num_col() + 1) *
|
||||
(data.num_col() + 1)))
|
||||
if ngroup == 1:
|
||||
preds = preds.reshape(nrow, data.num_col() + 1, data.num_col() + 1)
|
||||
preds = preds.reshape(nrow,
|
||||
data.num_col() + 1,
|
||||
data.num_col() + 1)
|
||||
else:
|
||||
preds = preds.reshape(nrow, ngroup, data.num_col() + 1, data.num_col() + 1)
|
||||
preds = preds.reshape(nrow, ngroup,
|
||||
data.num_col() + 1,
|
||||
data.num_col() + 1)
|
||||
elif pred_contribs:
|
||||
ngroup = int(chunk_size / (data.num_col() + 1))
|
||||
if ngroup == 1:
|
||||
|
||||
@ -570,10 +570,11 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
|
||||
DMatrixHandle dmat,
|
||||
int option_mask,
|
||||
unsigned ntree_limit,
|
||||
int32_t training,
|
||||
xgboost::bst_ulong *len,
|
||||
const bst_float **out_result) {
|
||||
std::vector<bst_float>&preds =
|
||||
XGBAPIThreadLocalStore::Get()->ret_vec_float;
|
||||
std::vector<bst_float>& preds =
|
||||
XGBAPIThreadLocalStore::Get()->ret_vec_float;
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto *bst = static_cast<Learner*>(handle);
|
||||
@ -582,6 +583,7 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
|
||||
static_cast<std::shared_ptr<DMatrix>*>(dmat)->get(),
|
||||
(option_mask & 1) != 0,
|
||||
&tmp_preds, ntree_limit,
|
||||
static_cast<bool>(training),
|
||||
(option_mask & 2) != 0,
|
||||
(option_mask & 4) != 0,
|
||||
(option_mask & 8) != 0,
|
||||
|
||||
@ -127,6 +127,7 @@ class GBLinear : public GradientBooster {
|
||||
|
||||
void PredictBatch(DMatrix *p_fmat,
|
||||
HostDeviceVector<bst_float> *out_preds,
|
||||
bool training,
|
||||
unsigned ntree_limit) override {
|
||||
monitor_.Start("PredictBatch");
|
||||
CHECK_EQ(ntree_limit, 0U)
|
||||
|
||||
@ -414,17 +414,36 @@ class Dart : public GBTree {
|
||||
out["dart_train_param"] = toJson(dparam_);
|
||||
}
|
||||
|
||||
// predict the leaf scores with dropout if ntree_limit = 0
|
||||
void PredictBatch(DMatrix* p_fmat,
|
||||
HostDeviceVector<bst_float>* out_preds,
|
||||
HostDeviceVector<bst_float>* p_out_preds,
|
||||
bool training,
|
||||
unsigned ntree_limit) override {
|
||||
DropTrees(ntree_limit);
|
||||
PredLoopInternal<Dart>(p_fmat, &out_preds->HostVector(), 0, ntree_limit, true);
|
||||
DropTrees(training);
|
||||
int num_group = model_.learner_model_param_->num_output_group;
|
||||
ntree_limit *= num_group;
|
||||
if (ntree_limit == 0 || ntree_limit > model_.trees.size()) {
|
||||
ntree_limit = static_cast<unsigned>(model_.trees.size());
|
||||
}
|
||||
size_t n = num_group * p_fmat->Info().num_row_;
|
||||
const auto &base_margin = p_fmat->Info().base_margin_.ConstHostVector();
|
||||
auto& out_preds = p_out_preds->HostVector();
|
||||
out_preds.resize(n);
|
||||
if (base_margin.size() != 0) {
|
||||
CHECK_EQ(out_preds.size(), n);
|
||||
std::copy(base_margin.begin(), base_margin.end(), out_preds.begin());
|
||||
} else {
|
||||
std::fill(out_preds.begin(), out_preds.end(),
|
||||
model_.learner_model_param_->base_score);
|
||||
}
|
||||
|
||||
PredLoopSpecalize(p_fmat, &out_preds, num_group, 0,
|
||||
ntree_limit, training);
|
||||
}
|
||||
|
||||
void PredictInstance(const SparsePage::Inst &inst,
|
||||
std::vector<bst_float> *out_preds, unsigned ntree_limit) override {
|
||||
DropTrees(1);
|
||||
std::vector<bst_float> *out_preds,
|
||||
unsigned ntree_limit) override {
|
||||
DropTrees(false);
|
||||
if (thread_temp_.size() == 0) {
|
||||
thread_temp_.resize(1, RegTree::FVec());
|
||||
thread_temp_[0].Init(model_.learner_model_param_->num_feature);
|
||||
@ -465,46 +484,13 @@ class Dart : public GBTree {
|
||||
|
||||
|
||||
protected:
|
||||
friend class GBTree;
|
||||
// internal prediction loop
|
||||
// add predictions to out_preds
|
||||
template<typename Derived>
|
||||
inline void PredLoopInternal(
|
||||
DMatrix* p_fmat,
|
||||
std::vector<bst_float>* out_preds,
|
||||
unsigned tree_begin,
|
||||
unsigned ntree_limit,
|
||||
bool init_out_preds) {
|
||||
int num_group = model_.learner_model_param_->num_output_group;
|
||||
ntree_limit *= num_group;
|
||||
if (ntree_limit == 0 || ntree_limit > model_.trees.size()) {
|
||||
ntree_limit = static_cast<unsigned>(model_.trees.size());
|
||||
}
|
||||
|
||||
if (init_out_preds) {
|
||||
size_t n = num_group * p_fmat->Info().num_row_;
|
||||
const auto& base_margin =
|
||||
p_fmat->Info().base_margin_.ConstHostVector();
|
||||
out_preds->resize(n);
|
||||
if (base_margin.size() != 0) {
|
||||
CHECK_EQ(out_preds->size(), n);
|
||||
std::copy(base_margin.begin(), base_margin.end(), out_preds->begin());
|
||||
} else {
|
||||
std::fill(out_preds->begin(), out_preds->end(),
|
||||
model_.learner_model_param_->base_score);
|
||||
}
|
||||
}
|
||||
PredLoopSpecalize<Derived>(p_fmat, out_preds, num_group, tree_begin,
|
||||
ntree_limit);
|
||||
}
|
||||
|
||||
template<typename Derived>
|
||||
inline void PredLoopSpecalize(
|
||||
DMatrix* p_fmat,
|
||||
std::vector<bst_float>* out_preds,
|
||||
int num_group,
|
||||
unsigned tree_begin,
|
||||
unsigned tree_end) {
|
||||
unsigned tree_end,
|
||||
bool training) {
|
||||
const int nthread = omp_get_max_threads();
|
||||
CHECK_EQ(num_group, model_.learner_model_param_->num_output_group);
|
||||
InitThreadTemp(nthread);
|
||||
@ -513,13 +499,12 @@ class Dart : public GBTree {
|
||||
<< "size_leaf_vector is enforced to 0 so far";
|
||||
CHECK_EQ(preds.size(), p_fmat->Info().num_row_ * num_group);
|
||||
// start collecting the prediction
|
||||
auto* self = static_cast<Derived*>(this);
|
||||
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
constexpr int kUnroll = 8;
|
||||
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
|
||||
const bst_omp_uint rest = nsize % kUnroll;
|
||||
if (nsize >= kUnroll) {
|
||||
#pragma omp parallel for schedule(static)
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < nsize - rest; i += kUnroll) {
|
||||
const int tid = omp_get_thread_num();
|
||||
RegTree::FVec& feats = thread_temp_[tid];
|
||||
@ -535,7 +520,7 @@ class Dart : public GBTree {
|
||||
for (int gid = 0; gid < num_group; ++gid) {
|
||||
const size_t offset = ridx[k] * num_group + gid;
|
||||
preds[offset] +=
|
||||
self->PredValue(inst[k], gid, &feats, tree_begin, tree_end);
|
||||
this->PredValue(inst[k], gid, &feats, tree_begin, tree_end);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -548,7 +533,7 @@ class Dart : public GBTree {
|
||||
for (int gid = 0; gid < num_group; ++gid) {
|
||||
const size_t offset = ridx * num_group + gid;
|
||||
preds[offset] +=
|
||||
self->PredValue(inst, gid,
|
||||
this->PredValue(inst, gid,
|
||||
&feats, tree_begin, tree_end);
|
||||
}
|
||||
}
|
||||
@ -569,11 +554,9 @@ class Dart : public GBTree {
|
||||
}
|
||||
|
||||
// predict the leaf scores without dropped trees
|
||||
inline bst_float PredValue(const SparsePage::Inst &inst,
|
||||
int bst_group,
|
||||
RegTree::FVec *p_feats,
|
||||
unsigned tree_begin,
|
||||
unsigned tree_end) {
|
||||
bst_float PredValue(const SparsePage::Inst &inst, int bst_group,
|
||||
RegTree::FVec *p_feats, unsigned tree_begin,
|
||||
unsigned tree_end) const {
|
||||
bst_float psum = 0.0f;
|
||||
p_feats->Fill(inst);
|
||||
for (size_t i = tree_begin; i < tree_end; ++i) {
|
||||
@ -590,9 +573,12 @@ class Dart : public GBTree {
|
||||
}
|
||||
|
||||
// select which trees to drop
|
||||
inline void DropTrees(unsigned ntree_limit_drop) {
|
||||
// passing clear=True will clear selection
|
||||
inline void DropTrees(bool is_training) {
|
||||
idx_drop_.clear();
|
||||
if (ntree_limit_drop > 0) return;
|
||||
if (!is_training) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::uniform_real_distribution<> runif(0.0, 1.0);
|
||||
auto& rnd = common::GlobalRandom();
|
||||
|
||||
@ -205,6 +205,7 @@ class GBTree : public GradientBooster {
|
||||
|
||||
void PredictBatch(DMatrix* p_fmat,
|
||||
HostDeviceVector<bst_float>* out_preds,
|
||||
bool training,
|
||||
unsigned ntree_limit) override {
|
||||
CHECK(configured_);
|
||||
GetPredictor(out_preds, p_fmat)->PredictBatch(p_fmat, out_preds, model_, 0, ntree_limit);
|
||||
|
||||
@ -694,7 +694,7 @@ class LearnerImpl : public Learner {
|
||||
this->ValidateDMatrix(train);
|
||||
|
||||
monitor_.Start("PredictRaw");
|
||||
this->PredictRaw(train, &preds_[train]);
|
||||
this->PredictRaw(train, &preds_[train], true);
|
||||
monitor_.Stop("PredictRaw");
|
||||
TrainingObserver::Instance().Observe(preds_[train], "Predictions");
|
||||
|
||||
@ -735,7 +735,7 @@ class LearnerImpl : public Learner {
|
||||
for (size_t i = 0; i < data_sets.size(); ++i) {
|
||||
DMatrix * dmat = data_sets[i];
|
||||
this->ValidateDMatrix(dmat);
|
||||
this->PredictRaw(data_sets[i], &preds_[dmat]);
|
||||
this->PredictRaw(data_sets[i], &preds_[dmat], false);
|
||||
obj_->EvalTransform(&preds_[dmat]);
|
||||
for (auto& ev : metrics_) {
|
||||
os << '\t' << data_names[i] << '-' << ev->Name() << ':'
|
||||
@ -799,6 +799,7 @@ class LearnerImpl : public Learner {
|
||||
|
||||
void Predict(DMatrix* data, bool output_margin,
|
||||
HostDeviceVector<bst_float>* out_preds, unsigned ntree_limit,
|
||||
bool training,
|
||||
bool pred_leaf, bool pred_contribs, bool approx_contribs,
|
||||
bool pred_interactions) override {
|
||||
int multiple_predictions = static_cast<int>(pred_leaf) +
|
||||
@ -814,7 +815,7 @@ class LearnerImpl : public Learner {
|
||||
} else if (pred_leaf) {
|
||||
gbm_->PredictLeaf(data, &out_preds->HostVector(), ntree_limit);
|
||||
} else {
|
||||
this->PredictRaw(data, out_preds, ntree_limit);
|
||||
this->PredictRaw(data, out_preds, training, ntree_limit);
|
||||
if (!output_margin) {
|
||||
obj_->PredTransform(out_preds);
|
||||
}
|
||||
@ -832,13 +833,15 @@ class LearnerImpl : public Learner {
|
||||
* \param out_preds output vector that stores the prediction
|
||||
* \param ntree_limit limit number of trees used for boosted tree
|
||||
* predictor, when it equals 0, this means we are using all the trees
|
||||
* \param training allow dropout when the DART booster is being used
|
||||
*/
|
||||
void PredictRaw(DMatrix* data, HostDeviceVector<bst_float>* out_preds,
|
||||
bool training,
|
||||
unsigned ntree_limit = 0) const {
|
||||
CHECK(gbm_ != nullptr)
|
||||
<< "Predict must happen after Load or configuration";
|
||||
this->ValidateDMatrix(data);
|
||||
gbm_->PredictBatch(data, out_preds, ntree_limit);
|
||||
gbm_->PredictBatch(data, out_preds, training, ntree_limit);
|
||||
}
|
||||
|
||||
void ConfigureObjective(LearnerTrainParam const& old, Args* p_args) {
|
||||
|
||||
@ -18,7 +18,7 @@ DMLC_REGISTRY_FILE_TAG(cpu_predictor);
|
||||
|
||||
class CPUPredictor : public Predictor {
|
||||
protected:
|
||||
static bst_float PredValue(const SparsePage::Inst& inst,
|
||||
static bst_float PredValue(const SparsePage::Inst& inst,
|
||||
const std::vector<std::unique_ptr<RegTree>>& trees,
|
||||
const std::vector<int>& tree_info, int bst_group,
|
||||
RegTree::FVec* p_feats,
|
||||
@ -175,13 +175,15 @@ class CPUPredictor : public Predictor {
|
||||
this->PredLoopInternal(dmat, &out_preds->HostVector(), model,
|
||||
tree_begin, ntree_limit);
|
||||
|
||||
auto cache_emtry = this->FindCache(dmat);
|
||||
if (cache_emtry == cache_->cend()) { return; }
|
||||
if (cache_emtry->second.predictions.Size() == 0) {
|
||||
auto cache_entry = this->FindCache(dmat);
|
||||
if (cache_entry == cache_->cend()) {
|
||||
return;
|
||||
}
|
||||
if (cache_entry->second.predictions.Size() == 0) {
|
||||
// See comment in GPUPredictor::PredictBatch.
|
||||
InitOutPredictions(cache_emtry->second.data->Info(),
|
||||
&(cache_emtry->second.predictions), model);
|
||||
cache_emtry->second.predictions.Copy(*out_preds);
|
||||
InitOutPredictions(cache_entry->second.data->Info(),
|
||||
&(cache_entry->second.predictions), model);
|
||||
cache_entry->second.predictions.Copy(*out_preds);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -2,6 +2,8 @@
|
||||
#include <dmlc/filesystem.h>
|
||||
#include <xgboost/generic_parameters.h>
|
||||
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
#include "xgboost/learner.h"
|
||||
#include "../helpers.h"
|
||||
#include "../../../src/gbm/gbtree.h"
|
||||
@ -18,7 +20,7 @@ TEST(GBTree, SelectTreeMethod) {
|
||||
mparam.num_output_group = 1;
|
||||
|
||||
std::vector<std::shared_ptr<DMatrix> > caches;
|
||||
std::unique_ptr<GradientBooster> p_gbm{
|
||||
std::unique_ptr<GradientBooster> p_gbm {
|
||||
GradientBooster::Create("gbtree", &generic_param, &mparam, caches)};
|
||||
auto& gbtree = dynamic_cast<gbm::GBTree&> (*p_gbm);
|
||||
|
||||
@ -175,4 +177,41 @@ TEST(Dart, Json_IO) {
|
||||
ASSERT_TRUE(IsA<Object>(model["model"]["gbtree"]));
|
||||
ASSERT_NE(get<Array>(model["model"]["weight_drop"]).size(), 0);
|
||||
}
|
||||
|
||||
TEST(Dart, Prediction) {
|
||||
size_t constexpr kRows = 16, kCols = 10;
|
||||
|
||||
auto pp_dmat = CreateDMatrix(kRows, kCols, 0);
|
||||
auto& p_mat = *pp_dmat;
|
||||
|
||||
std::vector<bst_float> labels (kRows);
|
||||
for (size_t i = 0; i < kRows; ++i) {
|
||||
labels[i] = i % 2;
|
||||
}
|
||||
p_mat->Info().SetInfo("label", labels.data(), DataType::kFloat32, kRows);
|
||||
|
||||
auto learner = std::unique_ptr<Learner>(Learner::Create({p_mat}));
|
||||
learner->SetParam("booster", "dart");
|
||||
learner->SetParam("rate_drop", "0.5");
|
||||
learner->Configure();
|
||||
|
||||
for (size_t i = 0; i < 16; ++i) {
|
||||
learner->UpdateOneIter(i, p_mat.get());
|
||||
}
|
||||
|
||||
HostDeviceVector<float> predts_training;
|
||||
learner->Predict(p_mat.get(), false, &predts_training, 0, true);
|
||||
HostDeviceVector<float> predts_inference;
|
||||
learner->Predict(p_mat.get(), false, &predts_inference, 0, false);
|
||||
|
||||
auto& h_predts_training = predts_training.ConstHostVector();
|
||||
auto& h_predts_inference = predts_inference.ConstHostVector();
|
||||
ASSERT_EQ(h_predts_training.size(), h_predts_inference.size());
|
||||
for (size_t i = 0; i < predts_inference.Size(); ++i) {
|
||||
// Inference doesn't drop tree.
|
||||
ASSERT_GT(std::abs(h_predts_training[i] - h_predts_inference[i]), kRtEps);
|
||||
}
|
||||
|
||||
delete pp_dmat;
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
@ -159,7 +159,6 @@ TEST(Learner, Json_ModelIO) {
|
||||
|
||||
{
|
||||
std::unique_ptr<Learner> learner { Learner::Create({p_dmat}) };
|
||||
learner->SetParam("verbosity", "3");
|
||||
for (int32_t iter = 0; iter < kIters; ++iter) {
|
||||
learner->UpdateOneIter(iter, p_dmat.get());
|
||||
}
|
||||
|
||||
@ -54,7 +54,7 @@ TEST(Logging, Basic) {
|
||||
ASSERT_NE(output.find("Test Log Console"), std::string::npos);
|
||||
|
||||
args["silent"] = "False";
|
||||
args["verbosity"] = "1"; // restore
|
||||
args["verbosity"] = "2"; // restore
|
||||
ConsoleLogger::Configure({args.cbegin(), args.cend()});
|
||||
}
|
||||
|
||||
|
||||
@ -44,7 +44,8 @@ class TestModels(unittest.TestCase):
|
||||
def test_dart(self):
|
||||
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
|
||||
param = {'max_depth': 5, 'objective': 'binary:logistic', 'booster': 'dart', 'verbosity': 1}
|
||||
param = {'max_depth': 5, 'objective': 'binary:logistic',
|
||||
'eval_metric': 'logloss', 'booster': 'dart', 'verbosity': 1}
|
||||
# specify validations set to watch performance
|
||||
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
||||
num_round = 2
|
||||
@ -52,7 +53,8 @@ class TestModels(unittest.TestCase):
|
||||
# this is prediction
|
||||
preds = bst.predict(dtest, ntree_limit=num_round)
|
||||
labels = dtest.get_label()
|
||||
err = sum(1 for i in range(len(preds)) if int(preds[i] > 0.5) != labels[i]) / float(len(preds))
|
||||
err = sum(1 for i in range(len(preds))
|
||||
if int(preds[i] > 0.5) != labels[i]) / float(len(preds))
|
||||
# error must be smaller than 10%
|
||||
assert err < 0.1
|
||||
|
||||
@ -68,18 +70,31 @@ class TestModels(unittest.TestCase):
|
||||
# assert they are the same
|
||||
assert np.sum(np.abs(preds2 - preds)) == 0
|
||||
|
||||
def my_logloss(preds, dtrain):
|
||||
labels = dtrain.get_label()
|
||||
return 'logloss', np.sum(
|
||||
np.log(np.where(labels, preds, 1 - preds)))
|
||||
|
||||
# check whether custom evaluation metrics work
|
||||
bst = xgb.train(param, dtrain, num_round, watchlist,
|
||||
feval=my_logloss)
|
||||
preds3 = bst.predict(dtest, ntree_limit=num_round)
|
||||
assert all(preds3 == preds)
|
||||
|
||||
# check whether sample_type and normalize_type work
|
||||
num_round = 50
|
||||
param['verbosity'] = 0
|
||||
param['learning_rate'] = 0.1
|
||||
param['rate_drop'] = 0.1
|
||||
preds_list = []
|
||||
for p in [[p0, p1] for p0 in ['uniform', 'weighted'] for p1 in ['tree', 'forest']]:
|
||||
for p in [[p0, p1] for p0 in ['uniform', 'weighted']
|
||||
for p1 in ['tree', 'forest']]:
|
||||
param['sample_type'] = p[0]
|
||||
param['normalize_type'] = p[1]
|
||||
bst = xgb.train(param, dtrain, num_round, watchlist)
|
||||
preds = bst.predict(dtest, ntree_limit=num_round)
|
||||
err = sum(1 for i in range(len(preds)) if int(preds[i] > 0.5) != labels[i]) / float(len(preds))
|
||||
err = sum(1 for i in range(len(preds))
|
||||
if int(preds[i] > 0.5) != labels[i]) / float(len(preds))
|
||||
assert err < 0.1
|
||||
preds_list.append(preds)
|
||||
|
||||
|
||||
@ -135,7 +135,7 @@ class TestRanking(unittest.TestCase):
|
||||
# specify validations set to watch performance
|
||||
watchlist = [(self.dtest, 'eval'), (self.dtrain, 'train')]
|
||||
bst = xgboost.train(self.params, self.dtrain, num_boost_round=2500,
|
||||
early_stopping_rounds=10, evals=watchlist)
|
||||
early_stopping_rounds=10, evals=watchlist)
|
||||
assert bst.best_score > 0.98
|
||||
|
||||
def test_cv(self):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user