[SYCL] Add dask support for distributed (#10812)
This commit is contained in:
committed by
GitHub
parent
2a37a8880c
commit
d7599e095b
@@ -17,5 +17,6 @@ dependencies:
|
||||
- pytest
|
||||
- pytest-timeout
|
||||
- pytest-cov
|
||||
- dask
|
||||
- dpcpp_linux-64
|
||||
- onedpl-devel
|
||||
|
||||
42
tests/python-sycl/test_sycl_simple_dask.py
Normal file
42
tests/python-sycl/test_sycl_simple_dask.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from xgboost import dask as dxgb
|
||||
from xgboost import testing as tm
|
||||
|
||||
from hypothesis import given, strategies, assume, settings, note
|
||||
|
||||
import dask.array as da
|
||||
import dask.distributed
|
||||
|
||||
|
||||
def train_result(client, param, dtrain, num_rounds):
|
||||
result = dxgb.train(
|
||||
client,
|
||||
param,
|
||||
dtrain,
|
||||
num_rounds,
|
||||
verbose_eval=False,
|
||||
evals=[(dtrain, "train")],
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class TestSYCLDask:
|
||||
# The simplest test verify only one node training.
|
||||
def test_simple(self):
|
||||
cluster = dask.distributed.LocalCluster(n_workers=1)
|
||||
client = dask.distributed.Client(cluster)
|
||||
|
||||
param = {}
|
||||
param["tree_method"] = "hist"
|
||||
param["device"] = "sycl"
|
||||
param["verbosity"] = 0
|
||||
param["objective"] = "reg:squarederror"
|
||||
|
||||
# X and y must be Dask dataframes or arrays
|
||||
num_obs = 1e4
|
||||
num_features = 20
|
||||
X = da.random.random(size=(num_obs, num_features), chunks=(1000, num_features))
|
||||
y = da.random.random(size=(num_obs, 1), chunks=(1000, 1))
|
||||
dtrain = dxgb.DaskDMatrix(client, X, y)
|
||||
|
||||
result = train_result(client, param, dtrain, 10)
|
||||
assert tm.non_increasing(result["history"]["train"]["rmse"])
|
||||
Reference in New Issue
Block a user