add ntree limit

This commit is contained in:
tqchen
2014-09-01 15:10:19 -07:00
parent 4c451de90b
commit 4592e500cb
10 changed files with 53 additions and 23 deletions

View File

@@ -192,15 +192,16 @@ class Booster:
return xglib.XGBoosterEvalOneIter(self.handle, it, dmats, evnames, len(evals))
def eval(self, mat, name = 'eval', it = 0):
return self.eval_set( [(mat,name)], it)
def predict(self, data, output_margin=False):
def predict(self, data, output_margin=False, ntree_limit=0):
"""
predict with data
data: the dmatrix storing the input
output_margin: whether output raw margin value that is untransformed
ntree_limit: limit number of trees in prediction, default to 0, 0 means using all the trees
"""
length = ctypes.c_ulong()
preds = xglib.XGBoosterPredict(self.handle, data.handle,
int(output_margin), ctypes.byref(length))
int(output_margin), ntree_limit, ctypes.byref(length))
return ctypes2numpy(preds, length.value, 'float32')
def save_model(self, fname):
""" save model to file """

View File

@@ -25,9 +25,9 @@ class Booster: public learner::BoostLearner {
this->init_model = false;
this->SetCacheData(mats);
}
const float *Pred(const DataMatrix &dmat, int output_margin, bst_ulong *len) {
inline const float *Pred(const DataMatrix &dmat, int output_margin, unsigned ntree_limit, bst_ulong *len) {
this->CheckInitModel();
this->Predict(dmat, output_margin != 0, &this->preds_);
this->Predict(dmat, output_margin != 0, &this->preds_, ntree_limit);
*len = static_cast<bst_ulong>(this->preds_.size());
return &this->preds_[0];
}
@@ -249,8 +249,8 @@ extern "C"{
bst->eval_str = bst->EvalOneIter(iter, mats, names);
return bst->eval_str.c_str();
}
const float *XGBoosterPredict(void *handle, void *dmat, int output_margin, bst_ulong *len) {
return static_cast<Booster*>(handle)->Pred(*static_cast<DataMatrix*>(dmat), output_margin, len);
const float *XGBoosterPredict(void *handle, void *dmat, int output_margin, unsigned ntree_limit, bst_ulong *len) {
return static_cast<Booster*>(handle)->Pred(*static_cast<DataMatrix*>(dmat), output_margin, ntree_limit, len);
}
void XGBoosterLoadModel(void *handle, const char *fname) {
static_cast<Booster*>(handle)->LoadModel(fname);

View File

@@ -165,9 +165,11 @@ extern "C" {
* \param handle handle
* \param dmat data matrix
* \param output_margin whether only output raw margin value
* \param ntree_limit limit number of trees used for prediction, this is only valid for boosted trees
* when the parameter is set to 0, we will use all the trees
* \param len used to store length of returning result
*/
XGB_DLL const float *XGBoosterPredict(void *handle, void *dmat, int output_margin, bst_ulong *len);
XGB_DLL const float *XGBoosterPredict(void *handle, void *dmat, int output_margin, unsigned ntree_limit, bst_ulong *len);
/*!
* \brief load model from existing file
* \param handle handle