fix ntreelimit
This commit is contained in:
parent
5177fa02e4
commit
e4817bb4c3
@ -371,7 +371,7 @@ class Booster:
|
|||||||
return xglib.XGBoosterEvalOneIter(self.handle, it, dmats, evnames, len(evals))
|
return xglib.XGBoosterEvalOneIter(self.handle, it, dmats, evnames, len(evals))
|
||||||
def eval(self, mat, name = 'eval', it = 0):
|
def eval(self, mat, name = 'eval', it = 0):
|
||||||
return self.eval_set( [(mat,name)], it)
|
return self.eval_set( [(mat,name)], it)
|
||||||
def predict(self, data, output_margin=False):
|
def predict(self, data, output_margin=False, ntree_limit=0):
|
||||||
"""
|
"""
|
||||||
predict with data
|
predict with data
|
||||||
Args:
|
Args:
|
||||||
@ -379,12 +379,14 @@ class Booster:
|
|||||||
the dmatrix storing the input
|
the dmatrix storing the input
|
||||||
output_margin: bool
|
output_margin: bool
|
||||||
whether output raw margin value that is untransformed
|
whether output raw margin value that is untransformed
|
||||||
|
|
||||||
|
ntree_limit: limit number of trees in prediction, default to 0, 0 means using all the trees
|
||||||
Returns:
|
Returns:
|
||||||
numpy array of prediction
|
numpy array of prediction
|
||||||
"""
|
"""
|
||||||
length = ctypes.c_ulong()
|
length = ctypes.c_ulong()
|
||||||
preds = xglib.XGBoosterPredict(self.handle, data.handle,
|
preds = xglib.XGBoosterPredict(self.handle, data.handle,
|
||||||
int(output_margin), ctypes.byref(length))
|
int(output_margin), ntree_limit, ctypes.byref(length))
|
||||||
return ctypes2numpy(preds, length.value, 'float32')
|
return ctypes2numpy(preds, length.value, 'float32')
|
||||||
def save_model(self, fname):
|
def save_model(self, fname):
|
||||||
""" save model to file
|
""" save model to file
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user