Merge pull request #586 from Far0n/sklearn_wrapper
sklearn_wrapper additions fixed #420
This commit is contained in:
commit
7f559235be
@ -212,10 +212,12 @@ class XGBModel(XGBModelBase):
|
|||||||
self.best_iteration = self._Booster.best_iteration
|
self.best_iteration = self._Booster.best_iteration
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def predict(self, data):
|
def predict(self, data, output_margin=False, ntree_limit=0):
|
||||||
# pylint: disable=missing-docstring,invalid-name
|
# pylint: disable=missing-docstring,invalid-name
|
||||||
test_dmatrix = DMatrix(data, missing=self.missing)
|
test_dmatrix = DMatrix(data, missing=self.missing)
|
||||||
return self.booster().predict(test_dmatrix)
|
return self.booster().predict(test_dmatrix,
|
||||||
|
output_margin=output_margin,
|
||||||
|
ntree_limit=ntree_limit)
|
||||||
|
|
||||||
def evals_result(self):
|
def evals_result(self):
|
||||||
"""Return the evaluation results.
|
"""Return the evaluation results.
|
||||||
@ -366,9 +368,11 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def predict(self, data):
|
def predict(self, data, output_margin=False, ntree_limit=0):
|
||||||
test_dmatrix = DMatrix(data, missing=self.missing)
|
test_dmatrix = DMatrix(data, missing=self.missing)
|
||||||
class_probs = self.booster().predict(test_dmatrix)
|
class_probs = self.booster().predict(test_dmatrix,
|
||||||
|
output_margin=output_margin,
|
||||||
|
ntree_limit=ntree_limit)
|
||||||
if len(class_probs.shape) > 1:
|
if len(class_probs.shape) > 1:
|
||||||
column_indexes = np.argmax(class_probs, axis=1)
|
column_indexes = np.argmax(class_probs, axis=1)
|
||||||
else:
|
else:
|
||||||
@ -376,9 +380,11 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
column_indexes[class_probs > 0.5] = 1
|
column_indexes[class_probs > 0.5] = 1
|
||||||
return self._le.inverse_transform(column_indexes)
|
return self._le.inverse_transform(column_indexes)
|
||||||
|
|
||||||
def predict_proba(self, data):
|
def predict_proba(self, data, output_margin=False, ntree_limit=0):
|
||||||
test_dmatrix = DMatrix(data, missing=self.missing)
|
test_dmatrix = DMatrix(data, missing=self.missing)
|
||||||
class_probs = self.booster().predict(test_dmatrix)
|
class_probs = self.booster().predict(test_dmatrix,
|
||||||
|
output_margin=output_margin,
|
||||||
|
ntree_limit=ntree_limit)
|
||||||
if self.objective == "multi:softprob":
|
if self.objective == "multi:softprob":
|
||||||
return class_probs
|
return class_probs
|
||||||
else:
|
else:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user