Avoid dask test fixtures. (#5270)

* Fix Travis OSX timeout.

* Fix classifier.
This commit is contained in:
Jiaming Yuan 2020-02-03 12:39:20 +08:00 committed by GitHub
parent 856b81c727
commit ed0216642f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -10,18 +10,20 @@ if sys.platform.startswith("win"):
pytestmark = pytest.mark.skipif(**tm.no_dask()) pytestmark = pytest.mark.skipif(**tm.no_dask())
try: try:
from distributed.utils_test import client, loop, cluster_fixture from distributed import LocalCluster, Client
import dask.dataframe as dd import dask.dataframe as dd
import dask.array as da import dask.array as da
from xgboost.dask import DaskDMatrix from xgboost.dask import DaskDMatrix
except ImportError: except ImportError:
client = None LocalCluster = None
loop = None Client = None
cluster_fixture = None dd = None
pass da = None
DaskDMatrix = None
kRows = 1000 kRows = 1000
kCols = 10 kCols = 10
kWorkers = 5
def generate_array(): def generate_array():
@ -31,7 +33,9 @@ def generate_array():
return X, y return X, y
def test_from_dask_dataframe(client): def test_from_dask_dataframe():
with LocalCluster(n_workers=5) as cluster:
with Client(cluster) as client:
X, y = generate_array() X, y = generate_array()
X = dd.from_dask_array(X) X = dd.from_dask_array(X)
@ -51,11 +55,13 @@ def test_from_dask_dataframe(client):
# evals_result is not supported in dask interface. # evals_result is not supported in dask interface.
xgb.dask.train( xgb.dask.train(
client, {}, dtrain, num_boost_round=2, evals_result={}) client, {}, dtrain, num_boost_round=2, evals_result={})
# force prediction to be computed
prediction = prediction.compute() # force prediction to be computed prediction = prediction.compute()
def test_from_dask_array(client): def test_from_dask_array():
with LocalCluster(n_workers=5) as cluster:
with Client(cluster) as client:
X, y = generate_array() X, y = generate_array()
dtrain = DaskDMatrix(client, X, y) dtrain = DaskDMatrix(client, X, y)
# results is {'booster': Booster, 'history': {...}} # results is {'booster': Booster, 'history': {...}}
@ -65,11 +71,13 @@ def test_from_dask_array(client):
assert prediction.shape[0] == kRows assert prediction.shape[0] == kRows
assert isinstance(prediction, da.Array) assert isinstance(prediction, da.Array)
# force prediction to be computed
prediction = prediction.compute() # force prediction to be computed prediction = prediction.compute()
def test_regressor(client): def test_dask_regressor():
with LocalCluster(n_workers=5) as cluster:
with Client(cluster) as client:
X, y = generate_array() X, y = generate_array()
regressor = xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2) regressor = xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2)
regressor.set_params(tree_method='hist') regressor.set_params(tree_method='hist')
@ -89,10 +97,13 @@ def test_regressor(client):
assert len(history['validation_0']['rmse']) == 2 assert len(history['validation_0']['rmse']) == 2
def test_classifier(client): def test_dask_classifier():
with LocalCluster(n_workers=5) as cluster:
with Client(cluster) as client:
X, y = generate_array() X, y = generate_array()
y = (y * 10).astype(np.int32) y = (y * 10).astype(np.int32)
classifier = xgb.dask.DaskXGBClassifier(verbosity=1, n_estimators=2) classifier = xgb.dask.DaskXGBClassifier(
verbosity=1, n_estimators=2)
classifier.client = client classifier.client = client
classifier.fit(X, y, eval_set=[(X, y)]) classifier.fit(X, y, eval_set=[(X, y)])
prediction = classifier.predict(X) prediction = classifier.predict(X)
@ -164,11 +175,15 @@ def run_empty_dmatrix(client, parameters):
# No test for Exact, as empty DMatrix handling are mostly for distributed # No test for Exact, as empty DMatrix handling are mostly for distributed
# environment and Exact doesn't support it. # environment and Exact doesn't support it.
def test_empty_dmatrix_hist(client): def test_empty_dmatrix_hist():
with LocalCluster(n_workers=5) as cluster:
with Client(cluster) as client:
parameters = {'tree_method': 'hist'} parameters = {'tree_method': 'hist'}
run_empty_dmatrix(client, parameters) run_empty_dmatrix(client, parameters)
def test_empty_dmatrix_approx(client): def test_empty_dmatrix_approx():
with LocalCluster(n_workers=5) as cluster:
with Client(cluster) as client:
parameters = {'tree_method': 'approx'} parameters = {'tree_method': 'approx'}
run_empty_dmatrix(client, parameters) run_empty_dmatrix(client, parameters)