[dask] Honor nthreads from dask worker. (#5414)

This commit is contained in:
Jiaming Yuan
2020-03-16 04:51:24 +08:00
committed by GitHub
parent 21b671aa06
commit 761a5dbdfc
5 changed files with 59 additions and 16 deletions

View File

@@ -3,6 +3,7 @@ import pytest
import xgboost as xgb
import sys
import numpy as np
import json
if sys.platform.startswith("win"):
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
@@ -60,7 +61,7 @@ def test_from_dask_dataframe():
def test_from_dask_array():
with LocalCluster(n_workers=5) as cluster:
with LocalCluster(n_workers=5, threads_per_worker=5) as cluster:
with Client(cluster) as client:
X, y = generate_array()
dtrain = DaskDMatrix(client, X, y)
@@ -74,11 +75,15 @@ def test_from_dask_array():
# force prediction to be computed
prediction = prediction.compute()
single_node_predt = result['booster'].predict(
booster = result['booster']
single_node_predt = booster.predict(
xgb.DMatrix(X.compute())
)
np.testing.assert_allclose(prediction, single_node_predt)
config = json.loads(booster.save_config())
assert int(config['learner']['generic_param']['nthread']) == 5
def test_dask_regressor():
with LocalCluster(n_workers=5) as cluster: