Merge pull request #586 from Far0n/sklearn_wrapper

sklearn_wrapper additions fixed #420
This commit is contained in:
Yuan (Terry) Tang 2015-11-02 12:07:12 -06:00
commit 7f559235be

View File

@ -212,10 +212,12 @@ class XGBModel(XGBModelBase):
self.best_iteration = self._Booster.best_iteration self.best_iteration = self._Booster.best_iteration
return self return self
def predict(self, data): def predict(self, data, output_margin=False, ntree_limit=0):
# pylint: disable=missing-docstring,invalid-name # pylint: disable=missing-docstring,invalid-name
test_dmatrix = DMatrix(data, missing=self.missing) test_dmatrix = DMatrix(data, missing=self.missing)
return self.booster().predict(test_dmatrix) return self.booster().predict(test_dmatrix,
output_margin=output_margin,
ntree_limit=ntree_limit)
def evals_result(self): def evals_result(self):
"""Return the evaluation results. """Return the evaluation results.
@ -366,9 +368,11 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
return self return self
def predict(self, data): def predict(self, data, output_margin=False, ntree_limit=0):
test_dmatrix = DMatrix(data, missing=self.missing) test_dmatrix = DMatrix(data, missing=self.missing)
class_probs = self.booster().predict(test_dmatrix) class_probs = self.booster().predict(test_dmatrix,
output_margin=output_margin,
ntree_limit=ntree_limit)
if len(class_probs.shape) > 1: if len(class_probs.shape) > 1:
column_indexes = np.argmax(class_probs, axis=1) column_indexes = np.argmax(class_probs, axis=1)
else: else:
@ -376,9 +380,11 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
column_indexes[class_probs > 0.5] = 1 column_indexes[class_probs > 0.5] = 1
return self._le.inverse_transform(column_indexes) return self._le.inverse_transform(column_indexes)
def predict_proba(self, data): def predict_proba(self, data, output_margin=False, ntree_limit=0):
test_dmatrix = DMatrix(data, missing=self.missing) test_dmatrix = DMatrix(data, missing=self.missing)
class_probs = self.booster().predict(test_dmatrix) class_probs = self.booster().predict(test_dmatrix,
output_margin=output_margin,
ntree_limit=ntree_limit)
if self.objective == "multi:softprob": if self.objective == "multi:softprob":
return class_probs return class_probs
else: else: