import testing as tm import pytest import xgboost as xgb import sys import numpy as np import json import asyncio from sklearn.datasets import make_classification import os import subprocess from hypothesis import given, strategies, settings, note from test_updaters import hist_parameter_strategy, exact_parameter_strategy if sys.platform.startswith("win"): pytest.skip("Skipping dask tests on Windows", allow_module_level=True) pytestmark = pytest.mark.skipif(**tm.no_dask()) try: from distributed import LocalCluster, Client from distributed.utils_test import client, loop, cluster_fixture import dask.dataframe as dd import dask.array as da from xgboost.dask import DaskDMatrix except ImportError: LocalCluster = None Client = None client = None loop = None cluster_fixture = None dd = None da = None DaskDMatrix = None kRows = 1000 kCols = 10 kWorkers = 5 def generate_array(): partition_size = 20 X = da.random.random((kRows, kCols), partition_size) y = da.random.random(kRows, partition_size) return X, y def test_from_dask_dataframe(): with LocalCluster(n_workers=kWorkers) as cluster: with Client(cluster) as client: X, y = generate_array() X = dd.from_dask_array(X) y = dd.from_dask_array(y) dtrain = DaskDMatrix(client, X, y) booster = xgb.dask.train( client, {}, dtrain, num_boost_round=2)['booster'] prediction = xgb.dask.predict(client, model=booster, data=dtrain) assert prediction.ndim == 1 assert isinstance(prediction, da.Array) assert prediction.shape[0] == kRows with pytest.raises(ValueError): # evals_result is not supported in dask interface. xgb.dask.train( client, {}, dtrain, num_boost_round=2, evals_result={}) # force prediction to be computed from_dmatrix = prediction.compute() prediction = xgb.dask.predict(client, model=booster, data=X) from_df = prediction.compute() assert isinstance(prediction, dd.Series) assert np.all(prediction.compute().values == from_dmatrix) assert np.all(from_dmatrix == from_df.to_numpy()) series_predictions = xgb.dask.inplace_predict(client, booster, X) assert isinstance(series_predictions, dd.Series) np.testing.assert_allclose(series_predictions.compute().values, from_dmatrix) def test_from_dask_array(): with LocalCluster(n_workers=kWorkers, threads_per_worker=5) as cluster: with Client(cluster) as client: X, y = generate_array() dtrain = DaskDMatrix(client, X, y) # results is {'booster': Booster, 'history': {...}} result = xgb.dask.train(client, {}, dtrain) prediction = xgb.dask.predict(client, result, dtrain) assert prediction.shape[0] == kRows assert isinstance(prediction, da.Array) # force prediction to be computed prediction = prediction.compute() booster = result['booster'] single_node_predt = booster.predict( xgb.DMatrix(X.compute()) ) np.testing.assert_allclose(prediction, single_node_predt) config = json.loads(booster.save_config()) assert int(config['learner']['generic_param']['nthread']) == 5 from_arr = xgb.dask.predict( client, model=booster, data=X) assert isinstance(from_arr, da.Array) assert np.all(single_node_predt == from_arr.compute()) def test_dask_predict_shape_infer(): with LocalCluster(n_workers=kWorkers) as cluster: with Client(cluster) as client: X, y = make_classification(n_samples=1000, n_informative=5, n_classes=3) X_ = dd.from_array(X, chunksize=100) y_ = dd.from_array(y, chunksize=100) dtrain = xgb.dask.DaskDMatrix(client, data=X_, label=y_) model = xgb.dask.train( client, {"objective": "multi:softprob", "num_class": 3}, dtrain=dtrain ) preds = xgb.dask.predict(client, model, dtrain) assert preds.shape[0] == preds.compute().shape[0] assert preds.shape[1] == preds.compute().shape[1] def test_dask_missing_value_reg(): with LocalCluster(n_workers=kWorkers) as cluster: with Client(cluster) as client: X_0 = np.ones((20 // 2, kCols)) X_1 = np.zeros((20 // 2, kCols)) X = np.concatenate([X_0, X_1], axis=0) np.random.shuffle(X) X = da.from_array(X) X = X.rechunk(20, 1) y = da.random.randint(0, 3, size=20) y.rechunk(20) regressor = xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2, missing=0.0) regressor.client = client regressor.set_params(tree_method='hist') regressor.fit(X, y, eval_set=[(X, y)]) dd_predt = regressor.predict(X).compute() np_X = X.compute() np_predt = regressor.get_booster().predict( xgb.DMatrix(np_X, missing=0.0)) np.testing.assert_allclose(np_predt, dd_predt) def test_dask_missing_value_cls(): with LocalCluster() as cluster: with Client(cluster) as client: X_0 = np.ones((kRows // 2, kCols)) X_1 = np.zeros((kRows // 2, kCols)) X = np.concatenate([X_0, X_1], axis=0) np.random.shuffle(X) X = da.from_array(X) X = X.rechunk(20, None) y = da.random.randint(0, 3, size=kRows) y = y.rechunk(20, 1) cls = xgb.dask.DaskXGBClassifier(verbosity=1, n_estimators=2, tree_method='hist', missing=0.0) cls.client = client cls.fit(X, y, eval_set=[(X, y)]) dd_pred_proba = cls.predict_proba(X).compute() np_X = X.compute() np_pred_proba = cls.get_booster().predict( xgb.DMatrix(np_X, missing=0.0)) np.testing.assert_allclose(np_pred_proba, dd_pred_proba) cls = xgb.dask.DaskXGBClassifier() assert hasattr(cls, 'missing') def test_dask_regressor(): with LocalCluster(n_workers=kWorkers) as cluster: with Client(cluster) as client: X, y = generate_array() regressor = xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2) regressor.set_params(tree_method='hist') regressor.client = client regressor.fit(X, y, eval_set=[(X, y)]) prediction = regressor.predict(X) assert prediction.ndim == 1 assert prediction.shape[0] == kRows history = regressor.evals_result() assert isinstance(prediction, da.Array) assert isinstance(history, dict) assert list(history['validation_0'].keys())[0] == 'rmse' assert len(history['validation_0']['rmse']) == 2 def test_dask_classifier(): with LocalCluster(n_workers=kWorkers) as cluster: with Client(cluster) as client: X, y = generate_array() y = (y * 10).astype(np.int32) classifier = xgb.dask.DaskXGBClassifier( verbosity=1, n_estimators=2) classifier.client = client classifier.fit(X, y, eval_set=[(X, y)]) prediction = classifier.predict(X) assert prediction.ndim == 1 assert prediction.shape[0] == kRows history = classifier.evals_result() assert isinstance(prediction, da.Array) assert isinstance(history, dict) assert list(history.keys())[0] == 'validation_0' assert list(history['validation_0'].keys())[0] == 'merror' assert len(list(history['validation_0'])) == 1 assert len(history['validation_0']['merror']) == 2 # Test .predict_proba() probas = classifier.predict_proba(X) assert classifier.n_classes_ == 10 assert probas.ndim == 2 assert probas.shape[0] == kRows assert probas.shape[1] == 10 cls_booster = classifier.get_booster() single_node_proba = cls_booster.inplace_predict(X.compute()) np.testing.assert_allclose(single_node_proba, probas.compute()) # Test with dataframe. X_d = dd.from_dask_array(X) y_d = dd.from_dask_array(y) classifier.fit(X_d, y_d) assert classifier.n_classes_ == 10 prediction = classifier.predict(X_d) assert prediction.ndim == 1 assert prediction.shape[0] == kRows @pytest.mark.skipif(**tm.no_sklearn()) def test_sklearn_grid_search(): from sklearn.model_selection import GridSearchCV with LocalCluster(n_workers=kWorkers) as cluster: with Client(cluster) as client: X, y = generate_array() reg = xgb.dask.DaskXGBRegressor(learning_rate=0.1, tree_method='hist') reg.client = client model = GridSearchCV(reg, {'max_depth': [2, 4], 'n_estimators': [5, 10]}, cv=2, verbose=1, iid=True) model.fit(X, y) # Expect unique results for each parameter value This confirms # sklearn is able to successfully update the parameter means = model.cv_results_['mean_test_score'] assert len(means) == len(set(means)) def run_empty_dmatrix_reg(client, parameters): def _check_outputs(out, predictions): assert isinstance(out['booster'], xgb.dask.Booster) assert len(out['history']['validation']['rmse']) == 2 assert isinstance(predictions, np.ndarray) assert predictions.shape[0] == 1 kRows, kCols = 1, 97 X = dd.from_array(np.random.randn(kRows, kCols)) y = dd.from_array(np.random.rand(kRows)) dtrain = xgb.dask.DaskDMatrix(client, X, y) out = xgb.dask.train(client, parameters, dtrain=dtrain, evals=[(dtrain, 'validation')], num_boost_round=2) predictions = xgb.dask.predict(client=client, model=out, data=dtrain).compute() _check_outputs(out, predictions) # train has more rows than evals valid = dtrain kRows += 1 X = dd.from_array(np.random.randn(kRows, kCols)) y = dd.from_array(np.random.rand(kRows)) dtrain = xgb.dask.DaskDMatrix(client, X, y) out = xgb.dask.train(client, parameters, dtrain=dtrain, evals=[(valid, 'validation')], num_boost_round=2) predictions = xgb.dask.predict(client=client, model=out, data=valid).compute() _check_outputs(out, predictions) def run_empty_dmatrix_cls(client, parameters): n_classes = 4 def _check_outputs(out, predictions): assert isinstance(out['booster'], xgb.dask.Booster) assert len(out['history']['validation']['merror']) == 2 assert isinstance(predictions, np.ndarray) assert predictions.shape[1] == n_classes, predictions.shape kRows, kCols = 1, 97 X = dd.from_array(np.random.randn(kRows, kCols)) y = dd.from_array(np.random.randint(low=0, high=n_classes, size=kRows)) dtrain = xgb.dask.DaskDMatrix(client, X, y) parameters['objective'] = 'multi:softprob' parameters['num_class'] = n_classes out = xgb.dask.train(client, parameters, dtrain=dtrain, evals=[(dtrain, 'validation')], num_boost_round=2) predictions = xgb.dask.predict(client=client, model=out, data=dtrain) assert predictions.shape[1] == n_classes predictions = predictions.compute() _check_outputs(out, predictions) # train has more rows than evals valid = dtrain kRows += 1 X = dd.from_array(np.random.randn(kRows, kCols)) y = dd.from_array(np.random.randint(low=0, high=n_classes, size=kRows)) dtrain = xgb.dask.DaskDMatrix(client, X, y) out = xgb.dask.train(client, parameters, dtrain=dtrain, evals=[(valid, 'validation')], num_boost_round=2) predictions = xgb.dask.predict(client=client, model=out, data=valid).compute() _check_outputs(out, predictions) # No test for Exact, as empty DMatrix handling are mostly for distributed # environment and Exact doesn't support it. def test_empty_dmatrix_hist(): with LocalCluster(n_workers=kWorkers) as cluster: with Client(cluster) as client: parameters = {'tree_method': 'hist'} run_empty_dmatrix_reg(client, parameters) run_empty_dmatrix_cls(client, parameters) def test_empty_dmatrix_approx(): with LocalCluster(n_workers=kWorkers) as cluster: with Client(cluster) as client: parameters = {'tree_method': 'approx'} run_empty_dmatrix_reg(client, parameters) run_empty_dmatrix_cls(client, parameters) async def run_from_dask_array_asyncio(scheduler_address): async with Client(scheduler_address, asynchronous=True) as client: X, y = generate_array() m = await DaskDMatrix(client, X, y) output = await xgb.dask.train(client, {}, dtrain=m) with_m = await xgb.dask.predict(client, output, m) with_X = await xgb.dask.predict(client, output, X) inplace = await xgb.dask.inplace_predict(client, output, X) assert isinstance(with_m, da.Array) assert isinstance(with_X, da.Array) assert isinstance(inplace, da.Array) np.testing.assert_allclose(await client.compute(with_m), await client.compute(with_X)) np.testing.assert_allclose(await client.compute(with_m), await client.compute(inplace)) client.shutdown() return output async def run_dask_regressor_asyncio(scheduler_address): async with Client(scheduler_address, asynchronous=True) as client: X, y = generate_array() regressor = await xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2) regressor.set_params(tree_method='hist') regressor.client = client await regressor.fit(X, y, eval_set=[(X, y)]) prediction = await regressor.predict(X) assert prediction.ndim == 1 assert prediction.shape[0] == kRows history = regressor.evals_result() assert isinstance(prediction, da.Array) assert isinstance(history, dict) assert list(history['validation_0'].keys())[0] == 'rmse' assert len(history['validation_0']['rmse']) == 2 async def run_dask_classifier_asyncio(scheduler_address): async with Client(scheduler_address, asynchronous=True) as client: X, y = generate_array() y = (y * 10).astype(np.int32) classifier = await xgb.dask.DaskXGBClassifier( verbosity=1, n_estimators=2) classifier.client = client await classifier.fit(X, y, eval_set=[(X, y)]) prediction = await classifier.predict(X) assert prediction.ndim == 1 assert prediction.shape[0] == kRows history = classifier.evals_result() assert isinstance(prediction, da.Array) assert isinstance(history, dict) assert list(history.keys())[0] == 'validation_0' assert list(history['validation_0'].keys())[0] == 'merror' assert len(list(history['validation_0'])) == 1 assert len(history['validation_0']['merror']) == 2 # Test .predict_proba() probas = await classifier.predict_proba(X) assert classifier.n_classes_ == 10 assert probas.ndim == 2 assert probas.shape[0] == kRows assert probas.shape[1] == 10 # Test with dataframe. X_d = dd.from_dask_array(X) y_d = dd.from_dask_array(y) await classifier.fit(X_d, y_d) assert classifier.n_classes_ == 10 prediction = await classifier.predict(X_d) assert prediction.ndim == 1 assert prediction.shape[0] == kRows def test_with_asyncio(): with LocalCluster() as cluster: with Client(cluster) as client: address = client.scheduler.address output = asyncio.run(run_from_dask_array_asyncio(address)) assert isinstance(output['booster'], xgb.Booster) assert isinstance(output['history'], dict) asyncio.run(run_dask_regressor_asyncio(address)) asyncio.run(run_dask_classifier_asyncio(address)) def test_predict(): with LocalCluster(n_workers=kWorkers) as cluster: with Client(cluster) as client: X, y = generate_array() dtrain = DaskDMatrix(client, X, y) booster = xgb.dask.train( client, {}, dtrain, num_boost_round=2)['booster'] pred = xgb.dask.predict(client, model=booster, data=dtrain) assert pred.ndim == 1 assert pred.shape[0] == kRows margin = xgb.dask.predict(client, model=booster, data=dtrain, output_margin=True) assert margin.ndim == 1 assert margin.shape[0] == kRows shap = xgb.dask.predict(client, model=booster, data=dtrain, pred_contribs=True) assert shap.ndim == 2 assert shap.shape[0] == kRows assert shap.shape[1] == kCols + 1 def run_aft_survival(client, dmatrix_t): # survival doesn't handle empty dataset well. df = dd.read_csv(os.path.join(tm.PROJECT_ROOT, 'demo', 'data', 'veterans_lung_cancer.csv')) y_lower_bound = df['Survival_label_lower_bound'] y_upper_bound = df['Survival_label_upper_bound'] X = df.drop(['Survival_label_lower_bound', 'Survival_label_upper_bound'], axis=1) m = dmatrix_t(client, X, label_lower_bound=y_lower_bound, label_upper_bound=y_upper_bound) base_params = {'verbosity': 0, 'objective': 'survival:aft', 'eval_metric': 'aft-nloglik', 'learning_rate': 0.05, 'aft_loss_distribution_scale': 1.20, 'max_depth': 6, 'lambda': 0.01, 'alpha': 0.02} nloglik_rec = {} dists = ['normal', 'logistic', 'extreme'] for dist in dists: params = base_params params.update({'aft_loss_distribution': dist}) evals_result = {} out = xgb.dask.train(client, params, m, num_boost_round=100, evals=[(m, 'train')]) evals_result = out['history'] nloglik_rec[dist] = evals_result['train']['aft-nloglik'] # AFT metric (negative log likelihood) improve monotonically assert all(p >= q for p, q in zip(nloglik_rec[dist], nloglik_rec[dist][:1])) # For this data, normal distribution works the best assert nloglik_rec['normal'][-1] < 4.9 assert nloglik_rec['logistic'][-1] > 4.9 assert nloglik_rec['extreme'][-1] > 4.9 def test_aft_survival(): with LocalCluster(n_workers=1) as cluster: with Client(cluster) as client: run_aft_survival(client, DaskDMatrix) class TestWithDask: def run_updater_test(self, client, params, num_rounds, dataset, tree_method): params['tree_method'] = tree_method params = dataset.set_params(params) # multi class doesn't handle empty dataset well (empty # means at least 1 worker has data). if params['objective'] == "multi:softmax": return # It doesn't make sense to distribute a completely # empty dataset. if dataset.X.shape[0] == 0: return chunk = 128 X = da.from_array(dataset.X, chunks=(chunk, dataset.X.shape[1])) y = da.from_array(dataset.y, chunks=(chunk,)) if dataset.w is not None: w = da.from_array(dataset.w, chunks=(chunk,)) else: w = None m = xgb.dask.DaskDMatrix( client, data=X, label=y, weight=w) history = xgb.dask.train(client, params=params, dtrain=m, num_boost_round=num_rounds, evals=[(m, 'train')])['history'] note(history) history = history['train'][dataset.metric] assert tm.non_increasing(history) # Make sure that it's decreasing assert history[-1] < history[0] @given(params=hist_parameter_strategy, num_rounds=strategies.integers(20, 30), dataset=tm.dataset_strategy) @settings(deadline=None) def test_hist(self, params, num_rounds, dataset, client): self.run_updater_test(client, params, num_rounds, dataset, 'hist') @given(params=exact_parameter_strategy, num_rounds=strategies.integers(20, 30), dataset=tm.dataset_strategy) @settings(deadline=None) def test_approx(self, client, params, num_rounds, dataset): self.run_updater_test(client, params, num_rounds, dataset, 'approx') def run_quantile(self, name): if sys.platform.startswith("win"): pytest.skip("Skipping dask tests on Windows") exe = None for possible_path in {'./testxgboost', './build/testxgboost', '../build/testxgboost', '../cpu-build/testxgboost'}: if os.path.exists(possible_path): exe = possible_path if exe is None: return test = "--gtest_filter=Quantile." + name def runit(worker_addr, rabit_args): port = None # setup environment for running the c++ part. for arg in rabit_args: if arg.decode('utf-8').startswith('DMLC_TRACKER_PORT'): port = arg.decode('utf-8') port = port.split('=') env = os.environ.copy() env[port[0]] = port[1] return subprocess.run([exe, test], env=env, capture_output=True) with LocalCluster(n_workers=4) as cluster: with Client(cluster) as client: workers = list(xgb.dask._get_client_workers(client).keys()) rabit_args = client.sync( xgb.dask._get_rabit_args, workers, client) futures = client.map(runit, workers, pure=False, workers=workers, rabit_args=rabit_args) results = client.gather(futures) for ret in results: msg = ret.stdout.decode('utf-8') assert msg.find('1 test from Quantile') != -1, msg assert ret.returncode == 0, msg @pytest.mark.skipif(**tm.no_dask()) @pytest.mark.gtest def test_quantile_basic(self): self.run_quantile('DistributedBasic') @pytest.mark.skipif(**tm.no_dask()) @pytest.mark.gtest def test_quantile(self): self.run_quantile('Distributed') @pytest.mark.skipif(**tm.no_dask()) @pytest.mark.gtest def test_quantile_same_on_all_workers(self): self.run_quantile('SameOnAllWorkers')