From e4817bb4c3f0b8d395e5343382e1cba5fe2ec577 Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 2 Sep 2014 15:05:49 -0700 Subject: [PATCH] fix ntreelimit --- wrapper/xgboost.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/wrapper/xgboost.py b/wrapper/xgboost.py index 34c4bfde7..a0a88af47 100644 --- a/wrapper/xgboost.py +++ b/wrapper/xgboost.py @@ -371,7 +371,7 @@ class Booster: return xglib.XGBoosterEvalOneIter(self.handle, it, dmats, evnames, len(evals)) def eval(self, mat, name = 'eval', it = 0): return self.eval_set( [(mat,name)], it) - def predict(self, data, output_margin=False): + def predict(self, data, output_margin=False, ntree_limit=0): """ predict with data Args: @@ -379,12 +379,14 @@ class Booster: the dmatrix storing the input output_margin: bool whether output raw margin value that is untransformed + + ntree_limit: limit number of trees in prediction, default to 0, 0 means using all the trees Returns: numpy array of prediction """ length = ctypes.c_ulong() preds = xglib.XGBoosterPredict(self.handle, data.handle, - int(output_margin), ctypes.byref(length)) + int(output_margin), ntree_limit, ctypes.byref(length)) return ctypes2numpy(preds, length.value, 'float32') def save_model(self, fname): """ save model to file