[Breaking] Don't drop trees during DART prediction by default (#5115)

* Simplify DropTrees calling logic

* Add `training` parameter for prediction method.

* [Breaking]: Add `training` to C API.

* Change for R and Python custom objective.

* Correct comment.

Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
Kodi Arfer 2020-01-13 08:48:30 -05:00 committed by Jiaming Yuan
parent 7b65698187
commit f100b8d878
23 changed files with 214 additions and 140 deletions

View File

@ -145,7 +145,7 @@ xgb.iter.update <- function(booster_handle, dtrain, iter, obj = NULL) {
if (is.null(obj)) {
.Call(XGBoosterUpdateOneIter_R, booster_handle, as.integer(iter), dtrain)
} else {
pred <- predict(booster_handle, dtrain)
pred <- predict(booster_handle, dtrain, training = TRUE)
gpair <- obj(pred, dtrain)
.Call(XGBoosterBoostOneIter_R, booster_handle, dtrain, gpair$grad, gpair$hess)
}

View File

@ -288,7 +288,7 @@ xgb.Booster.complete <- function(object, saveraw = TRUE) {
#' @export
predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FALSE, ntreelimit = NULL,
predleaf = FALSE, predcontrib = FALSE, approxcontrib = FALSE, predinteraction = FALSE,
reshape = FALSE, ...) {
reshape = FALSE, training = FALSE, ...) {
object <- xgb.Booster.complete(object, saveraw = FALSE)
if (!inherits(newdata, "xgb.DMatrix"))
@ -307,7 +307,8 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
option <- 0L + 1L * as.logical(outputmargin) + 2L * as.logical(predleaf) + 4L * as.logical(predcontrib) +
8L * as.logical(approxcontrib) + 16L * as.logical(predinteraction)
ret <- .Call(XGBoosterPredict_R, object$handle, newdata, option[1], as.integer(ntreelimit))
ret <- .Call(XGBoosterPredict_R, object$handle, newdata, option[1],
as.integer(ntreelimit), as.integer(training))
n_ret <- length(ret)
n_row <- nrow(newdata)

View File

@ -24,7 +24,7 @@ extern SEXP XGBoosterGetAttr_R(SEXP, SEXP);
extern SEXP XGBoosterLoadModelFromRaw_R(SEXP, SEXP);
extern SEXP XGBoosterLoadModel_R(SEXP, SEXP);
extern SEXP XGBoosterModelToRaw_R(SEXP);
extern SEXP XGBoosterPredict_R(SEXP, SEXP, SEXP, SEXP);
extern SEXP XGBoosterPredict_R(SEXP, SEXP, SEXP, SEXP, SEXP);
extern SEXP XGBoosterSaveModel_R(SEXP, SEXP);
extern SEXP XGBoosterSetAttr_R(SEXP, SEXP, SEXP);
extern SEXP XGBoosterSetParam_R(SEXP, SEXP, SEXP);
@ -50,7 +50,7 @@ static const R_CallMethodDef CallEntries[] = {
{"XGBoosterLoadModelFromRaw_R", (DL_FUNC) &XGBoosterLoadModelFromRaw_R, 2},
{"XGBoosterLoadModel_R", (DL_FUNC) &XGBoosterLoadModel_R, 2},
{"XGBoosterModelToRaw_R", (DL_FUNC) &XGBoosterModelToRaw_R, 1},
{"XGBoosterPredict_R", (DL_FUNC) &XGBoosterPredict_R, 4},
{"XGBoosterPredict_R", (DL_FUNC) &XGBoosterPredict_R, 5},
{"XGBoosterSaveModel_R", (DL_FUNC) &XGBoosterSaveModel_R, 2},
{"XGBoosterSetAttr_R", (DL_FUNC) &XGBoosterSetAttr_R, 3},
{"XGBoosterSetParam_R", (DL_FUNC) &XGBoosterSetParam_R, 3},

View File

@ -295,24 +295,26 @@ SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evnames) {
vec_sptr.push_back(vec_names[i].c_str());
}
CHECK_CALL(XGBoosterEvalOneIter(R_ExternalPtrAddr(handle),
asInteger(iter),
BeginPtr(vec_dmats),
BeginPtr(vec_sptr),
len, &ret));
asInteger(iter),
BeginPtr(vec_dmats),
BeginPtr(vec_sptr),
len, &ret));
R_API_END();
return mkString(ret);
}
SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP option_mask, SEXP ntree_limit) {
SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP option_mask,
SEXP ntree_limit, SEXP training) {
SEXP ret;
R_API_BEGIN();
bst_ulong olen;
const float *res;
CHECK_CALL(XGBoosterPredict(R_ExternalPtrAddr(handle),
R_ExternalPtrAddr(dmat),
asInteger(option_mask),
asInteger(ntree_limit),
&olen, &res));
R_ExternalPtrAddr(dmat),
asInteger(option_mask),
asInteger(ntree_limit),
0,
&olen, &res));
ret = PROTECT(allocVector(REALSXP, olen));
for (size_t i = 0; i < olen; ++i) {
REAL(ret)[i] = res[i];

View File

@ -148,8 +148,10 @@ XGB_DLL SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evn
* \param dmat data matrix
* \param option_mask output_margin:1 predict_leaf:2
* \param ntree_limit limit number of trees used in prediction
* \param training Whether the prediction value is used for training.
*/
XGB_DLL SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP option_mask, SEXP ntree_limit);
XGB_DLL SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP option_mask,
SEXP ntree_limit, SEXP training);
/*!
* \brief load model from existing file
* \param handle handle

View File

@ -166,7 +166,7 @@ test_that("SHAPs sum to predictions, with or without DART", {
nrounds = nrounds)
pr <- function(...)
predict(fit, newdata = d, ntreelimit = nrounds, ...)
predict(fit, newdata = d, ...)
pred <- pr()
shap <- pr(predcontrib = T)
shapi <- pr(predinteraction = T)

View File

@ -108,12 +108,4 @@ Sample Script
'skip_drop': 0.5}
num_round = 50
bst = xgb.train(param, dtrain, num_round)
# make prediction
# ntree_limit must not be 0
preds = bst.predict(dtest, ntree_limit=num_round)
.. note:: Specify ``ntree_limit`` when predicting with test sets
By default, ``bst.predict()`` will perform dropouts on trees. To obtain
correct results on test sets, disable dropouts by specifying
a nonzero value for ``ntree_limit``.
preds = bst.predict(dtest)

View File

@ -416,6 +416,7 @@ XGB_DLL int XGBoosterEvalOneIter(BoosterHandle handle,
* 4:output feature contributions to individual predictions
* \param ntree_limit limit number of trees used for prediction, this is only valid for boosted trees
* when the parameter is set to 0, we will use all the trees
* \param training Whether the prediction value is used for training.
* \param out_len used to store length of returning result
* \param out_result used to set a pointer to array
* \return 0 when success, -1 when failure happens
@ -424,6 +425,7 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
DMatrixHandle dmat,
int option_mask,
unsigned ntree_limit,
int training,
bst_ulong *out_len,
const float **out_result);
/*

View File

@ -84,6 +84,7 @@ class GradientBooster : public Model, public Configurable {
*/
virtual void PredictBatch(DMatrix* dmat,
HostDeviceVector<bst_float>* out_preds,
bool training,
unsigned ntree_limit = 0) = 0;
/*!
* \brief online prediction function, predict score for one instance at a time

View File

@ -96,6 +96,7 @@ class Learner : public Model, public Configurable, public rabit::Serializable {
bool output_margin,
HostDeviceVector<bst_float> *out_preds,
unsigned ntree_limit = 0,
bool training = false,
bool pred_leaf = false,
bool pred_contribs = false,
bool approx_contribs = false,

View File

@ -556,7 +556,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredict
DMatrixHandle dmat = (DMatrixHandle) jdmat;
bst_ulong len;
float *result;
int ret = XGBoosterPredict(handle, dmat, joption_mask, (unsigned int) jntree_limit, &len, (const float **) &result);
int ret = XGBoosterPredict(handle, dmat, joption_mask, (unsigned int) jntree_limit, 0,
&len, (const float **) &result);
if (len) {
jsize jlen = (jsize) len;
jfloatArray jarray = jenv->NewFloatArray(jlen);

View File

@ -196,7 +196,8 @@ def ctypes2numpy(cptr, length, dtype):
np.uint32: ctypes.c_uint,
}
if dtype not in NUMPY_TO_CTYPES_MAPPING:
raise RuntimeError('Supported types: {}'.format(NUMPY_TO_CTYPES_MAPPING.keys()))
raise RuntimeError('Supported types: {}'.format(
NUMPY_TO_CTYPES_MAPPING.keys()))
ctype = NUMPY_TO_CTYPES_MAPPING[dtype]
if not isinstance(cptr, ctypes.POINTER(ctype)):
raise RuntimeError('expected {} pointer'.format(ctype))
@ -1275,14 +1276,16 @@ class Booster(object):
"""
if not isinstance(dtrain, DMatrix):
raise TypeError('invalid training matrix: {}'.format(type(dtrain).__name__))
raise TypeError('invalid training matrix: {}'.format(
type(dtrain).__name__))
self._validate_features(dtrain)
if fobj is None:
_check_call(_LIB.XGBoosterUpdateOneIter(self.handle, ctypes.c_int(iteration),
_check_call(_LIB.XGBoosterUpdateOneIter(self.handle,
ctypes.c_int(iteration),
dtrain.handle))
else:
pred = self.predict(dtrain)
pred = self.predict(dtrain, training=True)
grad, hess = fobj(pred, dtrain)
self.boost(dtrain, grad, hess)
@ -1332,22 +1335,25 @@ class Booster(object):
"""
for d in evals:
if not isinstance(d[0], DMatrix):
raise TypeError('expected DMatrix, got {}'.format(type(d[0]).__name__))
raise TypeError('expected DMatrix, got {}'.format(
type(d[0]).__name__))
if not isinstance(d[1], STRING_TYPES):
raise TypeError('expected string, got {}'.format(type(d[1]).__name__))
raise TypeError('expected string, got {}'.format(
type(d[1]).__name__))
self._validate_features(d[0])
dmats = c_array(ctypes.c_void_p, [d[0].handle for d in evals])
evnames = c_array(ctypes.c_char_p, [c_str(d[1]) for d in evals])
msg = ctypes.c_char_p()
_check_call(_LIB.XGBoosterEvalOneIter(self.handle, ctypes.c_int(iteration),
_check_call(_LIB.XGBoosterEvalOneIter(self.handle,
ctypes.c_int(iteration),
dmats, evnames,
c_bst_ulong(len(evals)),
ctypes.byref(msg)))
res = msg.value.decode()
if feval is not None:
for dmat, evname in evals:
feval_ret = feval(self.predict(dmat), dmat)
feval_ret = feval(self.predict(dmat, training=False), dmat)
if isinstance(feval_ret, list):
for name, val in feval_ret:
res += '\t%s-%s:%f' % (evname, name, val)
@ -1378,27 +1384,24 @@ class Booster(object):
self._validate_features(data)
return self.eval_set([(data, name)], iteration)
def predict(self, data, output_margin=False, ntree_limit=0, pred_leaf=False,
pred_contribs=False, approx_contribs=False, pred_interactions=False,
validate_features=True):
def predict(self,
data,
output_margin=False,
ntree_limit=0,
pred_leaf=False,
pred_contribs=False,
approx_contribs=False,
pred_interactions=False,
validate_features=True,
training=False):
"""Predict with data.
.. note:: This function is not thread safe.
For each booster object, predict can only be called from one thread.
If you want to run prediction using multiple thread, call ``bst.copy()`` to make copies
of model object and then call ``predict()``.
.. note:: Using ``predict()`` with DART booster
If the booster object is DART type, ``predict()`` will perform dropouts, i.e. only
some of the trees will be evaluated. This will produce incorrect results if ``data`` is
not the training data. To obtain correct results on test sets, set ``ntree_limit`` to
a nonzero value, e.g.
.. code-block:: python
preds = bst.predict(dtest, ntree_limit=num_round)
If you want to run prediction using multiple thread, call
``bst.copy()`` to make copies of model object and then call
``predict()``.
Parameters
----------
@ -1409,38 +1412,53 @@ class Booster(object):
Whether to output the raw untransformed margin value.
ntree_limit : int
Limit number of trees in the prediction; defaults to 0 (use all trees).
Limit number of trees in the prediction; defaults to 0 (use all
trees).
pred_leaf : bool
When this option is on, the output will be a matrix of (nsample, ntrees)
with each record indicating the predicted leaf index of each sample in each tree.
Note that the leaf index of a tree is unique per tree, so you may find leaf 1
in both tree 1 and tree 0.
When this option is on, the output will be a matrix of (nsample,
ntrees) with each record indicating the predicted leaf index of
each sample in each tree. Note that the leaf index of a tree is
unique per tree, so you may find leaf 1 in both tree 1 and tree 0.
pred_contribs : bool
When this is True the output will be a matrix of size (nsample, nfeats + 1)
with each record indicating the feature contributions (SHAP values) for that
prediction. The sum of all feature contributions is equal to the raw untransformed
margin value of the prediction. Note the final column is the bias term.
When this is True the output will be a matrix of size (nsample,
nfeats + 1) with each record indicating the feature contributions
(SHAP values) for that prediction. The sum of all feature
contributions is equal to the raw untransformed margin value of the
prediction. Note the final column is the bias term.
approx_contribs : bool
Approximate the contributions of each feature
pred_interactions : bool
When this is True the output will be a matrix of size (nsample, nfeats + 1, nfeats + 1)
indicating the SHAP interaction values for each pair of features. The sum of each
row (or column) of the interaction values equals the corresponding SHAP value (from
pred_contribs), and the sum of the entire matrix equals the raw untransformed margin
value of the prediction. Note the last row and column correspond to the bias term.
When this is True the output will be a matrix of size (nsample,
nfeats + 1, nfeats + 1) indicating the SHAP interaction values for
each pair of features. The sum of each row (or column) of the
interaction values equals the corresponding SHAP value (from
pred_contribs), and the sum of the entire matrix equals the raw
untransformed margin value of the prediction. Note the last row and
column correspond to the bias term.
validate_features : bool
When this is True, validate that the Booster's and data's
feature_names are identical. Otherwise, it is assumed that the
feature_names are the same.
training : bool
Whether the prediction value is used for training. This can effect
`dart` booster, which performs dropouts during training iterations.
.. note:: Using ``predict()`` with DART booster
If the booster object is DART type, ``predict()`` will not perform
dropouts, i.e. all the trees will be evaluated. If you want to
obtain result with dropouts, provide `training=True`.
Returns
-------
prediction : numpy array
"""
option_mask = 0x00
if output_margin:
@ -1466,6 +1484,7 @@ class Booster(object):
_check_call(_LIB.XGBoosterPredict(self.handle, data.handle,
ctypes.c_int(option_mask),
ctypes.c_uint(ntree_limit),
ctypes.c_int(training),
ctypes.byref(length),
ctypes.byref(preds)))
preds = ctypes2numpy(preds, length.value, np.float32)
@ -1476,11 +1495,16 @@ class Booster(object):
chunk_size = int(preds.size / nrow)
if pred_interactions:
ngroup = int(chunk_size / ((data.num_col() + 1) * (data.num_col() + 1)))
ngroup = int(chunk_size / ((data.num_col() + 1) *
(data.num_col() + 1)))
if ngroup == 1:
preds = preds.reshape(nrow, data.num_col() + 1, data.num_col() + 1)
preds = preds.reshape(nrow,
data.num_col() + 1,
data.num_col() + 1)
else:
preds = preds.reshape(nrow, ngroup, data.num_col() + 1, data.num_col() + 1)
preds = preds.reshape(nrow, ngroup,
data.num_col() + 1,
data.num_col() + 1)
elif pred_contribs:
ngroup = int(chunk_size / (data.num_col() + 1))
if ngroup == 1:

View File

@ -570,10 +570,11 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
DMatrixHandle dmat,
int option_mask,
unsigned ntree_limit,
int32_t training,
xgboost::bst_ulong *len,
const bst_float **out_result) {
std::vector<bst_float>&preds =
XGBAPIThreadLocalStore::Get()->ret_vec_float;
std::vector<bst_float>& preds =
XGBAPIThreadLocalStore::Get()->ret_vec_float;
API_BEGIN();
CHECK_HANDLE();
auto *bst = static_cast<Learner*>(handle);
@ -582,6 +583,7 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
static_cast<std::shared_ptr<DMatrix>*>(dmat)->get(),
(option_mask & 1) != 0,
&tmp_preds, ntree_limit,
static_cast<bool>(training),
(option_mask & 2) != 0,
(option_mask & 4) != 0,
(option_mask & 8) != 0,

View File

@ -127,6 +127,7 @@ class GBLinear : public GradientBooster {
void PredictBatch(DMatrix *p_fmat,
HostDeviceVector<bst_float> *out_preds,
bool training,
unsigned ntree_limit) override {
monitor_.Start("PredictBatch");
CHECK_EQ(ntree_limit, 0U)

View File

@ -414,17 +414,36 @@ class Dart : public GBTree {
out["dart_train_param"] = toJson(dparam_);
}
// predict the leaf scores with dropout if ntree_limit = 0
void PredictBatch(DMatrix* p_fmat,
HostDeviceVector<bst_float>* out_preds,
HostDeviceVector<bst_float>* p_out_preds,
bool training,
unsigned ntree_limit) override {
DropTrees(ntree_limit);
PredLoopInternal<Dart>(p_fmat, &out_preds->HostVector(), 0, ntree_limit, true);
DropTrees(training);
int num_group = model_.learner_model_param_->num_output_group;
ntree_limit *= num_group;
if (ntree_limit == 0 || ntree_limit > model_.trees.size()) {
ntree_limit = static_cast<unsigned>(model_.trees.size());
}
size_t n = num_group * p_fmat->Info().num_row_;
const auto &base_margin = p_fmat->Info().base_margin_.ConstHostVector();
auto& out_preds = p_out_preds->HostVector();
out_preds.resize(n);
if (base_margin.size() != 0) {
CHECK_EQ(out_preds.size(), n);
std::copy(base_margin.begin(), base_margin.end(), out_preds.begin());
} else {
std::fill(out_preds.begin(), out_preds.end(),
model_.learner_model_param_->base_score);
}
PredLoopSpecalize(p_fmat, &out_preds, num_group, 0,
ntree_limit, training);
}
void PredictInstance(const SparsePage::Inst &inst,
std::vector<bst_float> *out_preds, unsigned ntree_limit) override {
DropTrees(1);
std::vector<bst_float> *out_preds,
unsigned ntree_limit) override {
DropTrees(false);
if (thread_temp_.size() == 0) {
thread_temp_.resize(1, RegTree::FVec());
thread_temp_[0].Init(model_.learner_model_param_->num_feature);
@ -465,46 +484,13 @@ class Dart : public GBTree {
protected:
friend class GBTree;
// internal prediction loop
// add predictions to out_preds
template<typename Derived>
inline void PredLoopInternal(
DMatrix* p_fmat,
std::vector<bst_float>* out_preds,
unsigned tree_begin,
unsigned ntree_limit,
bool init_out_preds) {
int num_group = model_.learner_model_param_->num_output_group;
ntree_limit *= num_group;
if (ntree_limit == 0 || ntree_limit > model_.trees.size()) {
ntree_limit = static_cast<unsigned>(model_.trees.size());
}
if (init_out_preds) {
size_t n = num_group * p_fmat->Info().num_row_;
const auto& base_margin =
p_fmat->Info().base_margin_.ConstHostVector();
out_preds->resize(n);
if (base_margin.size() != 0) {
CHECK_EQ(out_preds->size(), n);
std::copy(base_margin.begin(), base_margin.end(), out_preds->begin());
} else {
std::fill(out_preds->begin(), out_preds->end(),
model_.learner_model_param_->base_score);
}
}
PredLoopSpecalize<Derived>(p_fmat, out_preds, num_group, tree_begin,
ntree_limit);
}
template<typename Derived>
inline void PredLoopSpecalize(
DMatrix* p_fmat,
std::vector<bst_float>* out_preds,
int num_group,
unsigned tree_begin,
unsigned tree_end) {
unsigned tree_end,
bool training) {
const int nthread = omp_get_max_threads();
CHECK_EQ(num_group, model_.learner_model_param_->num_output_group);
InitThreadTemp(nthread);
@ -513,13 +499,12 @@ class Dart : public GBTree {
<< "size_leaf_vector is enforced to 0 so far";
CHECK_EQ(preds.size(), p_fmat->Info().num_row_ * num_group);
// start collecting the prediction
auto* self = static_cast<Derived*>(this);
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
constexpr int kUnroll = 8;
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
const bst_omp_uint rest = nsize % kUnroll;
if (nsize >= kUnroll) {
#pragma omp parallel for schedule(static)
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < nsize - rest; i += kUnroll) {
const int tid = omp_get_thread_num();
RegTree::FVec& feats = thread_temp_[tid];
@ -535,7 +520,7 @@ class Dart : public GBTree {
for (int gid = 0; gid < num_group; ++gid) {
const size_t offset = ridx[k] * num_group + gid;
preds[offset] +=
self->PredValue(inst[k], gid, &feats, tree_begin, tree_end);
this->PredValue(inst[k], gid, &feats, tree_begin, tree_end);
}
}
}
@ -548,7 +533,7 @@ class Dart : public GBTree {
for (int gid = 0; gid < num_group; ++gid) {
const size_t offset = ridx * num_group + gid;
preds[offset] +=
self->PredValue(inst, gid,
this->PredValue(inst, gid,
&feats, tree_begin, tree_end);
}
}
@ -569,11 +554,9 @@ class Dart : public GBTree {
}
// predict the leaf scores without dropped trees
inline bst_float PredValue(const SparsePage::Inst &inst,
int bst_group,
RegTree::FVec *p_feats,
unsigned tree_begin,
unsigned tree_end) {
bst_float PredValue(const SparsePage::Inst &inst, int bst_group,
RegTree::FVec *p_feats, unsigned tree_begin,
unsigned tree_end) const {
bst_float psum = 0.0f;
p_feats->Fill(inst);
for (size_t i = tree_begin; i < tree_end; ++i) {
@ -590,9 +573,12 @@ class Dart : public GBTree {
}
// select which trees to drop
inline void DropTrees(unsigned ntree_limit_drop) {
// passing clear=True will clear selection
inline void DropTrees(bool is_training) {
idx_drop_.clear();
if (ntree_limit_drop > 0) return;
if (!is_training) {
return;
}
std::uniform_real_distribution<> runif(0.0, 1.0);
auto& rnd = common::GlobalRandom();

View File

@ -205,6 +205,7 @@ class GBTree : public GradientBooster {
void PredictBatch(DMatrix* p_fmat,
HostDeviceVector<bst_float>* out_preds,
bool training,
unsigned ntree_limit) override {
CHECK(configured_);
GetPredictor(out_preds, p_fmat)->PredictBatch(p_fmat, out_preds, model_, 0, ntree_limit);

View File

@ -694,7 +694,7 @@ class LearnerImpl : public Learner {
this->ValidateDMatrix(train);
monitor_.Start("PredictRaw");
this->PredictRaw(train, &preds_[train]);
this->PredictRaw(train, &preds_[train], true);
monitor_.Stop("PredictRaw");
TrainingObserver::Instance().Observe(preds_[train], "Predictions");
@ -735,7 +735,7 @@ class LearnerImpl : public Learner {
for (size_t i = 0; i < data_sets.size(); ++i) {
DMatrix * dmat = data_sets[i];
this->ValidateDMatrix(dmat);
this->PredictRaw(data_sets[i], &preds_[dmat]);
this->PredictRaw(data_sets[i], &preds_[dmat], false);
obj_->EvalTransform(&preds_[dmat]);
for (auto& ev : metrics_) {
os << '\t' << data_names[i] << '-' << ev->Name() << ':'
@ -799,6 +799,7 @@ class LearnerImpl : public Learner {
void Predict(DMatrix* data, bool output_margin,
HostDeviceVector<bst_float>* out_preds, unsigned ntree_limit,
bool training,
bool pred_leaf, bool pred_contribs, bool approx_contribs,
bool pred_interactions) override {
int multiple_predictions = static_cast<int>(pred_leaf) +
@ -814,7 +815,7 @@ class LearnerImpl : public Learner {
} else if (pred_leaf) {
gbm_->PredictLeaf(data, &out_preds->HostVector(), ntree_limit);
} else {
this->PredictRaw(data, out_preds, ntree_limit);
this->PredictRaw(data, out_preds, training, ntree_limit);
if (!output_margin) {
obj_->PredTransform(out_preds);
}
@ -832,13 +833,15 @@ class LearnerImpl : public Learner {
* \param out_preds output vector that stores the prediction
* \param ntree_limit limit number of trees used for boosted tree
* predictor, when it equals 0, this means we are using all the trees
* \param training allow dropout when the DART booster is being used
*/
void PredictRaw(DMatrix* data, HostDeviceVector<bst_float>* out_preds,
bool training,
unsigned ntree_limit = 0) const {
CHECK(gbm_ != nullptr)
<< "Predict must happen after Load or configuration";
this->ValidateDMatrix(data);
gbm_->PredictBatch(data, out_preds, ntree_limit);
gbm_->PredictBatch(data, out_preds, training, ntree_limit);
}
void ConfigureObjective(LearnerTrainParam const& old, Args* p_args) {

View File

@ -18,7 +18,7 @@ DMLC_REGISTRY_FILE_TAG(cpu_predictor);
class CPUPredictor : public Predictor {
protected:
static bst_float PredValue(const SparsePage::Inst& inst,
static bst_float PredValue(const SparsePage::Inst& inst,
const std::vector<std::unique_ptr<RegTree>>& trees,
const std::vector<int>& tree_info, int bst_group,
RegTree::FVec* p_feats,
@ -175,13 +175,15 @@ class CPUPredictor : public Predictor {
this->PredLoopInternal(dmat, &out_preds->HostVector(), model,
tree_begin, ntree_limit);
auto cache_emtry = this->FindCache(dmat);
if (cache_emtry == cache_->cend()) { return; }
if (cache_emtry->second.predictions.Size() == 0) {
auto cache_entry = this->FindCache(dmat);
if (cache_entry == cache_->cend()) {
return;
}
if (cache_entry->second.predictions.Size() == 0) {
// See comment in GPUPredictor::PredictBatch.
InitOutPredictions(cache_emtry->second.data->Info(),
&(cache_emtry->second.predictions), model);
cache_emtry->second.predictions.Copy(*out_preds);
InitOutPredictions(cache_entry->second.data->Info(),
&(cache_entry->second.predictions), model);
cache_entry->second.predictions.Copy(*out_preds);
}
}

View File

@ -2,6 +2,8 @@
#include <dmlc/filesystem.h>
#include <xgboost/generic_parameters.h>
#include "xgboost/base.h"
#include "xgboost/host_device_vector.h"
#include "xgboost/learner.h"
#include "../helpers.h"
#include "../../../src/gbm/gbtree.h"
@ -18,7 +20,7 @@ TEST(GBTree, SelectTreeMethod) {
mparam.num_output_group = 1;
std::vector<std::shared_ptr<DMatrix> > caches;
std::unique_ptr<GradientBooster> p_gbm{
std::unique_ptr<GradientBooster> p_gbm {
GradientBooster::Create("gbtree", &generic_param, &mparam, caches)};
auto& gbtree = dynamic_cast<gbm::GBTree&> (*p_gbm);
@ -175,4 +177,41 @@ TEST(Dart, Json_IO) {
ASSERT_TRUE(IsA<Object>(model["model"]["gbtree"]));
ASSERT_NE(get<Array>(model["model"]["weight_drop"]).size(), 0);
}
TEST(Dart, Prediction) {
size_t constexpr kRows = 16, kCols = 10;
auto pp_dmat = CreateDMatrix(kRows, kCols, 0);
auto& p_mat = *pp_dmat;
std::vector<bst_float> labels (kRows);
for (size_t i = 0; i < kRows; ++i) {
labels[i] = i % 2;
}
p_mat->Info().SetInfo("label", labels.data(), DataType::kFloat32, kRows);
auto learner = std::unique_ptr<Learner>(Learner::Create({p_mat}));
learner->SetParam("booster", "dart");
learner->SetParam("rate_drop", "0.5");
learner->Configure();
for (size_t i = 0; i < 16; ++i) {
learner->UpdateOneIter(i, p_mat.get());
}
HostDeviceVector<float> predts_training;
learner->Predict(p_mat.get(), false, &predts_training, 0, true);
HostDeviceVector<float> predts_inference;
learner->Predict(p_mat.get(), false, &predts_inference, 0, false);
auto& h_predts_training = predts_training.ConstHostVector();
auto& h_predts_inference = predts_inference.ConstHostVector();
ASSERT_EQ(h_predts_training.size(), h_predts_inference.size());
for (size_t i = 0; i < predts_inference.Size(); ++i) {
// Inference doesn't drop tree.
ASSERT_GT(std::abs(h_predts_training[i] - h_predts_inference[i]), kRtEps);
}
delete pp_dmat;
}
} // namespace xgboost

View File

@ -159,7 +159,6 @@ TEST(Learner, Json_ModelIO) {
{
std::unique_ptr<Learner> learner { Learner::Create({p_dmat}) };
learner->SetParam("verbosity", "3");
for (int32_t iter = 0; iter < kIters; ++iter) {
learner->UpdateOneIter(iter, p_dmat.get());
}

View File

@ -54,7 +54,7 @@ TEST(Logging, Basic) {
ASSERT_NE(output.find("Test Log Console"), std::string::npos);
args["silent"] = "False";
args["verbosity"] = "1"; // restore
args["verbosity"] = "2"; // restore
ConsoleLogger::Configure({args.cbegin(), args.cend()});
}

View File

@ -44,7 +44,8 @@ class TestModels(unittest.TestCase):
def test_dart(self):
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
param = {'max_depth': 5, 'objective': 'binary:logistic', 'booster': 'dart', 'verbosity': 1}
param = {'max_depth': 5, 'objective': 'binary:logistic',
'eval_metric': 'logloss', 'booster': 'dart', 'verbosity': 1}
# specify validations set to watch performance
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
num_round = 2
@ -52,7 +53,8 @@ class TestModels(unittest.TestCase):
# this is prediction
preds = bst.predict(dtest, ntree_limit=num_round)
labels = dtest.get_label()
err = sum(1 for i in range(len(preds)) if int(preds[i] > 0.5) != labels[i]) / float(len(preds))
err = sum(1 for i in range(len(preds))
if int(preds[i] > 0.5) != labels[i]) / float(len(preds))
# error must be smaller than 10%
assert err < 0.1
@ -68,18 +70,31 @@ class TestModels(unittest.TestCase):
# assert they are the same
assert np.sum(np.abs(preds2 - preds)) == 0
def my_logloss(preds, dtrain):
labels = dtrain.get_label()
return 'logloss', np.sum(
np.log(np.where(labels, preds, 1 - preds)))
# check whether custom evaluation metrics work
bst = xgb.train(param, dtrain, num_round, watchlist,
feval=my_logloss)
preds3 = bst.predict(dtest, ntree_limit=num_round)
assert all(preds3 == preds)
# check whether sample_type and normalize_type work
num_round = 50
param['verbosity'] = 0
param['learning_rate'] = 0.1
param['rate_drop'] = 0.1
preds_list = []
for p in [[p0, p1] for p0 in ['uniform', 'weighted'] for p1 in ['tree', 'forest']]:
for p in [[p0, p1] for p0 in ['uniform', 'weighted']
for p1 in ['tree', 'forest']]:
param['sample_type'] = p[0]
param['normalize_type'] = p[1]
bst = xgb.train(param, dtrain, num_round, watchlist)
preds = bst.predict(dtest, ntree_limit=num_round)
err = sum(1 for i in range(len(preds)) if int(preds[i] > 0.5) != labels[i]) / float(len(preds))
err = sum(1 for i in range(len(preds))
if int(preds[i] > 0.5) != labels[i]) / float(len(preds))
assert err < 0.1
preds_list.append(preds)

View File

@ -135,7 +135,7 @@ class TestRanking(unittest.TestCase):
# specify validations set to watch performance
watchlist = [(self.dtest, 'eval'), (self.dtrain, 'train')]
bst = xgboost.train(self.params, self.dtrain, num_boost_round=2500,
early_stopping_rounds=10, evals=watchlist)
early_stopping_rounds=10, evals=watchlist)
assert bst.best_score > 0.98
def test_cv(self):