add ntree limit
This commit is contained in:
parent
4c451de90b
commit
4592e500cb
@ -11,7 +11,8 @@ setClass("xgb.Booster")
|
|||||||
#' value of sum of functions, when outputmargin=TRUE, the prediction is
|
#' value of sum of functions, when outputmargin=TRUE, the prediction is
|
||||||
#' untransformed margin value. In logistic regression, outputmargin=T will
|
#' untransformed margin value. In logistic regression, outputmargin=T will
|
||||||
#' output value before logistic transformation.
|
#' output value before logistic transformation.
|
||||||
#'
|
#' @param ntreelimit limit number of trees used in prediction, this parameter is only valid for gbtree, but not for gblinear.
|
||||||
|
#' set it to be value bigger than 0
|
||||||
#' @examples
|
#' @examples
|
||||||
#' data(iris)
|
#' data(iris)
|
||||||
#' bst <- xgboost(as.matrix(iris[,1:4]),as.numeric(iris[,5]), nrounds = 2)
|
#' bst <- xgboost(as.matrix(iris[,1:4]),as.numeric(iris[,5]), nrounds = 2)
|
||||||
@ -19,11 +20,18 @@ setClass("xgb.Booster")
|
|||||||
#' @export
|
#' @export
|
||||||
#'
|
#'
|
||||||
setMethod("predict", signature = "xgb.Booster",
|
setMethod("predict", signature = "xgb.Booster",
|
||||||
definition = function(object, newdata, outputmargin = FALSE) {
|
definition = function(object, newdata, outputmargin = FALSE, ntreelimit = NULL) {
|
||||||
if (class(newdata) != "xgb.DMatrix") {
|
if (class(newdata) != "xgb.DMatrix") {
|
||||||
newdata <- xgb.DMatrix(newdata)
|
newdata <- xgb.DMatrix(newdata)
|
||||||
}
|
}
|
||||||
ret <- .Call("XGBoosterPredict_R", object, newdata, as.integer(outputmargin), PACKAGE = "xgboost")
|
if (is.null(ntreelimit)) {
|
||||||
|
ntreelimit <- 0
|
||||||
|
} else {
|
||||||
|
if (ntreelimit < 1){
|
||||||
|
stop("predict: ntreelimit must be greater equal than 1")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ret <- .Call("XGBoosterPredict_R", object, newdata, as.integer(outputmargin), as.integer(ntreelimit), PACKAGE = "xgboost")
|
||||||
return(ret)
|
return(ret)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@ -247,12 +247,13 @@ extern "C" {
|
|||||||
&vec_dmats[0], &vec_sptr[0], len));
|
&vec_dmats[0], &vec_sptr[0], len));
|
||||||
_WrapperEnd();
|
_WrapperEnd();
|
||||||
}
|
}
|
||||||
SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP output_margin) {
|
SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP output_margin, SEXP ntree_limit) {
|
||||||
_WrapperBegin();
|
_WrapperBegin();
|
||||||
bst_ulong olen;
|
bst_ulong olen;
|
||||||
const float *res = XGBoosterPredict(R_ExternalPtrAddr(handle),
|
const float *res = XGBoosterPredict(R_ExternalPtrAddr(handle),
|
||||||
R_ExternalPtrAddr(dmat),
|
R_ExternalPtrAddr(dmat),
|
||||||
asInteger(output_margin),
|
asInteger(output_margin),
|
||||||
|
asInteger(ntree_limit),
|
||||||
&olen);
|
&olen);
|
||||||
SEXP ret = PROTECT(allocVector(REALSXP, olen));
|
SEXP ret = PROTECT(allocVector(REALSXP, olen));
|
||||||
for (size_t i = 0; i < olen; ++i) {
|
for (size_t i = 0; i < olen; ++i) {
|
||||||
|
|||||||
@ -107,8 +107,9 @@ extern "C" {
|
|||||||
* \param handle handle
|
* \param handle handle
|
||||||
* \param dmat data matrix
|
* \param dmat data matrix
|
||||||
* \param output_margin whether only output raw margin value
|
* \param output_margin whether only output raw margin value
|
||||||
|
* \param ntree_limit limit number of trees used in prediction
|
||||||
*/
|
*/
|
||||||
SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP output_margin);
|
SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP output_margin, SEXP ntree_limit);
|
||||||
/*!
|
/*!
|
||||||
* \brief load model from existing file
|
* \brief load model from existing file
|
||||||
* \param handle handle
|
* \param handle handle
|
||||||
|
|||||||
@ -105,7 +105,10 @@ class GBLinear : public IGradBooster {
|
|||||||
virtual void Predict(IFMatrix *p_fmat,
|
virtual void Predict(IFMatrix *p_fmat,
|
||||||
int64_t buffer_offset,
|
int64_t buffer_offset,
|
||||||
const BoosterInfo &info,
|
const BoosterInfo &info,
|
||||||
std::vector<float> *out_preds) {
|
std::vector<float> *out_preds,
|
||||||
|
unsigned ntree_limit = 0) {
|
||||||
|
utils::Check(ntree_limit == 0,
|
||||||
|
"GBLinear::Predict ntrees is only valid for gbtree predictor");
|
||||||
std::vector<float> &preds = *out_preds;
|
std::vector<float> &preds = *out_preds;
|
||||||
preds.resize(0);
|
preds.resize(0);
|
||||||
// start collecting the prediction
|
// start collecting the prediction
|
||||||
|
|||||||
@ -57,11 +57,14 @@ class IGradBooster {
|
|||||||
* the size of buffer is set by convention using IGradBooster.SetParam("num_pbuffer","size")
|
* the size of buffer is set by convention using IGradBooster.SetParam("num_pbuffer","size")
|
||||||
* \param info extra side information that may be needed for prediction
|
* \param info extra side information that may be needed for prediction
|
||||||
* \param out_preds output vector to hold the predictions
|
* \param out_preds output vector to hold the predictions
|
||||||
|
* \param ntree_limit limit the number of trees used in prediction, when it equals 0, this means
|
||||||
|
* we do not limit number of trees, this parameter is only valid for gbtree, but not for gblinear
|
||||||
*/
|
*/
|
||||||
virtual void Predict(IFMatrix *p_fmat,
|
virtual void Predict(IFMatrix *p_fmat,
|
||||||
int64_t buffer_offset,
|
int64_t buffer_offset,
|
||||||
const BoosterInfo &info,
|
const BoosterInfo &info,
|
||||||
std::vector<float> *out_preds) = 0;
|
std::vector<float> *out_preds,
|
||||||
|
unsigned ntree_limit = 0) = 0;
|
||||||
/*!
|
/*!
|
||||||
* \brief dump the model in text format
|
* \brief dump the model in text format
|
||||||
* \param fmap feature map that may help give interpretations of feature
|
* \param fmap feature map that may help give interpretations of feature
|
||||||
|
|||||||
@ -105,7 +105,8 @@ class GBTree : public IGradBooster {
|
|||||||
virtual void Predict(IFMatrix *p_fmat,
|
virtual void Predict(IFMatrix *p_fmat,
|
||||||
int64_t buffer_offset,
|
int64_t buffer_offset,
|
||||||
const BoosterInfo &info,
|
const BoosterInfo &info,
|
||||||
std::vector<float> *out_preds) {
|
std::vector<float> *out_preds,
|
||||||
|
unsigned ntree_limit = 0) {
|
||||||
int nthread;
|
int nthread;
|
||||||
#pragma omp parallel
|
#pragma omp parallel
|
||||||
{
|
{
|
||||||
@ -137,7 +138,8 @@ class GBTree : public IGradBooster {
|
|||||||
this->Pred(batch[i],
|
this->Pred(batch[i],
|
||||||
buffer_offset < 0 ? -1 : buffer_offset + ridx,
|
buffer_offset < 0 ? -1 : buffer_offset + ridx,
|
||||||
gid, info.GetRoot(ridx), &feats,
|
gid, info.GetRoot(ridx), &feats,
|
||||||
&preds[ridx * mparam.num_output_group + gid], stride);
|
&preds[ridx * mparam.num_output_group + gid], stride,
|
||||||
|
ntree_limit);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -212,14 +214,16 @@ class GBTree : public IGradBooster {
|
|||||||
int bst_group,
|
int bst_group,
|
||||||
unsigned root_index,
|
unsigned root_index,
|
||||||
tree::RegTree::FVec *p_feats,
|
tree::RegTree::FVec *p_feats,
|
||||||
float *out_pred, size_t stride) {
|
float *out_pred, size_t stride, unsigned ntree_limit) {
|
||||||
size_t itop = 0;
|
size_t itop = 0;
|
||||||
float psum = 0.0f;
|
float psum = 0.0f;
|
||||||
// sum of leaf vector
|
// sum of leaf vector
|
||||||
std::vector<float> vec_psum(mparam.size_leaf_vector, 0.0f);
|
std::vector<float> vec_psum(mparam.size_leaf_vector, 0.0f);
|
||||||
const int64_t bid = mparam.BufferOffset(buffer_index, bst_group);
|
const int64_t bid = mparam.BufferOffset(buffer_index, bst_group);
|
||||||
|
// number of valid trees
|
||||||
|
unsigned treeleft = ntree_limit == 0 ? std::numeric_limits<unsigned>::max() : ntree_limit;
|
||||||
// load buffered results if any
|
// load buffered results if any
|
||||||
if (bid >= 0) {
|
if (bid >= 0 && ntree_limit == 0) {
|
||||||
itop = pred_counter[bid];
|
itop = pred_counter[bid];
|
||||||
psum = pred_buffer[bid];
|
psum = pred_buffer[bid];
|
||||||
for (int i = 0; i < mparam.size_leaf_vector; ++i) {
|
for (int i = 0; i < mparam.size_leaf_vector; ++i) {
|
||||||
@ -235,12 +239,13 @@ class GBTree : public IGradBooster {
|
|||||||
for (int j = 0; j < mparam.size_leaf_vector; ++j) {
|
for (int j = 0; j < mparam.size_leaf_vector; ++j) {
|
||||||
vec_psum[j] += trees[i]->leafvec(tid)[j];
|
vec_psum[j] += trees[i]->leafvec(tid)[j];
|
||||||
}
|
}
|
||||||
|
if(--treeleft == 0) break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
p_feats->Drop(inst);
|
p_feats->Drop(inst);
|
||||||
}
|
}
|
||||||
// updated the buffered results
|
// updated the buffered results
|
||||||
if (bid >= 0) {
|
if (bid >= 0 && ntree_limit == 0) {
|
||||||
pred_counter[bid] = static_cast<unsigned>(trees.size());
|
pred_counter[bid] = static_cast<unsigned>(trees.size());
|
||||||
pred_buffer[bid] = psum;
|
pred_buffer[bid] = psum;
|
||||||
for (int i = 0; i < mparam.size_leaf_vector; ++i) {
|
for (int i = 0; i < mparam.size_leaf_vector; ++i) {
|
||||||
|
|||||||
@ -212,11 +212,14 @@ class BoostLearner {
|
|||||||
* \param data input data
|
* \param data input data
|
||||||
* \param output_margin whether to only predict margin value instead of transformed prediction
|
* \param output_margin whether to only predict margin value instead of transformed prediction
|
||||||
* \param out_preds output vector that stores the prediction
|
* \param out_preds output vector that stores the prediction
|
||||||
|
* \param ntree_limit limit number of trees used for boosted tree
|
||||||
|
* predictor, when it equals 0, this means we are using all the trees
|
||||||
*/
|
*/
|
||||||
inline void Predict(const DMatrix &data,
|
inline void Predict(const DMatrix &data,
|
||||||
bool output_margin,
|
bool output_margin,
|
||||||
std::vector<float> *out_preds) const {
|
std::vector<float> *out_preds,
|
||||||
this->PredictRaw(data, out_preds);
|
unsigned ntree_limit = 0) const {
|
||||||
|
this->PredictRaw(data, out_preds, ntree_limit);
|
||||||
if (!output_margin) {
|
if (!output_margin) {
|
||||||
obj_->PredTransform(out_preds);
|
obj_->PredTransform(out_preds);
|
||||||
}
|
}
|
||||||
@ -246,11 +249,14 @@ class BoostLearner {
|
|||||||
* \brief get un-transformed prediction
|
* \brief get un-transformed prediction
|
||||||
* \param data training data matrix
|
* \param data training data matrix
|
||||||
* \param out_preds output vector that stores the prediction
|
* \param out_preds output vector that stores the prediction
|
||||||
|
* \param ntree_limit limit number of trees used for boosted tree
|
||||||
|
* predictor, when it equals 0, this means we are using all the trees
|
||||||
*/
|
*/
|
||||||
inline void PredictRaw(const DMatrix &data,
|
inline void PredictRaw(const DMatrix &data,
|
||||||
std::vector<float> *out_preds) const {
|
std::vector<float> *out_preds,
|
||||||
|
unsigned ntree_limit = 0) const {
|
||||||
gbm_->Predict(data.fmat(), this->FindBufferOffset(data),
|
gbm_->Predict(data.fmat(), this->FindBufferOffset(data),
|
||||||
data.info.info, out_preds);
|
data.info.info, out_preds, ntree_limit);
|
||||||
// add base margin
|
// add base margin
|
||||||
std::vector<float> &preds = *out_preds;
|
std::vector<float> &preds = *out_preds;
|
||||||
const bst_omp_uint ndata = static_cast<bst_omp_uint>(preds.size());
|
const bst_omp_uint ndata = static_cast<bst_omp_uint>(preds.size());
|
||||||
|
|||||||
@ -192,15 +192,16 @@ class Booster:
|
|||||||
return xglib.XGBoosterEvalOneIter(self.handle, it, dmats, evnames, len(evals))
|
return xglib.XGBoosterEvalOneIter(self.handle, it, dmats, evnames, len(evals))
|
||||||
def eval(self, mat, name = 'eval', it = 0):
|
def eval(self, mat, name = 'eval', it = 0):
|
||||||
return self.eval_set( [(mat,name)], it)
|
return self.eval_set( [(mat,name)], it)
|
||||||
def predict(self, data, output_margin=False):
|
def predict(self, data, output_margin=False, ntree_limit=0):
|
||||||
"""
|
"""
|
||||||
predict with data
|
predict with data
|
||||||
data: the dmatrix storing the input
|
data: the dmatrix storing the input
|
||||||
output_margin: whether output raw margin value that is untransformed
|
output_margin: whether output raw margin value that is untransformed
|
||||||
|
ntree_limit: limit number of trees in prediction, default to 0, 0 means using all the trees
|
||||||
"""
|
"""
|
||||||
length = ctypes.c_ulong()
|
length = ctypes.c_ulong()
|
||||||
preds = xglib.XGBoosterPredict(self.handle, data.handle,
|
preds = xglib.XGBoosterPredict(self.handle, data.handle,
|
||||||
int(output_margin), ctypes.byref(length))
|
int(output_margin), ntree_limit, ctypes.byref(length))
|
||||||
return ctypes2numpy(preds, length.value, 'float32')
|
return ctypes2numpy(preds, length.value, 'float32')
|
||||||
def save_model(self, fname):
|
def save_model(self, fname):
|
||||||
""" save model to file """
|
""" save model to file """
|
||||||
|
|||||||
@ -25,9 +25,9 @@ class Booster: public learner::BoostLearner {
|
|||||||
this->init_model = false;
|
this->init_model = false;
|
||||||
this->SetCacheData(mats);
|
this->SetCacheData(mats);
|
||||||
}
|
}
|
||||||
const float *Pred(const DataMatrix &dmat, int output_margin, bst_ulong *len) {
|
inline const float *Pred(const DataMatrix &dmat, int output_margin, unsigned ntree_limit, bst_ulong *len) {
|
||||||
this->CheckInitModel();
|
this->CheckInitModel();
|
||||||
this->Predict(dmat, output_margin != 0, &this->preds_);
|
this->Predict(dmat, output_margin != 0, &this->preds_, ntree_limit);
|
||||||
*len = static_cast<bst_ulong>(this->preds_.size());
|
*len = static_cast<bst_ulong>(this->preds_.size());
|
||||||
return &this->preds_[0];
|
return &this->preds_[0];
|
||||||
}
|
}
|
||||||
@ -249,8 +249,8 @@ extern "C"{
|
|||||||
bst->eval_str = bst->EvalOneIter(iter, mats, names);
|
bst->eval_str = bst->EvalOneIter(iter, mats, names);
|
||||||
return bst->eval_str.c_str();
|
return bst->eval_str.c_str();
|
||||||
}
|
}
|
||||||
const float *XGBoosterPredict(void *handle, void *dmat, int output_margin, bst_ulong *len) {
|
const float *XGBoosterPredict(void *handle, void *dmat, int output_margin, unsigned ntree_limit, bst_ulong *len) {
|
||||||
return static_cast<Booster*>(handle)->Pred(*static_cast<DataMatrix*>(dmat), output_margin, len);
|
return static_cast<Booster*>(handle)->Pred(*static_cast<DataMatrix*>(dmat), output_margin, ntree_limit, len);
|
||||||
}
|
}
|
||||||
void XGBoosterLoadModel(void *handle, const char *fname) {
|
void XGBoosterLoadModel(void *handle, const char *fname) {
|
||||||
static_cast<Booster*>(handle)->LoadModel(fname);
|
static_cast<Booster*>(handle)->LoadModel(fname);
|
||||||
|
|||||||
@ -165,9 +165,11 @@ extern "C" {
|
|||||||
* \param handle handle
|
* \param handle handle
|
||||||
* \param dmat data matrix
|
* \param dmat data matrix
|
||||||
* \param output_margin whether only output raw margin value
|
* \param output_margin whether only output raw margin value
|
||||||
|
* \param ntree_limit limit number of trees used for prediction, this is only valid for boosted trees
|
||||||
|
* when the parameter is set to 0, we will use all the trees
|
||||||
* \param len used to store length of returning result
|
* \param len used to store length of returning result
|
||||||
*/
|
*/
|
||||||
XGB_DLL const float *XGBoosterPredict(void *handle, void *dmat, int output_margin, bst_ulong *len);
|
XGB_DLL const float *XGBoosterPredict(void *handle, void *dmat, int output_margin, unsigned ntree_limit, bst_ulong *len);
|
||||||
/*!
|
/*!
|
||||||
* \brief load model from existing file
|
* \brief load model from existing file
|
||||||
* \param handle handle
|
* \param handle handle
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user