From 00264eb72bfe1d5dfc8df00563eaa1c08e6dc15d Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sat, 6 Jul 2024 01:15:20 +0800 Subject: [PATCH] [EM] Basic distributed test for external memory. (#10492) --- python-package/xgboost/testing/__init__.py | 7 +- tests/ci_build/lint_python.py | 1 + .../test_with_dask/test_external_memory.py | 88 +++++++++++++++++++ 3 files changed, 93 insertions(+), 3 deletions(-) create mode 100644 tests/test_distributed/test_with_dask/test_external_memory.py diff --git a/python-package/xgboost/testing/__init__.py b/python-package/xgboost/testing/__init__.py index 482da68c9..e0096c89c 100644 --- a/python-package/xgboost/testing/__init__.py +++ b/python-package/xgboost/testing/__init__.py @@ -248,13 +248,14 @@ class IteratorForTest(xgb.core.DataIter): return X, y, w -def make_batches( +def make_batches( # pylint: disable=too-many-arguments,too-many-locals n_samples_per_batch: int, n_features: int, n_batches: int, use_cupy: bool = False, *, vary_size: bool = False, + random_state: int = 1994, ) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]: X = [] y = [] @@ -262,9 +263,9 @@ def make_batches( if use_cupy: import cupy - rng = cupy.random.RandomState(1994) + rng = cupy.random.RandomState(random_state) else: - rng = np.random.RandomState(1994) + rng = np.random.RandomState(random_state) for i in range(n_batches): n_samples = n_samples_per_batch + i * 10 if vary_size else n_samples_per_batch _X = rng.randn(n_samples, n_features) diff --git a/tests/ci_build/lint_python.py b/tests/ci_build/lint_python.py index 079996de6..f8bbbc284 100644 --- a/tests/ci_build/lint_python.py +++ b/tests/ci_build/lint_python.py @@ -98,6 +98,7 @@ class LintersPaths: "tests/python/test_model_io.py", "tests/test_distributed/test_federated/", "tests/test_distributed/test_gpu_federated/", + "tests/test_distributed/test_with_dask/test_external_memory.py", "tests/test_distributed/test_with_spark/test_data.py", "tests/test_distributed/test_gpu_with_spark/test_data.py", "tests/test_distributed/test_gpu_with_dask/test_gpu_with_dask.py", diff --git a/tests/test_distributed/test_with_dask/test_external_memory.py b/tests/test_distributed/test_with_dask/test_external_memory.py new file mode 100644 index 000000000..cf475d90f --- /dev/null +++ b/tests/test_distributed/test_with_dask/test_external_memory.py @@ -0,0 +1,88 @@ +from typing import List, cast + +import numpy as np +from distributed import Client, Scheduler, Worker, get_worker +from distributed.utils_test import gen_cluster + +import xgboost as xgb +from xgboost import testing as tm +from xgboost.compat import concat + + +def run_external_memory(worker_id: int, n_workers: int, comm_args: dict) -> None: + n_samples_per_batch = 32 + n_features = 4 + n_batches = 16 + use_cupy = False + + n_threads = get_worker().state.nthreads + with xgb.collective.CommunicatorContext(dmlc_communicator="rabit", **comm_args): + it = tm.IteratorForTest( + *tm.make_batches( + n_samples_per_batch, + n_features, + n_batches, + use_cupy, + random_state=worker_id, + ), + cache="cache", + ) + Xy = xgb.DMatrix(it, nthread=n_threads) + results: xgb.callback.TrainingCallback.EvalsLog = {} + booster = xgb.train( + {"tree_method": "hist", "nthread": n_threads}, + Xy, + evals=[(Xy, "Train")], + num_boost_round=32, + evals_result=results, + ) + assert tm.non_increasing(cast(List[float], results["Train"]["rmse"])) + + lx, ly, lw = [], [], [] + for i in range(n_workers): + x, y, w = tm.make_batches( + n_samples_per_batch, + n_features, + n_batches, + use_cupy, + random_state=i, + ) + lx.extend(x) + ly.extend(y) + lw.extend(w) + + X = concat(lx) + yconcat = concat(ly) + wconcat = concat(lw) + Xy = xgb.DMatrix(X, yconcat, wconcat, nthread=n_threads) + + results_local: xgb.callback.TrainingCallback.EvalsLog = {} + booster = xgb.train( + {"tree_method": "hist", "nthread": n_threads}, + Xy, + evals=[(Xy, "Train")], + num_boost_round=32, + evals_result=results_local, + ) + np.testing.assert_allclose( + results["Train"]["rmse"], results_local["Train"]["rmse"], rtol=1e-4 + ) + + +@gen_cluster(client=True) +async def test_external_memory( + client: Client, s: Scheduler, a: Worker, b: Worker +) -> None: + workers = tm.get_client_workers(client) + args = await client.sync( + xgb.dask._get_rabit_args, + len(workers), + None, + client, + ) + n_workers = len(workers) + + futs = client.map( + run_external_memory, range(n_workers), n_workers=n_workers, comm_args=args + ) + await client.gather(futs)