diff --git a/wrapper/xgboost.py b/wrapper/xgboost.py index 34c4bfde7..a0a88af47 100644 --- a/wrapper/xgboost.py +++ b/wrapper/xgboost.py @@ -371,7 +371,7 @@ class Booster: return xglib.XGBoosterEvalOneIter(self.handle, it, dmats, evnames, len(evals)) def eval(self, mat, name = 'eval', it = 0): return self.eval_set( [(mat,name)], it) - def predict(self, data, output_margin=False): + def predict(self, data, output_margin=False, ntree_limit=0): """ predict with data Args: @@ -379,12 +379,14 @@ class Booster: the dmatrix storing the input output_margin: bool whether output raw margin value that is untransformed + + ntree_limit: limit number of trees in prediction, default to 0, 0 means using all the trees Returns: numpy array of prediction """ length = ctypes.c_ulong() preds = xglib.XGBoosterPredict(self.handle, data.handle, - int(output_margin), ctypes.byref(length)) + int(output_margin), ntree_limit, ctypes.byref(length)) return ctypes2numpy(preds, length.value, 'float32') def save_model(self, fname): """ save model to file