Merge pull request #586 from Far0n/sklearn_wrapper
sklearn_wrapper additions fixed #420
This commit is contained in:
commit
7f559235be
@ -212,10 +212,12 @@ class XGBModel(XGBModelBase):
|
||||
self.best_iteration = self._Booster.best_iteration
|
||||
return self
|
||||
|
||||
def predict(self, data):
|
||||
def predict(self, data, output_margin=False, ntree_limit=0):
|
||||
# pylint: disable=missing-docstring,invalid-name
|
||||
test_dmatrix = DMatrix(data, missing=self.missing)
|
||||
return self.booster().predict(test_dmatrix)
|
||||
return self.booster().predict(test_dmatrix,
|
||||
output_margin=output_margin,
|
||||
ntree_limit=ntree_limit)
|
||||
|
||||
def evals_result(self):
|
||||
"""Return the evaluation results.
|
||||
@ -366,9 +368,11 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
|
||||
return self
|
||||
|
||||
def predict(self, data):
|
||||
def predict(self, data, output_margin=False, ntree_limit=0):
|
||||
test_dmatrix = DMatrix(data, missing=self.missing)
|
||||
class_probs = self.booster().predict(test_dmatrix)
|
||||
class_probs = self.booster().predict(test_dmatrix,
|
||||
output_margin=output_margin,
|
||||
ntree_limit=ntree_limit)
|
||||
if len(class_probs.shape) > 1:
|
||||
column_indexes = np.argmax(class_probs, axis=1)
|
||||
else:
|
||||
@ -376,9 +380,11 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
column_indexes[class_probs > 0.5] = 1
|
||||
return self._le.inverse_transform(column_indexes)
|
||||
|
||||
def predict_proba(self, data):
|
||||
def predict_proba(self, data, output_margin=False, ntree_limit=0):
|
||||
test_dmatrix = DMatrix(data, missing=self.missing)
|
||||
class_probs = self.booster().predict(test_dmatrix)
|
||||
class_probs = self.booster().predict(test_dmatrix,
|
||||
output_margin=output_margin,
|
||||
ntree_limit=ntree_limit)
|
||||
if self.objective == "multi:softprob":
|
||||
return class_probs
|
||||
else:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user