fix ntreelimit

This commit is contained in:
tqchen 2014-09-02 15:05:49 -07:00
parent 5177fa02e4
commit e4817bb4c3

View File

@ -371,7 +371,7 @@ class Booster:
return xglib.XGBoosterEvalOneIter(self.handle, it, dmats, evnames, len(evals)) return xglib.XGBoosterEvalOneIter(self.handle, it, dmats, evnames, len(evals))
def eval(self, mat, name = 'eval', it = 0): def eval(self, mat, name = 'eval', it = 0):
return self.eval_set( [(mat,name)], it) return self.eval_set( [(mat,name)], it)
def predict(self, data, output_margin=False): def predict(self, data, output_margin=False, ntree_limit=0):
""" """
predict with data predict with data
Args: Args:
@ -379,12 +379,14 @@ class Booster:
the dmatrix storing the input the dmatrix storing the input
output_margin: bool output_margin: bool
whether output raw margin value that is untransformed whether output raw margin value that is untransformed
ntree_limit: limit number of trees in prediction, default to 0, 0 means using all the trees
Returns: Returns:
numpy array of prediction numpy array of prediction
""" """
length = ctypes.c_ulong() length = ctypes.c_ulong()
preds = xglib.XGBoosterPredict(self.handle, data.handle, preds = xglib.XGBoosterPredict(self.handle, data.handle,
int(output_margin), ctypes.byref(length)) int(output_margin), ntree_limit, ctypes.byref(length))
return ctypes2numpy(preds, length.value, 'float32') return ctypes2numpy(preds, length.value, 'float32')
def save_model(self, fname): def save_model(self, fname):
""" save model to file """ save model to file