Rewrite Dask interface. (#4819)
This commit is contained in:
@@ -1,93 +1,96 @@
|
||||
import testing as tm
|
||||
import pytest
|
||||
import xgboost as xgb
|
||||
import numpy as np
|
||||
import sys
|
||||
|
||||
if sys.platform.startswith("win"):
|
||||
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
|
||||
|
||||
try:
|
||||
from distributed.utils_test import client, loop, cluster_fixture
|
||||
import dask.dataframe as dd
|
||||
import dask.array as da
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
pytestmark = pytest.mark.skipif(**tm.no_dask())
|
||||
|
||||
|
||||
def run_train():
|
||||
# Contains one label equal to rank
|
||||
dmat = xgb.DMatrix(np.array([[0]]), label=[xgb.rabit.get_rank()])
|
||||
bst = xgb.train({"eta": 1.0, "lambda": 0.0}, dmat, 1)
|
||||
pred = bst.predict(dmat)
|
||||
expected_result = np.average(range(xgb.rabit.get_world_size()))
|
||||
assert all(p == expected_result for p in pred)
|
||||
|
||||
|
||||
def test_train(client):
|
||||
# Train two workers, the first has label 0, the second has label 1
|
||||
# If they build the model together the output should be 0.5
|
||||
xgb.dask.run(client, run_train)
|
||||
# Run again to check we can have multiple sessions
|
||||
xgb.dask.run(client, run_train)
|
||||
|
||||
|
||||
def run_create_dmatrix(X, y, weights):
|
||||
dmat = xgb.dask.create_worker_dmatrix(X, y, weight=weights)
|
||||
# Expect this worker to get two partitions and concatenate them
|
||||
assert dmat.num_row() == 50
|
||||
|
||||
|
||||
def test_dask_dataframe(client):
|
||||
n = 10
|
||||
m = 100
|
||||
partition_size = 25
|
||||
X = dd.from_array(np.random.random((m, n)), partition_size)
|
||||
y = dd.from_array(np.random.random(m), partition_size)
|
||||
weights = dd.from_array(np.random.random(m), partition_size)
|
||||
xgb.dask.run(client, run_create_dmatrix, X, y, weights)
|
||||
|
||||
|
||||
def test_dask_array(client):
|
||||
n = 10
|
||||
m = 100
|
||||
partition_size = 25
|
||||
X = da.random.random((m, n), partition_size)
|
||||
y = da.random.random(m, partition_size)
|
||||
weights = da.random.random(m, partition_size)
|
||||
xgb.dask.run(client, run_create_dmatrix, X, y, weights)
|
||||
|
||||
|
||||
def run_get_local_data(X, y):
|
||||
X_local = xgb.dask.get_local_data(X)
|
||||
y_local = xgb.dask.get_local_data(y)
|
||||
assert (X_local.shape == (50, 10))
|
||||
assert (y_local.shape == (50,))
|
||||
|
||||
|
||||
def test_get_local_data(client):
|
||||
n = 10
|
||||
m = 100
|
||||
partition_size = 25
|
||||
X = da.random.random((m, n), partition_size)
|
||||
y = da.random.random(m, partition_size)
|
||||
xgb.dask.run(client, run_get_local_data, X, y)
|
||||
|
||||
|
||||
def run_sklearn():
|
||||
# Contains one label equal to rank
|
||||
X = np.array([[0]])
|
||||
y = [xgb.rabit.get_rank()]
|
||||
model = xgb.XGBRegressor(learning_rate=1.0)
|
||||
model.fit(X, y)
|
||||
pred = model.predict(X)
|
||||
expected_result = np.average(range(xgb.rabit.get_world_size()))
|
||||
assert all(p == expected_result for p in pred)
|
||||
return pred
|
||||
|
||||
|
||||
def test_sklearn(client):
|
||||
result = xgb.dask.run(client, run_sklearn)
|
||||
print(result)
|
||||
import testing as tm
|
||||
import pytest
|
||||
import xgboost as xgb
|
||||
import sys
|
||||
import numpy as np
|
||||
|
||||
if sys.platform.startswith("win"):
|
||||
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
|
||||
|
||||
try:
|
||||
from distributed.utils_test import client, loop, cluster_fixture
|
||||
import dask.dataframe as dd
|
||||
import dask.array as da
|
||||
from xgboost.dask import DaskDMatrix
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
pytestmark = pytest.mark.skipif(**tm.no_dask())
|
||||
|
||||
kRows = 1000
|
||||
|
||||
|
||||
def generate_array():
|
||||
n = 10
|
||||
partition_size = 20
|
||||
X = da.random.random((kRows, n), partition_size)
|
||||
y = da.random.random(kRows, partition_size)
|
||||
return X, y
|
||||
|
||||
|
||||
def test_from_dask_dataframe(client):
|
||||
X, y = generate_array()
|
||||
|
||||
X = dd.from_dask_array(X)
|
||||
y = dd.from_dask_array(y)
|
||||
|
||||
dtrain = DaskDMatrix(client, X, y)
|
||||
booster = xgb.dask.train(
|
||||
client, {}, dtrain, num_boost_round=2)['booster']
|
||||
|
||||
prediction = xgb.dask.predict(client, model=booster, data=dtrain)
|
||||
|
||||
assert isinstance(prediction, da.Array)
|
||||
assert prediction.shape[0] == kRows, prediction
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# evals_result is not supported in dask interface.
|
||||
xgb.dask.train(
|
||||
client, {}, dtrain, num_boost_round=2, evals_result={})
|
||||
|
||||
|
||||
def test_from_dask_array(client):
|
||||
X, y = generate_array()
|
||||
dtrain = DaskDMatrix(client, X, y)
|
||||
# results is {'booster': Booster, 'history': {...}}
|
||||
result = xgb.dask.train(client, {}, dtrain)
|
||||
|
||||
prediction = xgb.dask.predict(client, result, dtrain)
|
||||
|
||||
assert isinstance(prediction, da.Array)
|
||||
|
||||
|
||||
def test_regressor(client):
|
||||
X, y = generate_array()
|
||||
regressor = xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2)
|
||||
regressor.set_params(tree_method='hist')
|
||||
regressor.client = client
|
||||
regressor.fit(X, y, eval_set=[(X, y)])
|
||||
prediction = regressor.predict(X)
|
||||
|
||||
history = regressor.evals_result()
|
||||
|
||||
assert isinstance(prediction, da.Array)
|
||||
assert isinstance(history, dict)
|
||||
|
||||
assert list(history['validation_0'].keys())[0] == 'rmse'
|
||||
assert len(history['validation_0']['rmse']) == 2
|
||||
|
||||
|
||||
def test_classifier(client):
|
||||
X, y = generate_array()
|
||||
y = (y * 10).astype(np.int32)
|
||||
classifier = xgb.dask.DaskXGBClassifier(verbosity=1, n_estimators=2)
|
||||
classifier.client = client
|
||||
classifier.fit(X, y, eval_set=[(X, y)])
|
||||
prediction = classifier.predict(X)
|
||||
|
||||
history = classifier.evals_result()
|
||||
|
||||
assert isinstance(prediction, da.Array)
|
||||
assert isinstance(history, dict)
|
||||
|
||||
assert list(history.keys())[0] == 'validation_0'
|
||||
assert list(history['validation_0'].keys())[0] == 'merror'
|
||||
assert len(list(history['validation_0'])) == 1
|
||||
assert len(history['validation_0']['merror']) == 2
|
||||
|
||||
Reference in New Issue
Block a user