[dask] Return GPU Series when input is from cuDF. (#5710)
* Refactor predict function.
This commit is contained in:
parent
91c646392d
commit
35e2205256
@ -105,15 +105,8 @@ except ImportError:
|
||||
|
||||
# cudf
|
||||
try:
|
||||
from cudf import DataFrame as CUDF_DataFrame
|
||||
from cudf import Series as CUDF_Series
|
||||
from cudf import concat as CUDF_concat
|
||||
CUDF_INSTALLED = True
|
||||
except ImportError:
|
||||
CUDF_DataFrame = object
|
||||
CUDF_Series = object
|
||||
CUDF_MultiIndex = object
|
||||
CUDF_INSTALLED = False
|
||||
CUDF_concat = None
|
||||
|
||||
# sklearn
|
||||
|
||||
@ -17,8 +17,7 @@ import scipy.sparse
|
||||
|
||||
from .compat import (
|
||||
STRING_TYPES, DataFrame, py_str,
|
||||
PANDAS_INSTALLED, CUDF_INSTALLED,
|
||||
CUDF_DataFrame,
|
||||
PANDAS_INSTALLED,
|
||||
os_fspath, os_PathLike, lazy_isinstance)
|
||||
from .libpath import find_lib_path
|
||||
|
||||
@ -282,8 +281,8 @@ def _convert_unknown_data(data, meta=None, meta_type=None):
|
||||
|
||||
# Either object has cuda array interface or contains columns with interfaces
|
||||
def _has_cuda_array_interface(data):
|
||||
return hasattr(data, '__cuda_array_interface__') or (
|
||||
CUDF_INSTALLED and isinstance(data, CUDF_DataFrame))
|
||||
return hasattr(data, '__cuda_array_interface__') or \
|
||||
lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame')
|
||||
|
||||
|
||||
def _cudf_array_interfaces(df):
|
||||
@ -508,7 +507,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
def set_interface_info(self, field, data):
|
||||
"""Set info type property into DMatrix."""
|
||||
# If we are passed a dataframe, extract the series
|
||||
if CUDF_INSTALLED and isinstance(data, CUDF_DataFrame):
|
||||
if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'):
|
||||
if len(data.columns) != 1:
|
||||
raise ValueError(
|
||||
'Expecting meta-info to contain a single column')
|
||||
|
||||
@ -25,7 +25,7 @@ from .compat import distributed_get_worker, distributed_wait, distributed_comm
|
||||
from .compat import da, dd, delayed, get_client
|
||||
from .compat import sparse, scipy_sparse
|
||||
from .compat import PANDAS_INSTALLED, DataFrame, Series, pandas_concat
|
||||
from .compat import CUDF_INSTALLED, CUDF_DataFrame, CUDF_Series, CUDF_concat
|
||||
from .compat import CUDF_concat
|
||||
from .compat import lazy_isinstance
|
||||
|
||||
from .core import DMatrix, Booster, _expect
|
||||
@ -97,7 +97,8 @@ def concat(value): # pylint: disable=too-many-return-statements
|
||||
return sparse.concatenate(value, axis=0)
|
||||
if PANDAS_INSTALLED and isinstance(value[0], (DataFrame, Series)):
|
||||
return pandas_concat(value, axis=0)
|
||||
if CUDF_INSTALLED and isinstance(value[0], (CUDF_DataFrame, CUDF_Series)):
|
||||
if lazy_isinstance(value[0], 'cudf.core.dataframe', 'DataFrame') or \
|
||||
lazy_isinstance(value[0], 'cudf.core.series', 'Series'):
|
||||
return CUDF_concat(value, axis=0)
|
||||
if lazy_isinstance(value[0], 'cupy.core.core', 'ndarray'):
|
||||
import cupy # pylint: disable=import-error
|
||||
@ -461,6 +462,25 @@ def train(client, params, dtrain, *args, evals=(), **kwargs):
|
||||
return list(filter(lambda ret: ret is not None, results))[0]
|
||||
|
||||
|
||||
def _direct_predict_impl(client, data, predict_fn):
|
||||
if isinstance(data, da.Array):
|
||||
predictions = client.submit(
|
||||
da.map_blocks,
|
||||
predict_fn, data, False, drop_axis=1,
|
||||
dtype=numpy.float32
|
||||
).result()
|
||||
return predictions
|
||||
if isinstance(data, dd.DataFrame):
|
||||
predictions = client.submit(
|
||||
dd.map_partitions,
|
||||
predict_fn, data, True,
|
||||
meta=dd.utils.make_meta({'prediction': 'f4'})
|
||||
).result()
|
||||
return predictions.iloc[:, 0]
|
||||
raise TypeError('data of type: ' + str(type(data)) +
|
||||
' is not supported by direct prediction')
|
||||
|
||||
|
||||
def predict(client, model, data, *args, missing=numpy.nan):
|
||||
'''Run prediction with a trained booster.
|
||||
|
||||
@ -502,26 +522,19 @@ def predict(client, model, data, *args, missing=numpy.nan):
|
||||
|
||||
def mapped_predict(partition, is_df):
|
||||
worker = distributed_get_worker()
|
||||
booster.set_param({'nthread': worker.nthreads})
|
||||
m = DMatrix(partition, missing=missing, nthread=worker.nthreads)
|
||||
predt = booster.predict(m, *args, validate_features=False)
|
||||
if is_df:
|
||||
predt = DataFrame(predt, columns=['prediction'])
|
||||
if lazy_isinstance(partition, 'cudf', 'core.dataframe.DataFrame'):
|
||||
import cudf # pylint: disable=import-error
|
||||
predt = cudf.DataFrame(predt, columns=['prediction'])
|
||||
else:
|
||||
predt = DataFrame(predt, columns=['prediction'])
|
||||
return predt
|
||||
|
||||
if isinstance(data, da.Array):
|
||||
predictions = client.submit(
|
||||
da.map_blocks,
|
||||
mapped_predict, data, False, drop_axis=1,
|
||||
dtype=numpy.float32
|
||||
).result()
|
||||
return predictions
|
||||
if isinstance(data, dd.DataFrame):
|
||||
predictions = client.submit(
|
||||
dd.map_partitions,
|
||||
mapped_predict, data, True,
|
||||
meta=dd.utils.make_meta({'prediction': 'f4'})
|
||||
).result()
|
||||
return predictions.iloc[:, 0]
|
||||
if isinstance(data, (da.Array, dd.DataFrame)):
|
||||
return _direct_predict_impl(client, data, mapped_predict)
|
||||
|
||||
# Prediction on dask DMatrix.
|
||||
worker_map = data.worker_map
|
||||
@ -644,20 +657,7 @@ def inplace_predict(client, model, data,
|
||||
dtype=numpy.float32)
|
||||
return prediction
|
||||
|
||||
if isinstance(data, da.Array):
|
||||
predictions = client.submit(
|
||||
da.map_blocks,
|
||||
mapped_predict, data, False, drop_axis=1,
|
||||
dtype=numpy.float32
|
||||
).result()
|
||||
return predictions
|
||||
if isinstance(data, dd.DataFrame):
|
||||
predictions = client.submit(
|
||||
dd.map_partitions,
|
||||
mapped_predict, data, True,
|
||||
meta=dd.utils.make_meta({'prediction': 'f4'})
|
||||
).result()
|
||||
return predictions.iloc[:, 0]
|
||||
return _direct_predict_impl(client, data, mapped_predict)
|
||||
|
||||
|
||||
def _evaluation_matrices(client, validation_set, sample_weights, missing):
|
||||
|
||||
@ -44,10 +44,10 @@ class TestDistributedGPU(unittest.TestCase):
|
||||
out = dxgb.train(client, {'tree_method': 'gpu_hist'},
|
||||
dtrain=dtrain,
|
||||
evals=[(dtrain, 'X')],
|
||||
num_boost_round=2)
|
||||
num_boost_round=4)
|
||||
|
||||
assert isinstance(out['booster'], dxgb.Booster)
|
||||
assert len(out['history']['X']['rmse']) == 2
|
||||
assert len(out['history']['X']['rmse']) == 4
|
||||
|
||||
predictions = dxgb.predict(client, out, dtrain).compute()
|
||||
assert isinstance(predictions, np.ndarray)
|
||||
@ -62,6 +62,20 @@ class TestDistributedGPU(unittest.TestCase):
|
||||
cupy.testing.assert_allclose(single_node, predictions)
|
||||
cupy.testing.assert_allclose(single_node, series_predictions)
|
||||
|
||||
predt = dxgb.predict(client, out, X)
|
||||
assert isinstance(predt, dd.Series)
|
||||
|
||||
def is_df(part):
|
||||
assert isinstance(part, cudf.DataFrame), part
|
||||
return part
|
||||
|
||||
predt.map_partitions(
|
||||
is_df,
|
||||
meta=dd.utils.make_meta({'prediction': 'f4'}))
|
||||
|
||||
cupy.testing.assert_allclose(
|
||||
predt.values.compute(), single_node)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
@pytest.mark.mgpu
|
||||
def test_dask_array(self):
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# coding: utf-8
|
||||
from xgboost.compat import SKLEARN_INSTALLED, PANDAS_INSTALLED
|
||||
from xgboost.compat import CUDF_INSTALLED, DASK_INSTALLED
|
||||
from xgboost.compat import DASK_INSTALLED
|
||||
|
||||
|
||||
def no_sklearn():
|
||||
@ -46,6 +46,12 @@ def no_dask_cuda():
|
||||
|
||||
|
||||
def no_cudf():
|
||||
try:
|
||||
import cudf # noqa
|
||||
CUDF_INSTALLED = True
|
||||
except ImportError:
|
||||
CUDF_INSTALLED = False
|
||||
|
||||
return {'condition': not CUDF_INSTALLED,
|
||||
'reason': 'CUDF is not installed'}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user