Extract dask and spark test into distributed test. (#8395)
- Move test files. - Run spark and dask separately to prevent conflicts. - Gather common code into the testing module.
This commit is contained in:
@@ -7,6 +7,7 @@ import pytest
|
||||
import xgboost as xgb
|
||||
from xgboost import RabitTracker
|
||||
from xgboost import testing as tm
|
||||
from xgboost import collective
|
||||
|
||||
if sys.platform.startswith("win"):
|
||||
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
|
||||
@@ -21,12 +22,9 @@ def test_rabit_tracker():
|
||||
|
||||
|
||||
def run_rabit_ops(client, n_workers):
|
||||
from test_with_dask import _get_client_workers
|
||||
from xgboost.dask import CommunicatorContext, _get_dask_config, _get_rabit_args
|
||||
|
||||
from xgboost import collective
|
||||
|
||||
workers = _get_client_workers(client)
|
||||
workers = tm.get_client_workers(client)
|
||||
rabit_args = client.sync(_get_rabit_args, len(workers), _get_dask_config(), client)
|
||||
assert not collective.is_distributed()
|
||||
n_workers_from_dask = len(workers)
|
||||
@@ -76,7 +74,6 @@ def test_rabit_ops_ipv6():
|
||||
|
||||
def test_rank_assignment() -> None:
|
||||
from distributed import Client, LocalCluster
|
||||
from test_with_dask import _get_client_workers
|
||||
|
||||
def local_test(worker_id):
|
||||
with xgb.dask.CommunicatorContext(**args) as ctx:
|
||||
@@ -89,7 +86,7 @@ def test_rank_assignment() -> None:
|
||||
|
||||
with LocalCluster(n_workers=8) as cluster:
|
||||
with Client(cluster) as client:
|
||||
workers = _get_client_workers(client)
|
||||
workers = tm.get_client_workers(client)
|
||||
args = client.sync(
|
||||
xgb.dask._get_rabit_args,
|
||||
len(workers),
|
||||
|
||||
Reference in New Issue
Block a user