[Breaking] Fix .predict() method and add .predict_proba() in xgboost.dask.DaskXGBClassifier (#5986)
This commit is contained in:
parent
6f7112a848
commit
bd6b7f4aa7
@ -1079,13 +1079,34 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
|||||||
return self.client.sync(self._fit_async, X, y, sample_weights,
|
return self.client.sync(self._fit_async, X, y, sample_weights,
|
||||||
eval_set, sample_weight_eval_set, verbose)
|
eval_set, sample_weight_eval_set, verbose)
|
||||||
|
|
||||||
async def _predict_async(self, data):
|
async def _predict_proba_async(self, data):
|
||||||
|
_assert_dask_support()
|
||||||
|
|
||||||
test_dmatrix = await DaskDMatrix(client=self.client, data=data,
|
test_dmatrix = await DaskDMatrix(client=self.client, data=data,
|
||||||
missing=self.missing)
|
missing=self.missing)
|
||||||
pred_probs = await predict(client=self.client,
|
pred_probs = await predict(client=self.client,
|
||||||
model=self.get_booster(), data=test_dmatrix)
|
model=self.get_booster(), data=test_dmatrix)
|
||||||
return pred_probs
|
return pred_probs
|
||||||
|
|
||||||
|
def predict_proba(self, data): # pylint: disable=arguments-differ,missing-docstring
|
||||||
|
_assert_dask_support()
|
||||||
|
return self.client.sync(self._predict_proba_async, data)
|
||||||
|
|
||||||
|
async def _predict_async(self, data):
|
||||||
|
_assert_dask_support()
|
||||||
|
|
||||||
|
test_dmatrix = await DaskDMatrix(client=self.client, data=data,
|
||||||
|
missing=self.missing)
|
||||||
|
pred_probs = await predict(client=self.client,
|
||||||
|
model=self.get_booster(), data=test_dmatrix)
|
||||||
|
|
||||||
|
if self.n_classes_ == 2:
|
||||||
|
preds = (pred_probs > 0.5).astype(int)
|
||||||
|
else:
|
||||||
|
preds = da.argmax(pred_probs, axis=1)
|
||||||
|
|
||||||
|
return preds
|
||||||
|
|
||||||
def predict(self, data): # pylint: disable=arguments-differ
|
def predict(self, data): # pylint: disable=arguments-differ
|
||||||
_assert_dask_support()
|
_assert_dask_support()
|
||||||
return self.client.sync(self._predict_async, data)
|
return self.client.sync(self._predict_async, data)
|
||||||
|
|||||||
@ -165,12 +165,12 @@ def test_dask_missing_value_cls():
|
|||||||
missing=0.0)
|
missing=0.0)
|
||||||
cls.client = client
|
cls.client = client
|
||||||
cls.fit(X, y, eval_set=[(X, y)])
|
cls.fit(X, y, eval_set=[(X, y)])
|
||||||
dd_predt = cls.predict(X).compute()
|
dd_pred_proba = cls.predict_proba(X).compute()
|
||||||
|
|
||||||
np_X = X.compute()
|
np_X = X.compute()
|
||||||
np_predt = cls.get_booster().predict(
|
np_pred_proba = cls.get_booster().predict(
|
||||||
xgb.DMatrix(np_X, missing=0.0))
|
xgb.DMatrix(np_X, missing=0.0))
|
||||||
np.testing.assert_allclose(np_predt, dd_predt)
|
np.testing.assert_allclose(np_pred_proba, dd_pred_proba)
|
||||||
|
|
||||||
cls = xgb.dask.DaskXGBClassifier()
|
cls = xgb.dask.DaskXGBClassifier()
|
||||||
assert hasattr(cls, 'missing')
|
assert hasattr(cls, 'missing')
|
||||||
@ -209,7 +209,7 @@ def test_dask_classifier():
|
|||||||
classifier.fit(X, y, eval_set=[(X, y)])
|
classifier.fit(X, y, eval_set=[(X, y)])
|
||||||
prediction = classifier.predict(X)
|
prediction = classifier.predict(X)
|
||||||
|
|
||||||
assert prediction.ndim == 2
|
assert prediction.ndim == 1
|
||||||
assert prediction.shape[0] == kRows
|
assert prediction.shape[0] == kRows
|
||||||
|
|
||||||
history = classifier.evals_result()
|
history = classifier.evals_result()
|
||||||
@ -222,7 +222,18 @@ def test_dask_classifier():
|
|||||||
assert len(list(history['validation_0'])) == 1
|
assert len(list(history['validation_0'])) == 1
|
||||||
assert len(history['validation_0']['merror']) == 2
|
assert len(history['validation_0']['merror']) == 2
|
||||||
|
|
||||||
|
# Test .predict_proba()
|
||||||
|
probas = classifier.predict_proba(X)
|
||||||
assert classifier.n_classes_ == 10
|
assert classifier.n_classes_ == 10
|
||||||
|
assert probas.ndim == 2
|
||||||
|
assert probas.shape[0] == kRows
|
||||||
|
assert probas.shape[1] == 10
|
||||||
|
|
||||||
|
cls_booster = classifier.get_booster()
|
||||||
|
single_node_proba = cls_booster.inplace_predict(X.compute())
|
||||||
|
|
||||||
|
np.testing.assert_allclose(single_node_proba,
|
||||||
|
probas.compute())
|
||||||
|
|
||||||
# Test with dataframe.
|
# Test with dataframe.
|
||||||
X_d = dd.from_dask_array(X)
|
X_d = dd.from_dask_array(X)
|
||||||
@ -232,7 +243,7 @@ def test_dask_classifier():
|
|||||||
assert classifier.n_classes_ == 10
|
assert classifier.n_classes_ == 10
|
||||||
prediction = classifier.predict(X_d)
|
prediction = classifier.predict(X_d)
|
||||||
|
|
||||||
assert prediction.ndim == 2
|
assert prediction.ndim == 1
|
||||||
assert prediction.shape[0] == kRows
|
assert prediction.shape[0] == kRows
|
||||||
|
|
||||||
|
|
||||||
@ -407,7 +418,7 @@ async def run_dask_classifier_asyncio(scheduler_address):
|
|||||||
await classifier.fit(X, y, eval_set=[(X, y)])
|
await classifier.fit(X, y, eval_set=[(X, y)])
|
||||||
prediction = await classifier.predict(X)
|
prediction = await classifier.predict(X)
|
||||||
|
|
||||||
assert prediction.ndim == 2
|
assert prediction.ndim == 1
|
||||||
assert prediction.shape[0] == kRows
|
assert prediction.shape[0] == kRows
|
||||||
|
|
||||||
history = classifier.evals_result()
|
history = classifier.evals_result()
|
||||||
@ -420,7 +431,13 @@ async def run_dask_classifier_asyncio(scheduler_address):
|
|||||||
assert len(list(history['validation_0'])) == 1
|
assert len(list(history['validation_0'])) == 1
|
||||||
assert len(history['validation_0']['merror']) == 2
|
assert len(history['validation_0']['merror']) == 2
|
||||||
|
|
||||||
|
# Test .predict_proba()
|
||||||
|
probas = await classifier.predict_proba(X)
|
||||||
assert classifier.n_classes_ == 10
|
assert classifier.n_classes_ == 10
|
||||||
|
assert probas.ndim == 2
|
||||||
|
assert probas.shape[0] == kRows
|
||||||
|
assert probas.shape[1] == 10
|
||||||
|
|
||||||
|
|
||||||
# Test with dataframe.
|
# Test with dataframe.
|
||||||
X_d = dd.from_dask_array(X)
|
X_d = dd.from_dask_array(X)
|
||||||
@ -430,9 +447,8 @@ async def run_dask_classifier_asyncio(scheduler_address):
|
|||||||
assert classifier.n_classes_ == 10
|
assert classifier.n_classes_ == 10
|
||||||
prediction = await classifier.predict(X_d)
|
prediction = await classifier.predict(X_d)
|
||||||
|
|
||||||
assert prediction.ndim == 2
|
assert prediction.ndim == 1
|
||||||
assert prediction.shape[0] == kRows
|
assert prediction.shape[0] == kRows
|
||||||
assert prediction.shape[1] == 10
|
|
||||||
|
|
||||||
|
|
||||||
def test_with_asyncio():
|
def test_with_asyncio():
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user