Fix dask predict shape infer. (#5989)

This commit is contained in:
Jiaming Yuan
2020-08-08 14:29:22 +08:00
committed by GitHub
parent 9c6e791e64
commit 801e6b6800
2 changed files with 44 additions and 19 deletions

View File

@@ -738,7 +738,8 @@ async def _predict_async(client: Client, model, data, *args,
predt = booster.predict(data=local_x,
validate_features=local_x.num_row() != 0,
*args)
ret = (delayed(predt), order)
columns = 1 if len(predt.shape) == 1 else predt.shape[1]
ret = ((delayed(predt), columns), order)
predictions.append(ret)
return predictions
@@ -775,8 +776,10 @@ async def _predict_async(client: Client, model, data, *args,
# See https://docs.dask.org/en/latest/array-creation.html
arrays = []
for i, shape in enumerate(shapes):
arrays.append(da.from_delayed(results[i], shape=(shape[0], ),
dtype=numpy.float32))
arrays.append(da.from_delayed(
results[i][0], shape=(shape[0],)
if results[i][1] == 1 else (shape[0], results[i][1]),
dtype=numpy.float32))
predictions = await da.concatenate(arrays, axis=0)
return predictions
@@ -978,6 +981,7 @@ class DaskScikitLearnBase(XGBModel):
def client(self, clt):
self._client = clt
@xgboost_model_doc("""Implementation of the Scikit-Learn API for XGBoost.""",
['estimators', 'model'])
class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
@@ -1032,9 +1036,6 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
['estimators', 'model']
)
class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
# pylint: disable=missing-docstring
_client = None
async def _fit_async(self, X, y,
sample_weights=None,
eval_set=None,