diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 541adcd84..e943cc592 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -103,6 +103,9 @@ def concat(value): def _xgb_get_client(client): '''Simple wrapper around testing None.''' + if not isinstance(client, (type(get_client()), type(None))): + raise TypeError( + _expect([type(get_client()), type(None)], type(client))) ret = get_client() if client is None else client return ret @@ -112,12 +115,6 @@ def _get_client_workers(client): return workers -def _assert_client(client): - if not isinstance(client, (type(get_client()), type(None))): - raise TypeError( - _expect([type(get_client()), type(None)], type(client))) - - class DaskDMatrix: # pylint: disable=missing-docstring, too-many-instance-attributes '''DMatrix holding on references to Dask DataFrame or Dask Array. Constructing @@ -155,7 +152,7 @@ class DaskDMatrix: feature_names=None, feature_types=None): _assert_dask_support() - _assert_client(client) + client = _xgb_get_client(client) self.feature_names = feature_names self.feature_types = feature_types @@ -177,7 +174,6 @@ class DaskDMatrix: self.has_label = label is not None self.has_weights = weight is not None - client = _xgb_get_client(client) client.sync(self.map_local_data, client, data, label, weight) async def map_local_data(self, client, data, label=None, weights=None): @@ -391,13 +387,12 @@ def train(client, params, dtrain, *args, evals=(), **kwargs): ''' _assert_dask_support() - _assert_client(client) + client = _xgb_get_client(client) if 'evals_result' in kwargs.keys(): raise ValueError( 'evals_result is not supported in dask interface.', 'The evaluation history is returned as result of training.') - client = _xgb_get_client(client) workers = list(_get_client_workers(client).keys()) rabit_args = _get_rabit_args(workers, client) @@ -452,7 +447,7 @@ def train(client, params, dtrain, *args, evals=(), **kwargs): return list(filter(lambda ret: ret is not None, results))[0] -def predict(client, model, data, *args): +def predict(client, model, data, *args, missing=numpy.nan): '''Run prediction with a trained booster. .. note:: @@ -466,32 +461,55 @@ def predict(client, model, data, *args): returned from dask if it's set to None. model: A Booster or a dictionary returned by `xgboost.dask.train`. The trained model. - data: DaskDMatrix + data: DaskDMatrix/dask.dataframe.DataFrame/dask.array.Array Input data used for prediction. + missing: float + Used when input data is not DaskDMatrix. Specify the value + considered as missing. Returns ------- - prediction: dask.array.Array + prediction: dask.array.Array/dask.dataframe.Series ''' _assert_dask_support() - _assert_client(client) + client = _xgb_get_client(client) if isinstance(model, Booster): booster = model elif isinstance(model, dict): booster = model['booster'] else: raise TypeError(_expect([Booster, dict], type(model))) + if not isinstance(data, (DaskDMatrix, da.Array, dd.DataFrame)): + raise TypeError(_expect([DaskDMatrix, da.Array, dd.DataFrame], + type(data))) - if not isinstance(data, DaskDMatrix): - raise TypeError(_expect([DaskDMatrix], type(data))) + def mapped_predict(partition, is_df): + worker = distributed_get_worker() + m = DMatrix(partition, missing=missing, nthread=worker.nthreads) + predt = booster.predict(m, *args, validate_features=False) + if is_df: + predt = DataFrame(predt, columns=['prediction']) + return predt + if isinstance(data, da.Array): + predictions = client.submit( + da.map_blocks, + mapped_predict, data, False, drop_axis=1, + dtype=numpy.float32 + ).result() + return predictions + if isinstance(data, dd.DataFrame): + import dask + predictions = client.submit( + dd.map_partitions, + mapped_predict, data, True, + meta=dask.dataframe.utils.make_meta({'prediction': 'f4'}) + ).result() + return predictions.iloc[:, 0] + + # Prediction on dask DMatrix. worker_map = data.worker_map - client = _xgb_get_client(client) - - missing = data.missing - feature_names = data.feature_names - feature_types = data.feature_types def dispatched_predict(worker_id): '''Perform prediction on each worker.''' @@ -502,9 +520,9 @@ def predict(client, model, data, *args): booster.set_param({'nthread': worker.nthreads}) for part, order in list_of_parts: local_x = DMatrix(part, - feature_names=feature_names, - feature_types=feature_types, - missing=missing, + feature_names=data.feature_names, + feature_types=data.feature_types, + missing=data.missing, nthread=worker.nthreads) predt = booster.predict(data=local_x, validate_features=local_x.num_row() != 0, @@ -520,8 +538,7 @@ def predict(client, model, data, *args): list_of_parts = data.get_worker_x_ordered(worker) shapes = [] for part, order in list_of_parts: - s = part.shape - shapes.append((s, order)) + shapes.append((part.shape, order)) return shapes def map_function(func): diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index f579ee5d7..896f1881d 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -57,7 +57,13 @@ def test_from_dask_dataframe(): xgb.dask.train( client, {}, dtrain, num_boost_round=2, evals_result={}) # force prediction to be computed - prediction = prediction.compute() + from_dmatrix = prediction.compute() + + prediction = xgb.dask.predict(client, model=booster, data=X) + from_df = prediction.compute() + + assert isinstance(prediction, dd.Series) + assert np.all(from_dmatrix == from_df.to_numpy()) def test_from_dask_array(): @@ -84,6 +90,12 @@ def test_from_dask_array(): config = json.loads(booster.save_config()) assert int(config['learner']['generic_param']['nthread']) == 5 + from_arr = xgb.dask.predict( + client, model=booster, data=X) + + assert isinstance(from_arr, da.Array) + assert np.all(single_node_predt == from_arr.compute()) + def test_dask_regressor(): with LocalCluster(n_workers=5) as cluster: