Extract dask and spark test into distributed test. (#8395)

- Move test files.
- Run spark and dask separately to prevent conflicts.
- Gather common code into the testing module.
This commit is contained in:
Jiaming Yuan 2022-10-28 16:24:32 +08:00 committed by GitHub
parent f73520bfff
commit cfd2a9f872
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
34 changed files with 405 additions and 337 deletions

View File

@ -125,7 +125,7 @@ jobs:
- name: Test Python package - name: Test Python package
shell: bash -l {0} shell: bash -l {0}
run: | run: |
pytest -s -v ./tests/python pytest -s -v -rxXs --durations=0 ./tests/python
python-tests-on-macos: python-tests-on-macos:
name: Test XGBoost Python package on ${{ matrix.config.os }} name: Test XGBoost Python package on ${{ matrix.config.os }}
@ -177,4 +177,9 @@ jobs:
- name: Test Python package - name: Test Python package
shell: bash -l {0} shell: bash -l {0}
run: | run: |
pytest -s -v ./tests/python pytest -s -v -rxXs --durations=0 ./tests/python
- name: Test Dask Interface
shell: bash -l {0}
run: |
pytest -s -v -rxXs --durations=0 ./tests/test_distributed/test_with_dask

View File

@ -87,7 +87,7 @@ class PartIter(DataIter):
# We must set the device after import cudf, which will change the device id to 0 # We must set the device after import cudf, which will change the device id to 0
# See https://github.com/rapidsai/cudf/issues/11386 # See https://github.com/rapidsai/cudf/issues/11386
cp.cuda.runtime.setDevice(self._device_id) cp.cuda.runtime.setDevice(self._device_id) # pylint: disable=I1101
return cudf.DataFrame(data[self._iter]) return cudf.DataFrame(data[self._iter])
return data[self._iter] return data[self._iter]

View File

@ -102,10 +102,14 @@ def no_sklearn() -> PytestSkip:
def no_dask() -> PytestSkip: def no_dask() -> PytestSkip:
if sys.platform.startswith("win"):
return {"reason": "Unsupported platform.", "condition": True}
return no_mod("dask") return no_mod("dask")
def no_spark() -> PytestSkip: def no_spark() -> PytestSkip:
if sys.platform.startswith("win") or sys.platform.startswith("darwin"):
return {"reason": "Unsupported platform.", "condition": True}
return no_mod("pyspark") return no_mod("pyspark")
@ -159,6 +163,10 @@ def no_graphviz() -> PytestSkip:
return no_mod("graphviz") return no_mod("graphviz")
def no_rmm() -> PytestSkip:
return no_mod("rmm")
def no_multiple(*args: Any) -> PytestSkip: def no_multiple(*args: Any) -> PytestSkip:
condition = False condition = False
reason = "" reason = ""
@ -865,6 +873,30 @@ def timeout(sec: int, *args: Any, enable: bool = True, **kwargs: Any) -> Any:
return pytest.mark.timeout(None, *args, **kwargs) return pytest.mark.timeout(None, *args, **kwargs)
def setup_rmm_pool(_: Any, pytestconfig: pytest.Config) -> None:
if pytestconfig.getoption("--use-rmm-pool"):
if no_rmm()["condition"]:
raise ImportError("The --use-rmm-pool option requires the RMM package")
if no_dask_cuda()["condition"]:
raise ImportError(
"The --use-rmm-pool option requires the dask_cuda package"
)
import rmm
from dask_cuda.utils import get_n_gpus
rmm.reinitialize(
pool_allocator=True,
initial_pool_size=1024 * 1024 * 1024,
devices=list(range(get_n_gpus())),
)
def get_client_workers(client: Any) -> List[str]:
"Get workers from a dask client."
workers = client.scheduler_info()["workers"]
return list(workers.keys())
def demo_dir(path: str) -> str: def demo_dir(path: str) -> str:
"""Look for the demo directory based on the test file name.""" """Look for the demo directory based on the test file name."""
path = normpath(os.path.dirname(path)) path = normpath(os.path.dirname(path))

View File

@ -0,0 +1,49 @@
"""Strategies for updater tests."""
from typing import cast
import pytest
hypothesis = pytest.importorskip("hypothesis")
from hypothesis import strategies # pylint:disable=wrong-import-position
exact_parameter_strategy = strategies.fixed_dictionaries(
{
"nthread": strategies.integers(1, 4),
"max_depth": strategies.integers(1, 11),
"min_child_weight": strategies.floats(0.5, 2.0),
"alpha": strategies.floats(1e-5, 2.0),
"lambda": strategies.floats(1e-5, 2.0),
"eta": strategies.floats(0.01, 0.5),
"gamma": strategies.floats(1e-5, 2.0),
"seed": strategies.integers(0, 10),
# We cannot enable subsampling as the training loss can increase
# 'subsample': strategies.floats(0.5, 1.0),
"colsample_bytree": strategies.floats(0.5, 1.0),
"colsample_bylevel": strategies.floats(0.5, 1.0),
}
)
hist_parameter_strategy = strategies.fixed_dictionaries(
{
"max_depth": strategies.integers(1, 11),
"max_leaves": strategies.integers(0, 1024),
"max_bin": strategies.integers(2, 512),
"grow_policy": strategies.sampled_from(["lossguide", "depthwise"]),
"min_child_weight": strategies.floats(0.5, 2.0),
# We cannot enable subsampling as the training loss can increase
# 'subsample': strategies.floats(0.5, 1.0),
"colsample_bytree": strategies.floats(0.5, 1.0),
"colsample_bylevel": strategies.floats(0.5, 1.0),
}
).filter(
lambda x: (cast(int, x["max_depth"]) > 0 or cast(int, x["max_leaves"]) > 0)
and (cast(int, x["max_depth"]) > 0 or x["grow_policy"] == "lossguide")
)
cat_parameter_strategy = strategies.fixed_dictionaries(
{
"max_cat_to_onehot": strategies.integers(1, 128),
"max_cat_threshold": strategies.integers(1, 128),
}
)

View File

@ -0,0 +1,95 @@
"""Testing code shared by other tests."""
# pylint: disable=invalid-name
import collections
import importlib.util
import json
import os
import tempfile
from typing import Any, Callable, Dict, Type
import numpy as np
from xgboost._typing import ArrayLike
import xgboost as xgb
def validate_leaf_output(leaf: np.ndarray, num_parallel_tree: int) -> None:
"""Validate output for predict leaf tests."""
for i in range(leaf.shape[0]): # n_samples
for j in range(leaf.shape[1]): # n_rounds
for k in range(leaf.shape[2]): # n_classes
tree_group = leaf[i, j, k, :]
assert tree_group.shape[0] == num_parallel_tree
# No sampling, all trees within forest are the same
assert np.all(tree_group == tree_group[0])
def validate_data_initialization(
dmatrix: Type, model: Type[xgb.XGBModel], X: ArrayLike, y: ArrayLike
) -> None:
"""Assert that we don't create duplicated DMatrix."""
old_init = dmatrix.__init__
count = [0]
def new_init(self: Any, **kwargs: Any) -> Callable:
count[0] += 1
return old_init(self, **kwargs)
dmatrix.__init__ = new_init
model(n_estimators=1).fit(X, y, eval_set=[(X, y)])
assert count[0] == 1
count[0] = 0 # only 1 DMatrix is created.
y_copy = y.copy()
model(n_estimators=1).fit(X, y, eval_set=[(X, y_copy)])
assert count[0] == 2 # a different Python object is considered different
dmatrix.__init__ = old_init
# pylint: disable=too-many-arguments,too-many-locals
def get_feature_weights(
X: ArrayLike,
y: ArrayLike,
fw: np.ndarray,
parser_path: str,
tree_method: str,
model: Type[xgb.XGBModel] = xgb.XGBRegressor,
) -> np.ndarray:
"""Get feature weights using the demo parser."""
with tempfile.TemporaryDirectory() as tmpdir:
colsample_bynode = 0.5
reg = model(tree_method=tree_method, colsample_bynode=colsample_bynode)
reg.fit(X, y, feature_weights=fw)
model_path = os.path.join(tmpdir, "model.json")
reg.save_model(model_path)
with open(model_path, "r", encoding="utf-8") as fd:
model = json.load(fd)
spec = importlib.util.spec_from_file_location("JsonParser", parser_path)
assert spec is not None
jsonm = importlib.util.module_from_spec(spec)
assert spec.loader is not None
spec.loader.exec_module(jsonm)
model = jsonm.Model(model)
splits: Dict[int, int] = {}
total_nodes = 0
for tree in model.trees:
n_nodes = len(tree.nodes)
total_nodes += n_nodes
for n in range(n_nodes):
if tree.is_leaf(n):
continue
if splits.get(tree.split_index(n), None) is None:
splits[tree.split_index(n)] = 1
else:
splits[tree.split_index(n)] += 1
od = collections.OrderedDict(sorted(splits.items()))
tuples = list(od.items())
k, v = list(zip(*tuples))
w = np.polyfit(k, v, deg=1)
return w

View File

@ -10,8 +10,7 @@ facilities.
dependencies for tests, see conda files in `ci_build`. dependencies for tests, see conda files in `ci_build`.
* python-gpu: Similar to python tests, but for GPU. * python-gpu: Similar to python tests, but for GPU.
* travis: CI facilities for Travis. * travis: CI facilities for Travis.
* distributed: Legacy tests for distributed system. Most of the distributed tests are * distributed: Test for distributed system.
in Python tests using `dask` and jvm package using `spark`.
* benchmark: Legacy benchmark code. There are a number of benchmark projects for * benchmark: Legacy benchmark code. There are a number of benchmark projects for
XGBoost with much better configurations. XGBoost with much better configurations.

View File

@ -17,3 +17,5 @@ dependencies:
- isort - isort
- pyspark - pyspark
- cloudpickle - cloudpickle
- pytest
- hypothesis

View File

@ -103,7 +103,12 @@ class PyLint:
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser(
description=(
"Run static checkers for XGBoost, see `python_lint.yml' "
"conda env file for a list of dependencies."
)
)
parser.add_argument("--format", type=int, choices=[0, 1], default=1) parser.add_argument("--format", type=int, choices=[0, 1], default=1)
parser.add_argument("--type-check", type=int, choices=[0, 1], default=1) parser.add_argument("--type-check", type=int, choices=[0, 1], default=1)
parser.add_argument("--pylint", type=int, choices=[0, 1], default=1) parser.add_argument("--pylint", type=int, choices=[0, 1], default=1)
@ -125,11 +130,11 @@ if __name__ == "__main__":
# tests # tests
"tests/python/test_config.py", "tests/python/test_config.py",
"tests/python/test_data_iterator.py", "tests/python/test_data_iterator.py",
"tests/python/test_spark/",
"tests/python/test_quantile_dmatrix.py", "tests/python/test_quantile_dmatrix.py",
"tests/python-gpu/test_gpu_spark/",
"tests/python-gpu/test_gpu_data_iterator.py", "tests/python-gpu/test_gpu_data_iterator.py",
"tests/ci_build/lint_python.py", "tests/ci_build/lint_python.py",
"tests/test_distributed/test_with_spark/",
"tests/test_distributed/test_gpu_with_spark/",
# demo # demo
"demo/guide-python/cat_in_the_dat.py", "demo/guide-python/cat_in_the_dat.py",
"demo/guide-python/categorical.py", "demo/guide-python/categorical.py",
@ -146,11 +151,11 @@ if __name__ == "__main__":
"demo/guide-python/external_memory.py", "demo/guide-python/external_memory.py",
"demo/guide-python/cat_in_the_dat.py", "demo/guide-python/cat_in_the_dat.py",
"tests/python/test_data_iterator.py", "tests/python/test_data_iterator.py",
"tests/python/test_spark/test_data.py",
"tests/python-gpu/test_gpu_with_dask/test_gpu_with_dask.py",
"tests/python-gpu/test_gpu_data_iterator.py", "tests/python-gpu/test_gpu_data_iterator.py",
"tests/python-gpu/test_gpu_spark/test_data.py",
"tests/ci_build/lint_python.py", "tests/ci_build/lint_python.py",
"tests/test_distributed/test_with_spark/test_data.py",
"tests/test_distributed/test_gpu_with_spark/test_data.py",
"tests/test_distributed/test_gpu_with_dask/test_gpu_with_dask.py",
] ]
): ):
sys.exit(-1) sys.exit(-1)

View File

@ -68,6 +68,8 @@ case "$suite" in
install_xgboost install_xgboost
setup_pyspark_envs setup_pyspark_envs
pytest -v -s -rxXs --fulltrace --durations=0 -m "mgpu" ${args} tests/python-gpu pytest -v -s -rxXs --fulltrace --durations=0 -m "mgpu" ${args} tests/python-gpu
pytest -v -s -rxXs --fulltrace --durations=0 -m "mgpu" ${args} tests/test_distributed/test_gpu_with_dask
pytest -v -s -rxXs --fulltrace --durations=0 -m "mgpu" ${args} tests/test_distributed/test_gpu_with_spark
unset_pyspark_envs unset_pyspark_envs
uninstall_xgboost uninstall_xgboost
set +x set +x
@ -80,6 +82,8 @@ case "$suite" in
export RAY_OBJECT_STORE_ALLOW_SLOW_STORAGE=1 export RAY_OBJECT_STORE_ALLOW_SLOW_STORAGE=1
setup_pyspark_envs setup_pyspark_envs
pytest -v -s -rxXs --fulltrace --durations=0 ${args} tests/python pytest -v -s -rxXs --fulltrace --durations=0 ${args} tests/python
pytest -v -s -rxXs --fulltrace --durations=0 ${args} tests/test_distributed/test_with_dask
pytest -v -s -rxXs --fulltrace --durations=0 ${args} tests/test_distributed/test_with_spark
unset_pyspark_envs unset_pyspark_envs
uninstall_xgboost uninstall_xgboost
set +x set +x

View File

@ -1,43 +1,21 @@
import pytest import pytest
from xgboost import testing as tm # noqa from xgboost import testing as tm
def has_rmm(): def has_rmm():
try: return tm.no_rmm()["condition"]
import rmm
return True
except ImportError:
return False
@pytest.fixture(scope='session', autouse=True)
@pytest.fixture(scope="session", autouse=True)
def setup_rmm_pool(request, pytestconfig): def setup_rmm_pool(request, pytestconfig):
if pytestconfig.getoption('--use-rmm-pool'): tm.setup_rmm_pool(request, pytestconfig)
if not has_rmm():
raise ImportError('The --use-rmm-pool option requires the RMM package')
import rmm
from dask_cuda.utils import get_n_gpus
rmm.reinitialize(pool_allocator=True, initial_pool_size=1024*1024*1024,
devices=list(range(get_n_gpus())))
@pytest.fixture(scope='class')
def local_cuda_client(request, pytestconfig):
kwargs = {}
if hasattr(request, 'param'):
kwargs.update(request.param)
if pytestconfig.getoption('--use-rmm-pool'):
if not has_rmm():
raise ImportError('The --use-rmm-pool option requires the RMM package')
import rmm
kwargs['rmm_pool_size'] = '2GB'
if tm.no_dask_cuda()['condition']:
raise ImportError('The local_cuda_cluster fixture requires dask_cuda package')
from dask.distributed import Client
from dask_cuda import LocalCUDACluster
yield Client(LocalCUDACluster(**kwargs))
def pytest_addoption(parser): def pytest_addoption(parser: pytest.Parser) -> None:
parser.addoption('--use-rmm-pool', action='store_true', default=False, help='Use RMM pool') parser.addoption(
"--use-rmm-pool", action="store_true", default=False, help="Use RMM pool"
)
def pytest_collection_modifyitems(config, items): def pytest_collection_modifyitems(config, items):
@ -53,13 +31,3 @@ def pytest_collection_modifyitems(config, items):
for item in items: for item in items:
if any(item.nodeid.startswith(x) for x in blocklist): if any(item.nodeid.startswith(x) for x in blocklist):
item.add_marker(skip_mark) item.add_marker(skip_mark)
# mark dask tests as `mgpu`.
mgpu_mark = pytest.mark.mgpu
for item in items:
if item.nodeid.startswith(
"python-gpu/test_gpu_with_dask/test_gpu_with_dask.py"
) or item.nodeid.startswith(
"python-gpu/test_gpu_spark/test_gpu_spark.py"
):
item.add_marker(mgpu_mark)

View File

@ -1,23 +0,0 @@
import sys
import pytest
from xgboost import testing as tm
if tm.no_spark()["condition"]:
pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True)
if sys.platform.startswith("win") or sys.platform.startswith("darwin"):
pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True)
sys.path.append("tests/python")
from test_spark.test_data import run_dmatrix_ctor
@pytest.mark.skipif(**tm.no_cudf())
@pytest.mark.parametrize(
"is_feature_cols,is_qdm",
[(True, True), (True, False), (False, True), (False, False)],
)
def test_dmatrix_ctor(is_feature_cols: bool, is_qdm: bool) -> None:
run_dmatrix_ctor(is_feature_cols, is_qdm, on_gpu=True)

View File

@ -7,26 +7,16 @@ from hypothesis import assume, given, note, settings, strategies
import xgboost as xgb import xgboost as xgb
from xgboost import testing as tm from xgboost import testing as tm
from xgboost.testing.params import (
hist_parameter_strategy,
cat_parameter_strategy,
)
sys.path.append("tests/python") sys.path.append("tests/python")
import test_updaters as test_up import test_updaters as test_up
pytestmark = tm.timeout(30) pytestmark = tm.timeout(30)
parameter_strategy = strategies.fixed_dictionaries({
'max_depth': strategies.integers(0, 11),
'max_leaves': strategies.integers(0, 256),
'max_bin': strategies.integers(2, 1024),
'grow_policy': strategies.sampled_from(['lossguide', 'depthwise']),
'min_child_weight': strategies.floats(0.5, 2.0),
'seed': strategies.integers(0, 10),
# We cannot enable subsampling as the training loss can increase
# 'subsample': strategies.floats(0.5, 1.0),
'colsample_bytree': strategies.floats(0.5, 1.0),
'colsample_bylevel': strategies.floats(0.5, 1.0),
}).filter(lambda x: (x['max_depth'] > 0 or x['max_leaves'] > 0) and (
x['max_depth'] > 0 or x['grow_policy'] == 'lossguide'))
def train_result(param, dmat: xgb.DMatrix, num_rounds: int) -> dict: def train_result(param, dmat: xgb.DMatrix, num_rounds: int) -> dict:
result: xgb.callback.TrainingCallback.EvalsLog = {} result: xgb.callback.TrainingCallback.EvalsLog = {}
@ -47,7 +37,7 @@ def train_result(param, dmat: xgb.DMatrix, num_rounds: int) -> dict:
class TestGPUUpdaters: class TestGPUUpdaters:
cputest = test_up.TestTreeMethod() cputest = test_up.TestTreeMethod()
@given(parameter_strategy, strategies.integers(1, 20), tm.dataset_strategy) @given(hist_parameter_strategy, strategies.integers(1, 20), tm.dataset_strategy)
@settings(deadline=None, max_examples=50, print_blob=True) @settings(deadline=None, max_examples=50, print_blob=True)
def test_gpu_hist(self, param, num_rounds, dataset): def test_gpu_hist(self, param, num_rounds, dataset):
param["tree_method"] = "gpu_hist" param["tree_method"] = "gpu_hist"
@ -82,9 +72,8 @@ class TestGPUUpdaters:
@given( @given(
tm.categorical_dataset_strategy, tm.categorical_dataset_strategy,
test_up.exact_parameter_strategy, hist_parameter_strategy,
test_up.hist_parameter_strategy, cat_parameter_strategy,
test_up.cat_parameter_strategy,
strategies.integers(4, 32), strategies.integers(4, 32),
) )
@settings(deadline=None, max_examples=20, print_blob=True) @settings(deadline=None, max_examples=20, print_blob=True)
@ -92,12 +81,10 @@ class TestGPUUpdaters:
def test_categorical( def test_categorical(
self, self,
dataset: tm.TestDataset, dataset: tm.TestDataset,
exact_parameters: Dict[str, Any],
hist_parameters: Dict[str, Any], hist_parameters: Dict[str, Any],
cat_parameters: Dict[str, Any], cat_parameters: Dict[str, Any],
n_rounds: int, n_rounds: int,
) -> None: ) -> None:
cat_parameters.update(exact_parameters)
cat_parameters.update(hist_parameters) cat_parameters.update(hist_parameters)
cat_parameters["tree_method"] = "gpu_hist" cat_parameters["tree_method"] = "gpu_hist"
@ -105,8 +92,8 @@ class TestGPUUpdaters:
tm.non_increasing(results["train"]["rmse"]) tm.non_increasing(results["train"]["rmse"])
@given( @given(
test_up.hist_parameter_strategy, hist_parameter_strategy,
test_up.cat_parameter_strategy, cat_parameter_strategy,
) )
@settings(deadline=None, max_examples=10, print_blob=True) @settings(deadline=None, max_examples=10, print_blob=True)
def test_categorical_ames_housing( def test_categorical_ames_housing(
@ -149,8 +136,11 @@ class TestGPUUpdaters:
self.cputest.run_invalid_category("gpu_hist") self.cputest.run_invalid_category("gpu_hist")
@pytest.mark.skipif(**tm.no_cupy()) @pytest.mark.skipif(**tm.no_cupy())
@given(parameter_strategy, strategies.integers(1, 20), @given(
tm.dataset_strategy) hist_parameter_strategy,
strategies.integers(1, 20),
tm.dataset_strategy
)
@settings(deadline=None, max_examples=20, print_blob=True) @settings(deadline=None, max_examples=20, print_blob=True)
def test_gpu_hist_device_dmatrix(self, param, num_rounds, dataset): def test_gpu_hist_device_dmatrix(self, param, num_rounds, dataset):
# We cannot handle empty dataset yet # We cannot handle empty dataset yet
@ -161,8 +151,11 @@ class TestGPUUpdaters:
note(result) note(result)
assert tm.non_increasing(result['train'][dataset.metric], tolerance=1e-3) assert tm.non_increasing(result['train'][dataset.metric], tolerance=1e-3)
@given(parameter_strategy, strategies.integers(1, 3), @given(
tm.dataset_strategy) hist_parameter_strategy,
strategies.integers(1, 3),
tm.dataset_strategy
)
@settings(deadline=None, max_examples=10, print_blob=True) @settings(deadline=None, max_examples=10, print_blob=True)
def test_external_memory(self, param, num_rounds, dataset): def test_external_memory(self, param, num_rounds, dataset):
if dataset.name.endswith("-l1"): if dataset.name.endswith("-l1"):

View File

@ -5,6 +5,7 @@ import numpy as np
import pandas as pd import pandas as pd
import pytest import pytest
from scipy import sparse from scipy import sparse
from xgboost.testing.shared import validate_leaf_output
import xgboost as xgb import xgboost as xgb
from xgboost import testing as tm from xgboost import testing as tm
@ -26,16 +27,6 @@ def run_threaded_predict(X, rows, predict_func):
assert f.result() assert f.result()
def verify_leaf_output(leaf: np.ndarray, num_parallel_tree: int):
for i in range(leaf.shape[0]): # n_samples
for j in range(leaf.shape[1]): # n_rounds
for k in range(leaf.shape[2]): # n_classes
tree_group = leaf[i, j, k, :]
assert tree_group.shape[0] == num_parallel_tree
# No sampling, all trees within forest are the same
assert np.all(tree_group == tree_group[0])
def run_predict_leaf(predictor): def run_predict_leaf(predictor):
rows = 100 rows = 100
cols = 4 cols = 4
@ -67,7 +58,7 @@ def run_predict_leaf(predictor):
assert leaf.shape[2] == classes assert leaf.shape[2] == classes
assert leaf.shape[3] == num_parallel_tree assert leaf.shape[3] == num_parallel_tree
verify_leaf_output(leaf, num_parallel_tree) validate_leaf_output(leaf, num_parallel_tree)
ntree_limit = 2 ntree_limit = 2
sliced = booster.predict( sliced = booster.predict(

View File

@ -7,6 +7,7 @@ import pytest
import xgboost as xgb import xgboost as xgb
from xgboost import RabitTracker from xgboost import RabitTracker
from xgboost import testing as tm from xgboost import testing as tm
from xgboost import collective
if sys.platform.startswith("win"): if sys.platform.startswith("win"):
pytest.skip("Skipping dask tests on Windows", allow_module_level=True) pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
@ -21,12 +22,9 @@ def test_rabit_tracker():
def run_rabit_ops(client, n_workers): def run_rabit_ops(client, n_workers):
from test_with_dask import _get_client_workers
from xgboost.dask import CommunicatorContext, _get_dask_config, _get_rabit_args from xgboost.dask import CommunicatorContext, _get_dask_config, _get_rabit_args
from xgboost import collective workers = tm.get_client_workers(client)
workers = _get_client_workers(client)
rabit_args = client.sync(_get_rabit_args, len(workers), _get_dask_config(), client) rabit_args = client.sync(_get_rabit_args, len(workers), _get_dask_config(), client)
assert not collective.is_distributed() assert not collective.is_distributed()
n_workers_from_dask = len(workers) n_workers_from_dask = len(workers)
@ -76,7 +74,6 @@ def test_rabit_ops_ipv6():
def test_rank_assignment() -> None: def test_rank_assignment() -> None:
from distributed import Client, LocalCluster from distributed import Client, LocalCluster
from test_with_dask import _get_client_workers
def local_test(worker_id): def local_test(worker_id):
with xgb.dask.CommunicatorContext(**args) as ctx: with xgb.dask.CommunicatorContext(**args) as ctx:
@ -89,7 +86,7 @@ def test_rank_assignment() -> None:
with LocalCluster(n_workers=8) as cluster: with LocalCluster(n_workers=8) as cluster:
with Client(cluster) as client: with Client(cluster) as client:
workers = _get_client_workers(client) workers = tm.get_client_workers(client)
args = client.sync( args = client.sync(
xgb.dask._get_rabit_args, xgb.dask._get_rabit_args,
len(workers), len(workers),

View File

@ -8,36 +8,10 @@ from hypothesis import given, note, settings, strategies
import xgboost as xgb import xgboost as xgb
from xgboost import testing as tm from xgboost import testing as tm
from xgboost.testing.params import (
exact_parameter_strategy = strategies.fixed_dictionaries({ exact_parameter_strategy,
'nthread': strategies.integers(1, 4), hist_parameter_strategy,
'max_depth': strategies.integers(1, 11), cat_parameter_strategy,
'min_child_weight': strategies.floats(0.5, 2.0),
'alpha': strategies.floats(1e-5, 2.0),
'lambda': strategies.floats(1e-5, 2.0),
'eta': strategies.floats(0.01, 0.5),
'gamma': strategies.floats(1e-5, 2.0),
'seed': strategies.integers(0, 10),
# We cannot enable subsampling as the training loss can increase
# 'subsample': strategies.floats(0.5, 1.0),
'colsample_bytree': strategies.floats(0.5, 1.0),
'colsample_bylevel': strategies.floats(0.5, 1.0),
})
hist_parameter_strategy = strategies.fixed_dictionaries({
'max_depth': strategies.integers(1, 11),
'max_leaves': strategies.integers(0, 1024),
'max_bin': strategies.integers(2, 512),
'grow_policy': strategies.sampled_from(['lossguide', 'depthwise']),
}).filter(lambda x: (x['max_depth'] > 0 or x['max_leaves'] > 0) and (
x['max_depth'] > 0 or x['grow_policy'] == 'lossguide'))
cat_parameter_strategy = strategies.fixed_dictionaries(
{
"max_cat_to_onehot": strategies.integers(1, 128),
"max_cat_threshold": strategies.integers(1, 128),
}
) )

View File

@ -1,5 +1,3 @@
import collections
import importlib.util
import json import json
import os import os
import random import random
@ -9,6 +7,7 @@ from typing import Callable, Optional
import numpy as np import numpy as np
import pytest import pytest
from sklearn.utils.estimator_checks import parametrize_with_checks from sklearn.utils.estimator_checks import parametrize_with_checks
from xgboost.testing.shared import get_feature_weights, validate_data_initialization
import xgboost as xgb import xgboost as xgb
from xgboost import testing as tm from xgboost import testing as tm
@ -1031,45 +1030,6 @@ def test_pandas_input():
np.testing.assert_allclose(np.array(clf_isotonic.classes_), np.array([0, 1])) np.testing.assert_allclose(np.array(clf_isotonic.classes_), np.array([0, 1]))
def run_feature_weights(X, y, fw, tree_method, model=xgb.XGBRegressor):
with tempfile.TemporaryDirectory() as tmpdir:
colsample_bynode = 0.5
reg = model(tree_method=tree_method, colsample_bynode=colsample_bynode)
reg.fit(X, y, feature_weights=fw)
model_path = os.path.join(tmpdir, 'model.json')
reg.save_model(model_path)
with open(model_path) as fd:
model = json.load(fd)
parser_path = os.path.join(
tm.demo_dir(__file__), "json-model", "json_parser.py"
)
spec = importlib.util.spec_from_file_location("JsonParser",
parser_path)
foo = importlib.util.module_from_spec(spec)
spec.loader.exec_module(foo)
model = foo.Model(model)
splits = {}
total_nodes = 0
for tree in model.trees:
n_nodes = len(tree.nodes)
total_nodes += n_nodes
for n in range(n_nodes):
if tree.is_leaf(n):
continue
if splits.get(tree.split_index(n), None) is None:
splits[tree.split_index(n)] = 1
else:
splits[tree.split_index(n)] += 1
od = collections.OrderedDict(sorted(splits.items()))
tuples = [(k, v) for k, v in od.items()]
k, v = list(zip(*tuples))
w = np.polyfit(k, v, deg=1)
return w
@pytest.mark.parametrize("tree_method", ["approx", "hist"]) @pytest.mark.parametrize("tree_method", ["approx", "hist"])
def test_feature_weights(tree_method): def test_feature_weights(tree_method):
kRows = 512 kRows = 512
@ -1080,12 +1040,18 @@ def test_feature_weights(tree_method):
fw = np.ones(shape=(kCols,)) fw = np.ones(shape=(kCols,))
for i in range(kCols): for i in range(kCols):
fw[i] *= float(i) fw[i] *= float(i)
poly_increasing = run_feature_weights(X, y, fw, tree_method, xgb.XGBRegressor)
parser_path = os.path.join(tm.demo_dir(__file__), "json-model", "json_parser.py")
poly_increasing = get_feature_weights(
X, y, fw, parser_path, tree_method, xgb.XGBRegressor
)
fw = np.ones(shape=(kCols,)) fw = np.ones(shape=(kCols,))
for i in range(kCols): for i in range(kCols):
fw[i] *= float(kCols - i) fw[i] *= float(kCols - i)
poly_decreasing = run_feature_weights(X, y, fw, tree_method, xgb.XGBRegressor) poly_decreasing = get_feature_weights(
X, y, fw, parser_path, tree_method, xgb.XGBRegressor
)
# Approxmated test, this is dependent on the implementation of random # Approxmated test, this is dependent on the implementation of random
# number generator in std library. # number generator in std library.
@ -1219,33 +1185,10 @@ def test_multilabel_classification() -> None:
assert predt.dtype == np.int64 assert predt.dtype == np.int64
def run_data_initialization(DMatrix, model, X, y):
"""Assert that we don't create duplicated DMatrix."""
old_init = DMatrix.__init__
count = [0]
def new_init(self, **kwargs):
count[0] += 1
return old_init(self, **kwargs)
DMatrix.__init__ = new_init
model(n_estimators=1).fit(X, y, eval_set=[(X, y)])
assert count[0] == 1
count[0] = 0 # only 1 DMatrix is created.
y_copy = y.copy()
model(n_estimators=1).fit(X, y, eval_set=[(X, y_copy)])
assert count[0] == 2 # a different Python object is considered different
DMatrix.__init__ = old_init
def test_data_initialization(): def test_data_initialization():
from sklearn.datasets import load_digits from sklearn.datasets import load_digits
X, y = load_digits(return_X_y=True) X, y = load_digits(return_X_y=True)
run_data_initialization(xgb.DMatrix, xgb.XGBClassifier, X, y) validate_data_initialization(xgb.DMatrix, xgb.XGBClassifier, X, y)
@parametrize_with_checks([xgb.XGBRegressor()]) @parametrize_with_checks([xgb.XGBRegressor()])

View File

@ -3,9 +3,10 @@ import multiprocessing
import sys import sys
import time import time
import xgboost as xgb
import xgboost.federated import xgboost.federated
import xgboost as xgb
SERVER_KEY = 'server-key.pem' SERVER_KEY = 'server-key.pem'
SERVER_CERT = 'server-cert.pem' SERVER_CERT = 'server-cert.pem'
CLIENT_KEY = 'client-key.pem' CLIENT_KEY = 'client-key.pem'
@ -58,7 +59,7 @@ def run_worker(port: int, world_size: int, rank: int, with_ssl: bool, with_gpu:
xgb.collective.communicator_print("Finished training\n") xgb.collective.communicator_print("Finished training\n")
def run_test(with_ssl: bool = True, with_gpu: bool = False) -> None: def run_federated(with_ssl: bool = True, with_gpu: bool = False) -> None:
port = 9091 port = 9091
world_size = int(sys.argv[1]) world_size = int(sys.argv[1])
@ -80,7 +81,7 @@ def run_test(with_ssl: bool = True, with_gpu: bool = False) -> None:
if __name__ == '__main__': if __name__ == '__main__':
run_test(with_ssl=True, with_gpu=False) run_federated(with_ssl=True, with_gpu=False)
run_test(with_ssl=False, with_gpu=False) run_federated(with_ssl=False, with_gpu=False)
run_test(with_ssl=True, with_gpu=True) run_federated(with_ssl=True, with_gpu=True)
run_test(with_ssl=False, with_gpu=True) run_federated(with_ssl=False, with_gpu=True)

View File

@ -0,0 +1 @@

View File

@ -0,0 +1,42 @@
from typing import Generator, Sequence
import pytest
from xgboost import testing as tm
@pytest.fixture(scope="session", autouse=True)
def setup_rmm_pool(request, pytestconfig: pytest.Config) -> None:
tm.setup_rmm_pool(request, pytestconfig)
@pytest.fixture(scope="class")
def local_cuda_client(request, pytestconfig: pytest.Config) -> Generator:
kwargs = {}
if hasattr(request, "param"):
kwargs.update(request.param)
if pytestconfig.getoption("--use-rmm-pool"):
if tm.no_rmm()["condition"]:
raise ImportError("The --use-rmm-pool option requires the RMM package")
import rmm
kwargs["rmm_pool_size"] = "2GB"
if tm.no_dask_cuda()["condition"]:
raise ImportError("The local_cuda_cluster fixture requires dask_cuda package")
from dask.distributed import Client
from dask_cuda import LocalCUDACluster
yield Client(LocalCUDACluster(**kwargs))
def pytest_addoption(parser: pytest.Parser) -> None:
parser.addoption(
"--use-rmm-pool", action="store_true", default=False, help="Use RMM pool"
)
def pytest_collection_modifyitems(config: pytest.Config, items: Sequence) -> None:
# mark dask tests as `mgpu`.
mgpu_mark = pytest.mark.mgpu
for item in items:
item.add_marker(mgpu_mark)

View File

@ -2,7 +2,6 @@
import asyncio import asyncio
import os import os
import subprocess import subprocess
import sys
from collections import OrderedDict from collections import OrderedDict
from inspect import signature from inspect import signature
from typing import Any, Dict, Type, TypeVar, Union from typing import Any, Dict, Type, TypeVar, Union
@ -11,43 +10,39 @@ import numpy as np
import pytest import pytest
from hypothesis import given, note, settings, strategies from hypothesis import given, note, settings, strategies
from hypothesis._settings import duration from hypothesis._settings import duration
from test_gpu_updaters import parameter_strategy from xgboost.testing.params import hist_parameter_strategy
import xgboost import xgboost as xgb
from xgboost import testing as tm from xgboost import testing as tm
if sys.platform.startswith("win"): pytestmark = [
pytest.skip("Skipping dask tests on Windows", allow_module_level=True) pytest.mark.skipif(**tm.no_dask()),
pytest.mark.skipif(**tm.no_dask_cuda()),
]
sys.path.append("tests/python") from ..test_with_dask.test_with_dask import generate_array
from ..test_with_dask.test_with_dask import kCols as random_cols
if tm.no_dask_cuda()["condition"]: from ..test_with_dask.test_with_dask import (
pytest.skip(tm.no_dask_cuda()["reason"], allow_module_level=True) make_categorical,
run_auc,
run_boost_from_prediction,
from test_with_dask import _get_client_workers # noqa run_boost_from_prediction_multi_class,
from test_with_dask import generate_array # noqa run_categorical,
from test_with_dask import make_categorical # noqa run_dask_classifier,
from test_with_dask import run_auc # noqa run_empty_dmatrix_auc,
from test_with_dask import run_boost_from_prediction # noqa run_empty_dmatrix_cls,
from test_with_dask import run_boost_from_prediction_multi_class # noqa run_empty_dmatrix_reg,
from test_with_dask import run_categorical # noqa run_tree_stats,
from test_with_dask import run_dask_classifier # noqa suppress,
from test_with_dask import run_empty_dmatrix_auc # noqa )
from test_with_dask import run_empty_dmatrix_cls # noqa
from test_with_dask import run_empty_dmatrix_reg # noqa
from test_with_dask import run_tree_stats # noqa
from test_with_dask import suppress # noqa
from test_with_dask import kCols as random_cols # noqa
try: try:
import cudf import cudf
import dask.dataframe as dd import dask.dataframe as dd
from dask import array as da from dask import array as da
from dask.distributed import Client from dask.distributed import Client
from dask_cuda import LocalCUDACluster, utils from dask_cuda import LocalCUDACluster
import xgboost as xgb
from xgboost import dask as dxgb from xgboost import dask as dxgb
except ImportError: except ImportError:
pass pass
@ -57,10 +52,10 @@ def run_with_dask_dataframe(DMatrixT: Type, client: Client) -> None:
import cupy as cp import cupy as cp
cp.cuda.runtime.setDevice(0) cp.cuda.runtime.setDevice(0)
X, y, _ = generate_array() _X, _y, _ = generate_array()
X = dd.from_dask_array(X) X = dd.from_dask_array(_X)
y = dd.from_dask_array(y) y = dd.from_dask_array(_y)
X = X.map_partitions(cudf.from_pandas) X = X.map_partitions(cudf.from_pandas)
y = y.map_partitions(cudf.from_pandas) y = y.map_partitions(cudf.from_pandas)
@ -83,7 +78,7 @@ def run_with_dask_dataframe(DMatrixT: Type, client: Client) -> None:
series_predictions = dxgb.inplace_predict(client, out, X) series_predictions = dxgb.inplace_predict(client, out, X)
assert isinstance(series_predictions, dd.Series) assert isinstance(series_predictions, dd.Series)
single_node = out["booster"].predict(xgboost.DMatrix(X.compute())) single_node = out["booster"].predict(xgb.DMatrix(X.compute()))
cp.testing.assert_allclose(single_node, predictions.compute()) cp.testing.assert_allclose(single_node, predictions.compute())
np.testing.assert_allclose(single_node, series_predictions.compute().to_numpy()) np.testing.assert_allclose(single_node, series_predictions.compute().to_numpy())
@ -127,7 +122,7 @@ def run_with_dask_array(DMatrixT: Type, client: Client) -> None:
) )
from_dmatrix = dxgb.predict(client, out, dtrain).compute() from_dmatrix = dxgb.predict(client, out, dtrain).compute()
inplace_predictions = dxgb.inplace_predict(client, out, X).compute() inplace_predictions = dxgb.inplace_predict(client, out, X).compute()
single_node = out["booster"].predict(xgboost.DMatrix(X.compute())) single_node = out["booster"].predict(xgb.DMatrix(X.compute()))
np.testing.assert_allclose(single_node, from_dmatrix) np.testing.assert_allclose(single_node, from_dmatrix)
device = cp.cuda.runtime.getDevice() device = cp.cuda.runtime.getDevice()
assert device == inplace_predictions.device.id assert device == inplace_predictions.device.id
@ -242,7 +237,7 @@ class TestDistributedGPU:
run_categorical(local_cuda_client, "gpu_hist", X, X_onehot, y) run_categorical(local_cuda_client, "gpu_hist", X, X_onehot, y)
@given( @given(
params=parameter_strategy, params=hist_parameter_strategy,
num_rounds=strategies.integers(1, 20), num_rounds=strategies.integers(1, 20),
dataset=tm.dataset_strategy, dataset=tm.dataset_strategy,
dmatrix_type=strategies.sampled_from( dmatrix_type=strategies.sampled_from(
@ -405,7 +400,7 @@ class TestDistributedGPU:
np.testing.assert_allclose(predt, in_predt) np.testing.assert_allclose(predt, in_predt)
def test_empty_dmatrix_auc(self, local_cuda_client: Client) -> None: def test_empty_dmatrix_auc(self, local_cuda_client: Client) -> None:
n_workers = len(_get_client_workers(local_cuda_client)) n_workers = len(tm.get_client_workers(local_cuda_client))
run_empty_dmatrix_auc(local_cuda_client, "gpu_hist", n_workers) run_empty_dmatrix_auc(local_cuda_client, "gpu_hist", n_workers)
def test_auc(self, local_cuda_client: Client) -> None: def test_auc(self, local_cuda_client: Client) -> None:
@ -418,7 +413,7 @@ class TestDistributedGPU:
fw = fw - fw.min() fw = fw - fw.min()
m = dxgb.DaskDMatrix(local_cuda_client, X, y, feature_weights=fw) m = dxgb.DaskDMatrix(local_cuda_client, X, y, feature_weights=fw)
workers = _get_client_workers(local_cuda_client) workers = tm.get_client_workers(local_cuda_client)
rabit_args = local_cuda_client.sync( rabit_args = local_cuda_client.sync(
dxgb._get_rabit_args, len(workers), None, local_cuda_client dxgb._get_rabit_args, len(workers), None, local_cuda_client
) )
@ -488,9 +483,6 @@ class TestDistributedGPU:
assert rn == drn assert rn == drn
def run_quantile(self, name: str, local_cuda_client: Client) -> None: def run_quantile(self, name: str, local_cuda_client: Client) -> None:
if sys.platform.startswith("win"):
pytest.skip("Skipping dask tests on Windows")
exe = None exe = None
for possible_path in { for possible_path in {
"./testxgboost", "./testxgboost",
@ -506,14 +498,13 @@ class TestDistributedGPU:
def runit( def runit(
worker_addr: str, rabit_args: Dict[str, Union[int, str]] worker_addr: str, rabit_args: Dict[str, Union[int, str]]
) -> subprocess.CompletedProcess: ) -> subprocess.CompletedProcess:
port_env = ""
# setup environment for running the c++ part. # setup environment for running the c++ part.
env = os.environ.copy() env = os.environ.copy()
env['DMLC_TRACKER_PORT'] = str(rabit_args['DMLC_TRACKER_PORT']) env['DMLC_TRACKER_PORT'] = str(rabit_args['DMLC_TRACKER_PORT'])
env["DMLC_TRACKER_URI"] = str(rabit_args["DMLC_TRACKER_URI"]) env["DMLC_TRACKER_URI"] = str(rabit_args["DMLC_TRACKER_URI"])
return subprocess.run([str(exe), test], env=env, stdout=subprocess.PIPE) return subprocess.run([str(exe), test], env=env, stdout=subprocess.PIPE)
workers = _get_client_workers(local_cuda_client) workers = tm.get_client_workers(local_cuda_client)
rabit_args = local_cuda_client.sync( rabit_args = local_cuda_client.sync(
dxgb._get_rabit_args, len(workers), None, local_cuda_client dxgb._get_rabit_args, len(workers), None, local_cuda_client
) )
@ -539,7 +530,7 @@ class TestDistributedGPU:
def test_with_asyncio(local_cuda_client: Client) -> None: def test_with_asyncio(local_cuda_client: Client) -> None:
address = local_cuda_client.scheduler.address address = local_cuda_client.scheduler.address
output = asyncio.run(run_from_dask_array_asyncio(address)) output = asyncio.run(run_from_dask_array_asyncio(address))
assert isinstance(output["booster"], xgboost.Booster) assert isinstance(output["booster"], xgb.Booster)
assert isinstance(output["history"], dict) assert isinstance(output["history"], dict)
@ -551,12 +542,12 @@ async def run_from_dask_array_asyncio(scheduler_address: str) -> dxgb.TrainRetur
X = X.map_blocks(cp.array) X = X.map_blocks(cp.array)
y = y.map_blocks(cp.array) y = y.map_blocks(cp.array)
m = await xgboost.dask.DaskDeviceQuantileDMatrix(client, X, y) m = await xgb.dask.DaskDeviceQuantileDMatrix(client, X, y)
output = await xgboost.dask.train(client, {"tree_method": "gpu_hist"}, dtrain=m) output = await xgb.dask.train(client, {"tree_method": "gpu_hist"}, dtrain=m)
with_m = await xgboost.dask.predict(client, output, m) with_m = await xgb.dask.predict(client, output, m)
with_X = await xgboost.dask.predict(client, output, X) with_X = await xgb.dask.predict(client, output, X)
inplace = await xgboost.dask.inplace_predict(client, output, X) inplace = await xgb.dask.inplace_predict(client, output, X)
assert isinstance(with_m, da.Array) assert isinstance(with_m, da.Array)
assert isinstance(with_X, da.Array) assert isinstance(with_X, da.Array)
assert isinstance(inplace, da.Array) assert isinstance(inplace, da.Array)

View File

@ -0,0 +1,10 @@
from typing import Sequence
import pytest
def pytest_collection_modifyitems(config: pytest.Config, items: Sequence) -> None:
# mark dask tests as `mgpu`.
mgpu_mark = pytest.mark.mgpu
for item in items:
item.add_marker(mgpu_mark)

View File

@ -0,0 +1,16 @@
import pytest
from xgboost import testing as tm
pytestmark = pytest.mark.skipif(**tm.no_spark())
from ..test_with_spark.test_data import run_dmatrix_ctor
@pytest.mark.skipif(**tm.no_cudf())
@pytest.mark.parametrize(
"is_feature_cols,is_qdm",
[(True, True), (True, False), (False, True), (False, False)],
)
def test_dmatrix_ctor(is_feature_cols: bool, is_qdm: bool) -> None:
run_dmatrix_ctor(is_feature_cols, is_qdm, on_gpu=True)

View File

@ -1,24 +1,20 @@
import json import json
import logging import logging
import subprocess import subprocess
import sys
import pytest import pytest
import sklearn import sklearn
from xgboost import testing as tm from xgboost import testing as tm
if tm.no_spark()["condition"]: pytestmark = pytest.mark.skipif(**tm.no_spark())
pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True)
if sys.platform.startswith("win"):
pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True)
from pyspark.ml.linalg import Vectors from pyspark.ml.linalg import Vectors
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.sql import SparkSession from pyspark.sql import SparkSession
from xgboost.spark import SparkXGBClassifier, SparkXGBRegressor from xgboost.spark import SparkXGBClassifier, SparkXGBRegressor
gpu_discovery_script_path = "tests/python-gpu/test_gpu_spark/discover_gpu.sh" gpu_discovery_script_path = "tests/test_distributed/test_gpu_with_spark/discover_gpu.sh"
def get_devices(): def get_devices():

View File

@ -0,0 +1 @@

View File

@ -5,7 +5,6 @@ import os
import pickle import pickle
import socket import socket
import subprocess import subprocess
import sys
import tempfile import tempfile
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from functools import partial from functools import partial
@ -13,7 +12,7 @@ from itertools import starmap
from math import ceil from math import ceil
from operator import attrgetter, getitem from operator import attrgetter, getitem
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Type, Union from typing import Any, Dict, Optional, Tuple, Type, Union, Generator
import hypothesis import hypothesis
import numpy as np import numpy as np
@ -22,18 +21,18 @@ import scipy
import sklearn import sklearn
from hypothesis import HealthCheck, given, note, settings from hypothesis import HealthCheck, given, note, settings
from sklearn.datasets import make_classification, make_regression from sklearn.datasets import make_classification, make_regression
from test_predict import verify_leaf_output
from test_updaters import exact_parameter_strategy, hist_parameter_strategy
from test_with_sklearn import run_data_initialization, run_feature_weights
from xgboost.data import _is_cudf_df from xgboost.data import _is_cudf_df
from xgboost.testing.params import hist_parameter_strategy
from xgboost.testing.shared import (
get_feature_weights,
validate_data_initialization,
validate_leaf_output,
)
import xgboost as xgb import xgboost as xgb
from xgboost import testing as tm from xgboost import testing as tm
if sys.platform.startswith("win"): pytestmark = [tm.timeout(30), pytest.mark.skipif(**tm.no_dask())]
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
if tm.no_dask()['condition']:
pytest.skip(msg=tm.no_dask()['reason'], allow_module_level=True)
import dask import dask
import dask.array as da import dask.array as da
@ -44,7 +43,6 @@ from xgboost.dask import DaskDMatrix
dask.config.set({"distributed.scheduler.allowed-failures": False}) dask.config.set({"distributed.scheduler.allowed-failures": False})
pytestmark = tm.timeout(30)
if hasattr(HealthCheck, 'function_scoped_fixture'): if hasattr(HealthCheck, 'function_scoped_fixture'):
suppress = [HealthCheck.function_scoped_fixture] suppress = [HealthCheck.function_scoped_fixture]
@ -53,7 +51,7 @@ else:
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def cluster(): def cluster() -> Generator:
with LocalCluster( with LocalCluster(
n_workers=2, threads_per_worker=2, dashboard_address=":0" n_workers=2, threads_per_worker=2, dashboard_address=":0"
) as dask_cluster: ) as dask_cluster:
@ -61,7 +59,7 @@ def cluster():
@pytest.fixture @pytest.fixture
def client(cluster): def client(cluster: "LocalCluster") -> Generator:
with Client(cluster) as dask_client: with Client(cluster) as dask_client:
yield dask_client yield dask_client
@ -71,11 +69,6 @@ kCols = 10
kWorkers = 5 kWorkers = 5
def _get_client_workers(client: "Client") -> List[str]:
workers = client.scheduler_info()['workers']
return list(workers.keys())
def make_categorical( def make_categorical(
client: Client, client: Client,
n_samples: int, n_samples: int,
@ -83,7 +76,7 @@ def make_categorical(
n_categories: int, n_categories: int,
onehot: bool = False, onehot: bool = False,
) -> Tuple[dd.DataFrame, dd.Series]: ) -> Tuple[dd.DataFrame, dd.Series]:
workers = _get_client_workers(client) workers = tm.get_client_workers(client)
n_workers = len(workers) n_workers = len(workers)
dfs = [] dfs = []
@ -121,9 +114,7 @@ def make_categorical(
def generate_array( def generate_array(
with_weights: bool = False, with_weights: bool = False,
) -> Tuple[ ) -> Tuple[da.Array, da.Array, Optional[da.Array]]:
xgb.dask._DataT, xgb.dask._DaskCollection, Optional[xgb.dask._DaskCollection]
]:
chunk_size = 20 chunk_size = 20
rng = da.random.RandomState(1994) rng = da.random.RandomState(1994)
X = rng.random_sample((kRows, kCols), chunks=(chunk_size, -1)) X = rng.random_sample((kRows, kCols), chunks=(chunk_size, -1))
@ -134,7 +125,7 @@ def generate_array(
return X, y, None return X, y, None
def deterministic_persist_per_worker(df, client): def deterministic_persist_per_worker(df: dd.DataFrame, client: "Client") -> dd.DataFrame:
# Got this script from https://github.com/dmlc/xgboost/issues/7927 # Got this script from https://github.com/dmlc/xgboost/issues/7927
# Query workers # Query workers
n_workers = len(client.cluster.workers) n_workers = len(client.cluster.workers)
@ -1232,7 +1223,7 @@ def test_dask_predict_leaf(booster: str, client: "Client") -> None:
leaf_from_apply = cls.apply(X).reshape(leaf.shape).compute() leaf_from_apply = cls.apply(X).reshape(leaf.shape).compute()
np.testing.assert_allclose(leaf_from_apply, leaf) np.testing.assert_allclose(leaf_from_apply, leaf)
verify_leaf_output(leaf, num_parallel_tree) validate_leaf_output(leaf, num_parallel_tree)
def test_dask_iteration_range(client: "Client"): def test_dask_iteration_range(client: "Client"):
@ -1287,7 +1278,7 @@ class TestWithDask:
assert Xy.num_col() == 4 assert Xy.num_col() == 4
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
workers = _get_client_workers(client) workers = tm.get_client_workers(client)
rabit_args = client.sync( rabit_args = client.sync(
xgb.dask._get_rabit_args, len(workers), None, client xgb.dask._get_rabit_args, len(workers), None, client
) )
@ -1403,10 +1394,10 @@ class TestWithDask:
note(history) note(history)
history = history['train'][dataset.metric] history = history['train'][dataset.metric]
def is_stump(): def is_stump() -> bool:
return params["max_depth"] == 1 or params["max_leaves"] == 1 return params["max_depth"] == 1 or params["max_leaves"] == 1
def minimum_bin(): def minimum_bin() -> bool:
return "max_bin" in params and params["max_bin"] == 2 return "max_bin" in params and params["max_bin"] == 2
# See note on `ObjFunction::UpdateTreeLeaf`. # See note on `ObjFunction::UpdateTreeLeaf`.
@ -1466,9 +1457,10 @@ class TestWithDask:
quantile_hist["Valid"]["rmse"], dmatrix_hist["Valid"]["rmse"] quantile_hist["Valid"]["rmse"], dmatrix_hist["Valid"]["rmse"]
) )
@given(params=exact_parameter_strategy, @given(params=hist_parameter_strategy, dataset=tm.dataset_strategy)
dataset=tm.dataset_strategy) @settings(
@settings(deadline=None, max_examples=10, suppress_health_check=suppress, print_blob=True) deadline=None, max_examples=10, suppress_health_check=suppress, print_blob=True
)
def test_approx( def test_approx(
self, client: "Client", params: Dict, dataset: tm.TestDataset self, client: "Client", params: Dict, dataset: tm.TestDataset
) -> None: ) -> None:
@ -1476,9 +1468,6 @@ class TestWithDask:
self.run_updater_test(client, params, num_rounds, dataset, 'approx') self.run_updater_test(client, params, num_rounds, dataset, 'approx')
def run_quantile(self, name: str) -> None: def run_quantile(self, name: str) -> None:
if sys.platform.startswith("win"):
pytest.skip("Skipping dask tests on Windows")
exe: Optional[str] = None exe: Optional[str] = None
for possible_path in {'./testxgboost', './build/testxgboost', for possible_path in {'./testxgboost', './build/testxgboost',
'../build/cpubuild/testxgboost', '../build/cpubuild/testxgboost',
@ -1493,7 +1482,6 @@ class TestWithDask:
def runit( def runit(
worker_addr: str, rabit_args: Dict[str, Union[int, str]] worker_addr: str, rabit_args: Dict[str, Union[int, str]]
) -> subprocess.CompletedProcess: ) -> subprocess.CompletedProcess:
port_env = ''
# setup environment for running the c++ part. # setup environment for running the c++ part.
env = os.environ.copy() env = os.environ.copy()
env['DMLC_TRACKER_PORT'] = str(rabit_args['DMLC_TRACKER_PORT']) env['DMLC_TRACKER_PORT'] = str(rabit_args['DMLC_TRACKER_PORT'])
@ -1502,7 +1490,7 @@ class TestWithDask:
with LocalCluster(n_workers=4, dashboard_address=":0") as cluster: with LocalCluster(n_workers=4, dashboard_address=":0") as cluster:
with Client(cluster) as client: with Client(cluster) as client:
workers = _get_client_workers(client) workers = tm.get_client_workers(client)
rabit_args = client.sync( rabit_args = client.sync(
xgb.dask._get_rabit_args, len(workers), None, client xgb.dask._get_rabit_args, len(workers), None, client
) )
@ -1565,7 +1553,7 @@ class TestWithDask:
with LocalCluster(n_workers=2, dashboard_address=":0") as cluster: with LocalCluster(n_workers=2, dashboard_address=":0") as cluster:
with Client(cluster) as client: with Client(cluster) as client:
workers = _get_client_workers(client) workers = tm.get_client_workers(client)
rabit_args = client.sync( rabit_args = client.sync(
xgb.dask._get_rabit_args, len(workers), None, client xgb.dask._get_rabit_args, len(workers), None, client
) )
@ -1580,7 +1568,7 @@ class TestWithDask:
def test_n_workers(self) -> None: def test_n_workers(self) -> None:
with LocalCluster(n_workers=2, dashboard_address=":0") as cluster: with LocalCluster(n_workers=2, dashboard_address=":0") as cluster:
with Client(cluster) as client: with Client(cluster) as client:
workers = _get_client_workers(client) workers = tm.get_client_workers(client)
from sklearn.datasets import load_breast_cancer from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True) X, y = load_breast_cancer(return_X_y=True)
@ -1609,16 +1597,17 @@ class TestWithDask:
for i in range(kCols): for i in range(kCols):
fw[i] *= float(i) fw[i] *= float(i)
fw = da.from_array(fw) fw = da.from_array(fw)
poly_increasing = run_feature_weights( parser = os.path.join(tm.demo_dir(__file__), "json-model", "json_parser.py")
X, y, fw, "approx", model=xgb.dask.DaskXGBRegressor poly_increasing = get_feature_weights(
X, y, fw, parser, "approx", model=xgb.dask.DaskXGBRegressor
) )
fw = np.ones(shape=(kCols,)) fw = np.ones(shape=(kCols,))
for i in range(kCols): for i in range(kCols):
fw[i] *= float(kCols - i) fw[i] *= float(kCols - i)
fw = da.from_array(fw) fw = da.from_array(fw)
poly_decreasing = run_feature_weights( poly_decreasing = get_feature_weights(
X, y, fw, "approx", model=xgb.dask.DaskXGBRegressor X, y, fw, parser, "approx", model=xgb.dask.DaskXGBRegressor
) )
# Approxmated test, this is dependent on the implementation of random # Approxmated test, this is dependent on the implementation of random
@ -1675,7 +1664,7 @@ class TestWithDask:
X, y, _ = generate_array() X, y, _ = generate_array()
n_partitions = X.npartitions n_partitions = X.npartitions
m = xgb.dask.DaskDMatrix(client, X, y) m = xgb.dask.DaskDMatrix(client, X, y)
workers = _get_client_workers(client) workers = tm.get_client_workers(client)
rabit_args = client.sync( rabit_args = client.sync(
xgb.dask._get_rabit_args, len(workers), None, client xgb.dask._get_rabit_args, len(workers), None, client
) )
@ -1717,7 +1706,9 @@ class TestWithDask:
from sklearn.datasets import load_digits from sklearn.datasets import load_digits
X, y = load_digits(return_X_y=True) X, y = load_digits(return_X_y=True)
X, y = dd.from_array(X, chunksize=32), dd.from_array(y, chunksize=32) X, y = dd.from_array(X, chunksize=32), dd.from_array(y, chunksize=32)
run_data_initialization(xgb.dask.DaskDMatrix, xgb.dask.DaskXGBClassifier, X, y) validate_data_initialization(
xgb.dask.DaskDMatrix, xgb.dask.DaskXGBClassifier, X, y
)
def run_shap(self, X: Any, y: Any, params: Dict[str, Any], client: "Client") -> None: def run_shap(self, X: Any, y: Any, params: Dict[str, Any], client: "Client") -> None:
rows = X.shape[0] rows = X.shape[0]
@ -1884,7 +1875,7 @@ def test_parallel_submits(client: "Client") -> None:
from sklearn.datasets import load_digits from sklearn.datasets import load_digits
futures = [] futures = []
workers = _get_client_workers(client) workers = tm.get_client_workers(client)
n_submits = len(workers) n_submits = len(workers)
for i in range(n_submits): for i in range(n_submits):
X_, y_ = load_digits(return_X_y=True) X_, y_ = load_digits(return_X_y=True)
@ -1970,7 +1961,7 @@ def test_parallel_submit_multi_clients() -> None:
with LocalCluster(n_workers=4, dashboard_address=":0") as cluster: with LocalCluster(n_workers=4, dashboard_address=":0") as cluster:
with Client(cluster) as client: with Client(cluster) as client:
workers = _get_client_workers(client) workers = tm.get_client_workers(client)
n_submits = len(workers) n_submits = len(workers)
assert n_submits == 4 assert n_submits == 4

View File

@ -1,4 +1,3 @@
import sys
from typing import List from typing import List
import numpy as np import numpy as np
@ -7,10 +6,7 @@ import pytest
from xgboost import testing as tm from xgboost import testing as tm
if tm.no_spark()["condition"]: pytestmark = [pytest.mark.skipif(**tm.no_spark())]
pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True)
if sys.platform.startswith("win") or sys.platform.startswith("darwin"):
pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True)
from xgboost.spark.data import ( from xgboost.spark.data import (
_read_csr_matrix_from_unwrapped_spark_vec, _read_csr_matrix_from_unwrapped_spark_vec,

View File

@ -1,7 +1,6 @@
import glob import glob
import logging import logging
import random import random
import sys
import uuid import uuid
import numpy as np import numpy as np
@ -10,10 +9,7 @@ import pytest
import xgboost as xgb import xgboost as xgb
from xgboost import testing as tm from xgboost import testing as tm
if tm.no_spark()["condition"]: pytestmark = [tm.timeout(60), pytest.mark.skipif(**tm.no_spark())]
pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True)
if sys.platform.startswith("win") or sys.platform.startswith("darwin"):
pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True)
from pyspark.ml import Pipeline, PipelineModel from pyspark.ml import Pipeline, PipelineModel
from pyspark.ml.evaluation import BinaryClassificationEvaluator from pyspark.ml.evaluation import BinaryClassificationEvaluator
@ -37,8 +33,6 @@ from .utils import SparkTestCase
logging.getLogger("py4j").setLevel(logging.INFO) logging.getLogger("py4j").setLevel(logging.INFO)
pytestmark = tm.timeout(60)
class XgboostLocalTest(SparkTestCase): class XgboostLocalTest(SparkTestCase):
def setUp(self): def setUp(self):

View File

@ -9,10 +9,7 @@ import pytest
from xgboost import testing as tm from xgboost import testing as tm
if tm.no_spark()["condition"]: pytestmark = pytest.mark.skipif(**tm.no_spark())
pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True)
if sys.platform.startswith("win") or sys.platform.startswith("darwin"):
pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True)
from pyspark.ml.linalg import Vectors from pyspark.ml.linalg import Vectors
from xgboost.spark import SparkXGBClassifier, SparkXGBRegressor from xgboost.spark import SparkXGBClassifier, SparkXGBRegressor

View File

@ -10,10 +10,8 @@ from six import StringIO
from xgboost import testing as tm from xgboost import testing as tm
if tm.no_spark()["condition"]: pytestmark = [pytest.mark.skipif(**tm.no_spark())]
pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True)
if sys.platform.startswith("win") or sys.platform.startswith("darwin"):
pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True)
from pyspark.sql import SparkSession, SQLContext from pyspark.sql import SparkSession, SQLContext
from xgboost.spark.utils import _get_default_params_from_func from xgboost.spark.utils import _get_default_params_from_func