[Breaking] Rename data to X in predict_proba. (#6555)

New Scikit-Learn version uses keyword argument, and `X` is the predefined
keyword.

* Use pip to install latest Python graphviz on Windows CI.
This commit is contained in:
Jiaming Yuan 2020-12-28 21:36:03 +08:00 committed by GitHub
parent cb207a355d
commit 610ee632cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 14 additions and 14 deletions

View File

@ -1321,10 +1321,10 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
feature_weights=feature_weights, feature_weights=feature_weights,
callbacks=callbacks) callbacks=callbacks)
async def _predict_proba_async(self, data, output_margin=False, async def _predict_proba_async(self, X, output_margin=False,
base_margin=None): base_margin=None):
test_dmatrix = await DaskDMatrix( test_dmatrix = await DaskDMatrix(
client=self.client, data=data, base_margin=base_margin, client=self.client, data=X, base_margin=base_margin,
missing=self.missing missing=self.missing
) )
pred_probs = await predict(client=self.client, pred_probs = await predict(client=self.client,
@ -1334,11 +1334,11 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
return pred_probs return pred_probs
# pylint: disable=arguments-differ,missing-docstring # pylint: disable=arguments-differ,missing-docstring
def predict_proba(self, data, output_margin=False, base_margin=None): def predict_proba(self, X, output_margin=False, base_margin=None):
_assert_dask_support() _assert_dask_support()
return self.client.sync( return self.client.sync(
self._predict_proba_async, self._predict_proba_async,
data, X=X,
output_margin=output_margin, output_margin=output_margin,
base_margin=base_margin base_margin=base_margin
) )

View File

@ -1000,10 +1000,9 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
return self._le.inverse_transform(column_indexes) return self._le.inverse_transform(column_indexes)
return column_indexes return column_indexes
def predict_proba(self, data, ntree_limit=None, validate_features=False, def predict_proba(self, X, ntree_limit=None, validate_features=False,
base_margin=None): base_margin=None):
""" """ Predict the probability of each `X` example being of a given class.
Predict the probability of each `data` example being of a given class.
.. note:: This function is not thread safe .. note:: This function is not thread safe
@ -1013,21 +1012,22 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
Parameters Parameters
---------- ----------
data : array_like X : array_like
Feature matrix. Feature matrix.
ntree_limit : int ntree_limit : int
Limit number of trees in the prediction; defaults to best_ntree_limit if defined Limit number of trees in the prediction; defaults to best_ntree_limit if
(i.e. it has been trained with early stopping), otherwise 0 (use all trees). defined (i.e. it has been trained with early stopping), otherwise 0 (use all
trees).
validate_features : bool validate_features : bool
When this is True, validate that the Booster's and data's feature_names are identical. When this is True, validate that the Booster's and data's feature_names are
Otherwise, it is assumed that the feature_names are the same. identical. Otherwise, it is assumed that the feature_names are the same.
Returns Returns
------- -------
prediction : numpy array prediction : numpy array
a numpy array with the probability of each data example being of a given class. a numpy array with the probability of each data example being of a given class.
""" """
test_dmatrix = DMatrix(data, base_margin=base_margin, test_dmatrix = DMatrix(X, base_margin=base_margin,
missing=self.missing, nthread=self.n_jobs) missing=self.missing, nthread=self.n_jobs)
if ntree_limit is None: if ntree_limit is None:
ntree_limit = getattr(self, "best_ntree_limit", 0) ntree_limit = getattr(self, "best_ntree_limit", 0)

View File

@ -9,7 +9,6 @@ dependencies:
- scikit-learn - scikit-learn
- pandas - pandas
- pytest - pytest
- python-graphviz
- boto3 - boto3
- hypothesis - hypothesis
- jsonschema - jsonschema
@ -17,3 +16,4 @@ dependencies:
- pip: - pip:
- cupy-cuda101 - cupy-cuda101
- modin[all] - modin[all]
- graphviz