add ntree limit

This commit is contained in:
tqchen 2014-09-01 15:10:19 -07:00
parent 4c451de90b
commit 4592e500cb
10 changed files with 53 additions and 23 deletions

View File

@ -11,7 +11,8 @@ setClass("xgb.Booster")
#' value of sum of functions, when outputmargin=TRUE, the prediction is #' value of sum of functions, when outputmargin=TRUE, the prediction is
#' untransformed margin value. In logistic regression, outputmargin=T will #' untransformed margin value. In logistic regression, outputmargin=T will
#' output value before logistic transformation. #' output value before logistic transformation.
#' #' @param ntreelimit limit number of trees used in prediction, this parameter is only valid for gbtree, but not for gblinear.
#' set it to be value bigger than 0
#' @examples #' @examples
#' data(iris) #' data(iris)
#' bst <- xgboost(as.matrix(iris[,1:4]),as.numeric(iris[,5]), nrounds = 2) #' bst <- xgboost(as.matrix(iris[,1:4]),as.numeric(iris[,5]), nrounds = 2)
@ -19,11 +20,18 @@ setClass("xgb.Booster")
#' @export #' @export
#' #'
setMethod("predict", signature = "xgb.Booster", setMethod("predict", signature = "xgb.Booster",
definition = function(object, newdata, outputmargin = FALSE) { definition = function(object, newdata, outputmargin = FALSE, ntreelimit = NULL) {
if (class(newdata) != "xgb.DMatrix") { if (class(newdata) != "xgb.DMatrix") {
newdata <- xgb.DMatrix(newdata) newdata <- xgb.DMatrix(newdata)
} }
ret <- .Call("XGBoosterPredict_R", object, newdata, as.integer(outputmargin), PACKAGE = "xgboost") if (is.null(ntreelimit)) {
ntreelimit <- 0
} else {
if (ntreelimit < 1){
stop("predict: ntreelimit must be greater equal than 1")
}
}
ret <- .Call("XGBoosterPredict_R", object, newdata, as.integer(outputmargin), as.integer(ntreelimit), PACKAGE = "xgboost")
return(ret) return(ret)
}) })

View File

