[EM] Add basic distributed GPU tests. (#10861)

- Split Hist and Approx tests in unittests.
- Basic GPU tests for distributed.
This commit is contained in:
Jiaming Yuan
2024-10-01 01:28:43 +08:00
committed by GitHub
parent 92f1c48a22
commit 9ecb7583e9
4 changed files with 90 additions and 136 deletions

View File

@@ -1,14 +1,16 @@
"""Tests for dask shared by different test modules."""
from typing import Literal
from typing import List, Literal, cast
import numpy as np
import pandas as pd
from dask import array as da
from dask import dataframe as dd
from distributed import Client
from distributed import Client, get_worker
import xgboost as xgb
import xgboost.testing as tm
from xgboost.compat import concat
from xgboost.testing.updater import get_basescore
@@ -91,3 +93,76 @@ def check_uneven_nan(
dd.from_pandas(X, npartitions=n_workers),
dd.from_pandas(y, npartitions=n_workers),
)
def check_external_memory( # pylint: disable=too-many-locals
worker_id: int,
n_workers: int,
device: str,
comm_args: dict,
is_qdm: bool,
) -> None:
"""Basic checks for distributed external memory."""
n_samples_per_batch = 32
n_features = 4
n_batches = 16
use_cupy = device != "cpu"
n_threads = get_worker().state.nthreads
with xgb.collective.CommunicatorContext(dmlc_communicator="rabit", **comm_args):
it = tm.IteratorForTest(
*tm.make_batches(
n_samples_per_batch,
n_features,
n_batches,
use_cupy=use_cupy,
random_state=worker_id,
),
cache="cache",
)
if is_qdm:
Xy: xgb.DMatrix = xgb.ExtMemQuantileDMatrix(it, nthread=n_threads)
else:
Xy = xgb.DMatrix(it, nthread=n_threads)
results: xgb.callback.TrainingCallback.EvalsLog = {}
xgb.train(
{"tree_method": "hist", "nthread": n_threads, "device": device},
Xy,
evals=[(Xy, "Train")],
num_boost_round=32,
evals_result=results,
)
assert tm.non_increasing(cast(List[float], results["Train"]["rmse"]))
lx, ly, lw = [], [], []
for i in range(n_workers):
x, y, w = tm.make_batches(
n_samples_per_batch,
n_features,
n_batches,
use_cupy=use_cupy,
random_state=i,
)
lx.extend(x)
ly.extend(y)
lw.extend(w)
X = concat(lx)
yconcat = concat(ly)
wconcat = concat(lw)
if is_qdm:
Xy = xgb.QuantileDMatrix(X, yconcat, weight=wconcat, nthread=n_threads)
else:
Xy = xgb.DMatrix(X, yconcat, weight=wconcat, nthread=n_threads)
results_local: xgb.callback.TrainingCallback.EvalsLog = {}
xgb.train(
{"tree_method": "hist", "nthread": n_threads, "device": device},
Xy,
evals=[(Xy, "Train")],
num_boost_round=32,
evals_result=results_local,
)
np.testing.assert_allclose(
results["Train"]["rmse"], results_local["Train"]["rmse"], rtol=1e-4
)