[dask] Honor nthreads from dask worker. (#5414)
This commit is contained in:
@@ -3,6 +3,7 @@ import pytest
|
||||
import xgboost as xgb
|
||||
import sys
|
||||
import numpy as np
|
||||
import json
|
||||
|
||||
if sys.platform.startswith("win"):
|
||||
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
|
||||
@@ -60,7 +61,7 @@ def test_from_dask_dataframe():
|
||||
|
||||
|
||||
def test_from_dask_array():
|
||||
with LocalCluster(n_workers=5) as cluster:
|
||||
with LocalCluster(n_workers=5, threads_per_worker=5) as cluster:
|
||||
with Client(cluster) as client:
|
||||
X, y = generate_array()
|
||||
dtrain = DaskDMatrix(client, X, y)
|
||||
@@ -74,11 +75,15 @@ def test_from_dask_array():
|
||||
# force prediction to be computed
|
||||
prediction = prediction.compute()
|
||||
|
||||
single_node_predt = result['booster'].predict(
|
||||
booster = result['booster']
|
||||
single_node_predt = booster.predict(
|
||||
xgb.DMatrix(X.compute())
|
||||
)
|
||||
np.testing.assert_allclose(prediction, single_node_predt)
|
||||
|
||||
config = json.loads(booster.save_config())
|
||||
assert int(config['learner']['generic_param']['nthread']) == 5
|
||||
|
||||
|
||||
def test_dask_regressor():
|
||||
with LocalCluster(n_workers=5) as cluster:
|
||||
|
||||
Reference in New Issue
Block a user