@ -247,12 +247,13 @@ extern "C" {
&vec_dmats[0], &vec_sptr[0], len)); &vec_dmats[0], &vec_sptr[0], len));
_WrapperEnd(); _WrapperEnd();
} }
SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP output_margin) { SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP output_margin, SEXP ntree_limit) {
_WrapperBegin(); _WrapperBegin();
bst_ulong olen; bst_ulong olen;
const float *res = XGBoosterPredict(R_ExternalPtrAddr(handle), const float *res = XGBoosterPredict(R_ExternalPtrAddr(handle),
R_ExternalPtrAddr(dmat), R_ExternalPtrAddr(dmat),
asInteger(output_margin), asInteger(output_margin),
asInteger(ntree_limit),
&olen); &olen);
SEXP ret = PROTECT(allocVector(REALSXP, olen)); SEXP ret = PROTECT(allocVector(REALSXP, olen));
for (size_t i = 0; i < olen; ++i) { for (size_t i = 0; i < olen; ++i) {

View File

@ -107,8 +107,9 @@ extern "C" {
* \param handle handle * \param handle handle
* \param dmat data matrix * \param dmat data matrix
* \param output_margin whether only output raw margin value * \param output_margin whether only output raw margin value
* \param ntree_limit limit number of trees used in prediction
*/ */
SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP output_margin); SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP output_margin, SEXP ntree_limit);
/*! /*!
* \brief load model from existing file * \brief load model from existing file
* \param handle handle * \param handle handle

View File

@ -105,7 +105,10 @@ class GBLinear : public IGradBooster {
virtual void Predict(IFMatrix *p_fmat, virtual void Predict(IFMatrix *p_fmat,
int64_t buffer_offset, int64_t buffer_offset,
const BoosterInfo &info, const BoosterInfo &info,
std::vector<float> *out_preds) { std::vector<float> *out_preds,
unsigned ntree_limit = 0) {
utils::Check(ntree_limit == 0,
"GBLinear::Predict ntrees is only valid for gbtree predictor");
std::vector<float> &preds = *out_preds; std::vector<float> &preds = *out_preds;
preds.resize(0); preds.resize(0);
// start collecting the prediction // start collecting the prediction

View File

@ -57,11 +57,14 @@ class IGradBooster {
* the size of buffer is set by convention using IGradBooster.SetParam("num_pbuffer","size") * the size of buffer is set by convention using IGradBooster.SetParam("num_pbuffer","size")
* \param info extra side information that may be needed for prediction * \param info extra side information that may be needed for prediction
* \param out_preds output vector to hold the predictions * \param out_preds output vector to hold the predictions
* \param ntree_limit limit the number of trees used in prediction, when it equals 0, this means
* we do not limit number of trees, this parameter is only valid for gbtree, but not for gblinear
*/ */
virtual void Predict(IFMatrix *p_fmat, virtual void Predict(IFMatrix *p_fmat,
int64_t buffer_offset, int64_t buffer_offset,
const BoosterInfo &info, const BoosterInfo &info,
std::vector<float> *out_preds) = 0; std::vector<float> *out_preds,
unsigned ntree_limit = 0) = 0;
/*! /*!
* \brief dump the model in text format * \brief dump the model in text format
* \param fmap feature map that may help give interpretations of feature * \param fmap feature map that may help give interpretations of feature

View File

@ -105,7 +105,8 @@ class GBTree : public IGradBooster {
virtual void Predict(IFMatrix *p_fmat, virtual void Predict(IFMatrix *p_fmat,
int64_t buffer_offset, int64_t buffer_offset,
const BoosterInfo &info, const BoosterInfo &info,
std::vector<float> *out_preds) { std::vector<float> *out_preds,
unsigned ntree_limit = 0) {
int nthread; int nthread;
#pragma omp parallel #pragma omp parallel
{ {
@ -137,7 +138,8 @@ class GBTree : public IGradBooster {
this->Pred(batch[i], this->Pred(batch[i],
buffer_offset < 0 ? -1 : buffer_offset + ridx, buffer_offset < 0 ? -1 : buffer_offset + ridx,
gid, info.GetRoot(ridx), &feats, gid, info.GetRoot(ridx), &feats,
&preds[ridx * mparam.num_output_group + gid], stride); &preds[ridx * mparam.num_output_group + gid], stride,
ntree_limit);
} }
} }
} }
@ -212,14 +214,16 @@ class GBTree : public IGradBooster {
int bst_group, int bst_group,
unsigned root_index, unsigned root_index,
tree::RegTree::FVec *p_feats, tree::RegTree::FVec *p_feats,
float *out_pred, size_t stride) { float *out_pred, size_t stride, unsigned ntree_limit) {
size_t itop = 0; size_t itop = 0;
float psum = 0.0f; float psum = 0.0f;
// sum of leaf vector // sum of leaf vector
std::vector<float> vec_psum(mparam.size_leaf_vector, 0.0f); std::vector<float> vec_psum(mparam.size_leaf_vector, 0.0f);
const int64_t bid = mparam.BufferOffset(buffer_index, bst_group); const int64_t bid = mparam.BufferOffset(buffer_index, bst_group);
// number of valid trees
unsigned treeleft = ntree_limit == 0 ? std::numeric_limits<unsigned>::max() : ntree_limit;
// load buffered results if any // load buffered results if any
if (bid >= 0) { if (bid >= 0 && ntree_limit == 0) {
itop = pred_counter[bid]; itop = pred_counter[bid];
psum = pred_buffer[bid]; psum = pred_buffer[bid];
for (int i = 0; i < mparam.size_leaf_vector; ++i) { for (int i = 0; i < mparam.size_leaf_vector; ++i) {
@ -235,12 +239,13 @@ class GBTree : public IGradBooster {
for (int j = 0; j < mparam.size_leaf_vector; ++j) { for (int j = 0; j < mparam.size_leaf_vector; ++j) {
vec_psum[j] += trees[i]->leafvec(tid)[j]; vec_psum[j] += trees[i]->leafvec(tid)[j];
} }
if(--treeleft == 0) break;
} }
} }
p_feats->Drop(inst); p_feats->Drop(inst);
} }
// updated the buffered results // updated the buffered results
if (bid >= 0) { if (bid >= 0 && ntree_limit == 0) {
pred_counter[bid] = static_cast<unsigned>(trees.size()); pred_counter[bid] = static_cast<unsigned>(trees.size());
pred_buffer[bid] = psum; pred_buffer[bid] = psum;
for (int i = 0; i < mparam.size_leaf_vector; ++i) { for (int i = 0; i < mparam.size_leaf_vector; ++i) {

View File

@ -212,11 +212,14 @@ class BoostLearner {
* \param data input data * \param data input data
* \param output_margin whether to only predict margin value instead of transformed prediction * \param output_margin whether to only predict margin value instead of transformed prediction
* \param out_preds output vector that stores the prediction * \param out_preds output vector that stores the prediction
* \param ntree_limit limit number of trees used for boosted tree
* predictor, when it equals 0, this means we are using all the trees
*/ */
inline void Predict(const DMatrix &data, inline void Predict(const DMatrix &data,
bool output_margin, bool output_margin,
std::vector<float> *out_preds) const { std::vector<float> *out_preds,
this->PredictRaw(data, out_preds); unsigned ntree_limit = 0) const {
this->PredictRaw(data, out_preds, ntree_limit);
if (!output_margin) { if (!output_margin) {
obj_->PredTransform(out_preds); obj_->PredTransform(out_preds);
} }
@ -246,11 +249,14 @@ class BoostLearner {
* \brief get un-transformed prediction * \brief get un-transformed prediction
* \param data training data matrix * \param data training data matrix
* \param out_preds output vector that stores the prediction * \param out_preds output vector that stores the prediction
* \param ntree_limit limit number of trees used for boosted tree
* predictor, when it equals 0, this means we are using all the trees
*/ */
inline void PredictRaw(const DMatrix &data, inline void PredictRaw(const DMatrix &data,
std::vector<float> *out_preds) const { std::vector<float> *out_preds,
unsigned ntree_limit = 0) const {
gbm_->Predict(data.fmat(), this->FindBufferOffset(data), gbm_->Predict(data.fmat(), this->FindBufferOffset(data),
data.info.info, out_preds); data.info.info, out_preds, ntree_limit);
// add base margin // add base margin
std::vector<float> &preds = *out_preds; std::vector<float> &preds = *out_preds;
const bst_omp_uint ndata = static_cast<bst_omp_uint>(preds.size()); const bst_omp_uint ndata = static_cast<bst_omp_uint>(preds.size());

View File

@ -192,15 +192,16 @@ class Booster:
return xglib.XGBoosterEvalOneIter(self.handle, it, dmats, evnames, len(evals)) return xglib.XGBoosterEvalOneIter(self.handle, it, dmats, evnames, len(evals))
def eval(self, mat, name = 'eval', it = 0): def eval(self, mat, name = 'eval', it = 0):
return self.eval_set( [(mat,name)], it) return self.eval_set( [(mat,name)], it)
def predict(self, data, output_margin=False): def predict(self, data, output_margin=False, ntree_limit=0):
""" """
predict with data predict with data
data: the dmatrix storing the input data: the dmatrix storing the input
output_margin: whether output raw margin value that is untransformed output_margin: whether output raw margin value that is untransformed
ntree_limit: limit number of trees in prediction, default to 0, 0 means using all the trees
""" """
length = ctypes.c_ulong() length = ctypes.c_ulong()
preds = xglib.XGBoosterPredict(self.handle, data.handle, preds = xglib.XGBoosterPredict(self.handle, data.handle,
int(output_margin), ctypes.byref(length)) int(output_margin), ntree_limit, ctypes.byref(length))
return ctypes2numpy(preds, length.value, 'float32') return ctypes2numpy(preds, length.value, 'float32')
def save_model(self, fname): def save_model(self, fname):
""" save model to file """ """ save model to file """

View File

@ -25,9 +25,9 @@ class Booster: public learner::BoostLearner {
this->init_model = false; this->init_model = false;
this->SetCacheData(mats); this->SetCacheData(mats);
} }
const float *Pred(const DataMatrix &dmat, int output_margin, bst_ulong *len) { inline const float *Pred(const DataMatrix &dmat, int output_margin, unsigned ntree_limit, bst_ulong *len) {
this->CheckInitModel(); this->CheckInitModel();
this->Predict(dmat, output_margin != 0, &this->preds_); this->Predict(dmat, output_margin != 0, &this->preds_, ntree_limit);
*len = static_cast<bst_ulong>(this->preds_.size()); *len = static_cast<bst_ulong>(this->preds_.size());
return &this->preds_[0]; return &this->preds_[0];
} }
@ -249,8 +249,8 @@ extern "C"{
bst->eval_str = bst->EvalOneIter(iter, mats, names); bst->eval_str = bst->EvalOneIter(iter, mats, names);
return bst->eval_str.c_str(); return bst->eval_str.c_str();
} }
const float *XGBoosterPredict(void *handle, void *dmat, int output_margin, bst_ulong *len) { const float *XGBoosterPredict(void *handle, void *dmat, int output_margin, unsigned ntree_limit, bst_ulong *len) {
return static_cast<Booster*>(handle)->Pred(*static_cast<DataMatrix*>(dmat), output_margin, len); return static_cast<Booster*>(handle)->Pred(*static_cast<DataMatrix*>(dmat), output_margin, ntree_limit, len);
} }
void XGBoosterLoadModel(void *handle, const char *fname) { void XGBoosterLoadModel(void *handle, const char *fname) {
static_cast<Booster*>(handle)->LoadModel(fname); static_cast<Booster*>(handle)->LoadModel(fname);

View File

@ -165,9 +165,11 @@ extern "C" {
* \param handle handle * \param handle handle
* \param dmat data matrix * \param dmat data matrix
* \param output_margin whether only output raw margin value * \param output_margin whether only output raw margin value
* \param ntree_limit limit number of trees used for prediction, this is only valid for boosted trees
* when the parameter is set to 0, we will use all the trees
* \param len used to store length of returning result * \param len used to store length of returning result
*/ */
XGB_DLL const float *XGBoosterPredict(void *handle, void *dmat, int output_margin, bst_ulong *len); XGB_DLL const float *XGBoosterPredict(void *handle, void *dmat, int output_margin, unsigned ntree_limit, bst_ulong *len);
/*! /*!
* \brief load model from existing file * \brief load model from existing file
* \param handle handle * \param handle handle