[Breaking] Fix .predict() method and add .predict_proba() in xgboost.dask.DaskXGBClassifier (#5986)

This commit is contained in:
jameskrach
2020-08-11 04:11:28 -04:00
committed by GitHub
parent 6f7112a848
commit bd6b7f4aa7
2 changed files with 46 additions and 9 deletions

View File

@@ -1079,13 +1079,34 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
return self.client.sync(self._fit_async, X, y, sample_weights,
eval_set, sample_weight_eval_set, verbose)
async def _predict_async(self, data):
async def _predict_proba_async(self, data):
_assert_dask_support()
test_dmatrix = await DaskDMatrix(client=self.client, data=data,
missing=self.missing)
pred_probs = await predict(client=self.client,
model=self.get_booster(), data=test_dmatrix)
return pred_probs
def predict_proba(self, data): # pylint: disable=arguments-differ,missing-docstring
_assert_dask_support()
return self.client.sync(self._predict_proba_async, data)
async def _predict_async(self, data):
_assert_dask_support()
test_dmatrix = await DaskDMatrix(client=self.client, data=data,
missing=self.missing)
pred_probs = await predict(client=self.client,
model=self.get_booster(), data=test_dmatrix)
if self.n_classes_ == 2:
preds = (pred_probs > 0.5).astype(int)
else:
preds = da.argmax(pred_probs, axis=1)
return preds
def predict(self, data): # pylint: disable=arguments-differ
_assert_dask_support()
return self.client.sync(self._predict_async, data)