Use cudf.concat explicitly. (#4918)
* Use `cudf.concat` explicitly. * Add test.
This commit is contained in:
parent
aefb1e5c2f
commit
6c9b6f11da
@ -134,12 +134,14 @@ try:
|
|||||||
from cudf import DataFrame as CUDF_DataFrame
|
from cudf import DataFrame as CUDF_DataFrame
|
||||||
from cudf import Series as CUDF_Series
|
from cudf import Series as CUDF_Series
|
||||||
from cudf import MultiIndex as CUDF_MultiIndex
|
from cudf import MultiIndex as CUDF_MultiIndex
|
||||||
|
from cudf import concat as CUDF_concat
|
||||||
CUDF_INSTALLED = True
|
CUDF_INSTALLED = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
CUDF_DataFrame = object
|
CUDF_DataFrame = object
|
||||||
CUDF_Series = object
|
CUDF_Series = object
|
||||||
CUDF_MultiIndex = object
|
CUDF_MultiIndex = object
|
||||||
CUDF_INSTALLED = False
|
CUDF_INSTALLED = False
|
||||||
|
CUDF_concat = None
|
||||||
|
|
||||||
# sklearn
|
# sklearn
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -25,6 +25,7 @@ from .compat import distributed_get_worker, distributed_wait, distributed_comm
|
|||||||
from .compat import da, dd, delayed, get_client
|
from .compat import da, dd, delayed, get_client
|
||||||
from .compat import sparse, scipy_sparse
|
from .compat import sparse, scipy_sparse
|
||||||
from .compat import PANDAS_INSTALLED, DataFrame, Series, pandas_concat
|
from .compat import PANDAS_INSTALLED, DataFrame, Series, pandas_concat
|
||||||
|
from .compat import CUDF_INSTALLED, CUDF_DataFrame, CUDF_Series, CUDF_concat
|
||||||
|
|
||||||
from .core import DMatrix, Booster, _expect
|
from .core import DMatrix, Booster, _expect
|
||||||
from .training import train as worker_train
|
from .training import train as worker_train
|
||||||
@ -84,6 +85,8 @@ def concat(value):
|
|||||||
return sparse.concatenate(value, axis=0)
|
return sparse.concatenate(value, axis=0)
|
||||||
if PANDAS_INSTALLED and isinstance(value[0], (DataFrame, Series)):
|
if PANDAS_INSTALLED and isinstance(value[0], (DataFrame, Series)):
|
||||||
return pandas_concat(value, axis=0)
|
return pandas_concat(value, axis=0)
|
||||||
|
if CUDF_INSTALLED and isinstance(value[0], (CUDF_DataFrame, CUDF_Series)):
|
||||||
|
return CUDF_concat(value, axis=0)
|
||||||
return dd.multi.concat(list(value), axis=0)
|
return dd.multi.concat(list(value), axis=0)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
41
tests/python-gpu/test_gpu_with_dask.py
Normal file
41
tests/python-gpu/test_gpu_with_dask.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
import sys
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
if sys.platform.startswith("win"):
|
||||||
|
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from distributed.utils_test import client, loop, cluster_fixture
|
||||||
|
import dask.dataframe as dd
|
||||||
|
from xgboost import dask as dxgb
|
||||||
|
import cudf
|
||||||
|
except ImportError:
|
||||||
|
client = None
|
||||||
|
loop = None
|
||||||
|
cluster_fixture = None
|
||||||
|
pass
|
||||||
|
|
||||||
|
sys.path.append("tests/python")
|
||||||
|
from test_with_dask import generate_array
|
||||||
|
import testing as tm
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(**tm.no_dask())
|
||||||
|
@pytest.mark.skipif(**tm.no_cudf())
|
||||||
|
def test_dask_dataframe(client):
|
||||||
|
X, y = generate_array()
|
||||||
|
|
||||||
|
X = dd.from_dask_array(X)
|
||||||
|
y = dd.from_dask_array(y)
|
||||||
|
|
||||||
|
X = X.map_partitions(cudf.from_pandas)
|
||||||
|
y = y.map_partitions(cudf.from_pandas)
|
||||||
|
|
||||||
|
dtrain = dxgb.DaskDMatrix(client, X, y)
|
||||||
|
out = dxgb.train(client, {'tree_method': 'gpu_hist'},
|
||||||
|
dtrain=dtrain,
|
||||||
|
evals=[(dtrain, 'X')],
|
||||||
|
num_boost_round=2)
|
||||||
|
|
||||||
|
assert isinstance(out['booster'], dxgb.Booster)
|
||||||
|
assert len(out['history']['X']['rmse']) == 2
|
||||||
@ -7,16 +7,19 @@ import numpy as np
|
|||||||
if sys.platform.startswith("win"):
|
if sys.platform.startswith("win"):
|
||||||
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
|
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.skipif(**tm.no_dask())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from distributed.utils_test import client, loop, cluster_fixture
|
from distributed.utils_test import client, loop, cluster_fixture
|
||||||
import dask.dataframe as dd
|
import dask.dataframe as dd
|
||||||
import dask.array as da
|
import dask.array as da
|
||||||
from xgboost.dask import DaskDMatrix
|
from xgboost.dask import DaskDMatrix
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
client = None
|
||||||
|
loop = None
|
||||||
|
cluster_fixture = None
|
||||||
pass
|
pass
|
||||||
|
|
||||||
pytestmark = pytest.mark.skipif(**tm.no_dask())
|
|
||||||
|
|
||||||
kRows = 1000
|
kRows = 1000
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user