[EM] Add basic distributed GPU tests. (#10861)

- Split Hist and Approx tests in unittests.
- Basic GPU tests for distributed.
This commit is contained in:
Jiaming Yuan 2024-10-01 01:28:43 +08:00 committed by GitHub
parent 92f1c48a22
commit 9ecb7583e9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 90 additions and 136 deletions

View File

@ -1,14 +1,16 @@
"""Tests for dask shared by different test modules."""
from typing import Literal
from typing import List, Literal, cast
import numpy as np
import pandas as pd
from dask import array as da
from dask import dataframe as dd
from distributed import Client
from distributed import Client, get_worker
import xgboost as xgb
import xgboost.testing as tm
from xgboost.compat import concat
from xgboost.testing.updater import get_basescore
@ -91,3 +93,76 @@ def check_uneven_nan(
dd.from_pandas(X, npartitions=n_workers),
dd.from_pandas(y, npartitions=n_workers),
)
def check_external_memory( # pylint: disable=too-many-locals
worker_id: int,
n_workers: int,
device: str,
comm_args: dict,
is_qdm: bool,
) -> None:
"""Basic checks for distributed external memory."""
n_samples_per_batch = 32
n_features = 4
n_batches = 16
use_cupy = device != "cpu"
n_threads = get_worker().state.nthreads
with xgb.collective.CommunicatorContext(dmlc_communicator="rabit", **comm_args):
it = tm.IteratorForTest(
*tm.make_batches(
n_samples_per_batch,
n_features,
n_batches,
use_cupy=use_cupy,
random_state=worker_id,
),
cache="cache",
)
if is_qdm:
Xy: xgb.DMatrix = xgb.ExtMemQuantileDMatrix(it, nthread=n_threads)
else:
Xy = xgb.DMatrix(it, nthread=n_threads)
results: xgb.callback.TrainingCallback.EvalsLog = {}
xgb.train(
{"tree_method": "hist", "nthread": n_threads, "device": device},
Xy,
evals=[(Xy, "Train")],
num_boost_round=32,
evals_result=results,
)
assert tm.non_increasing(cast(List[float], results["Train"]["rmse"]))
lx, ly, lw = [], [], []
for i in range(n_workers):
x, y, w = tm.make_batches(
n_samples_per_batch,
n_features,
n_batches,
use_cupy=use_cupy,
random_state=i,
)
lx.extend(x)
ly.extend(y)
lw.extend(w)
X = concat(lx)
yconcat = concat(ly)
wconcat = concat(lw)
if is_qdm:
Xy = xgb.QuantileDMatrix(X, yconcat, weight=wconcat, nthread=n_threads)
else:
Xy = xgb.DMatrix(X, yconcat, weight=wconcat, nthread=n_threads)
results_local: xgb.callback.TrainingCallback.EvalsLog = {}
xgb.train(
{"tree_method": "hist", "nthread": n_threads, "device": device},
Xy,
evals=[(Xy, "Train")],
num_boost_round=32,
evals_result=results_local,
)
np.testing.assert_allclose(
results["Train"]["rmse"], results_local["Train"]["rmse"], rtol=1e-4
)

View File

@ -318,55 +318,4 @@ TEST_F(MGPUHistTest, HistColumnSplit) {
this->DoTest([&] { VerifyHistColumnSplit(kRows, kCols, expected_tree); }, true);
this->DoTest([&] { VerifyHistColumnSplit(kRows, kCols, expected_tree); }, false);
}
namespace {
RegTree GetApproxTree(Context const* ctx, DMatrix* dmat) {
ObjInfo task{ObjInfo::kRegression};
std::unique_ptr<TreeUpdater> approx_maker{TreeUpdater::Create("grow_gpu_approx", ctx, &task)};
approx_maker->Configure(Args{});
TrainParam param;
param.UpdateAllowUnknown(Args{});
linalg::Matrix<GradientPair> gpair({dmat->Info().num_row_}, ctx->Device());
gpair.Data()->Copy(GenerateRandomGradients(dmat->Info().num_row_));
std::vector<HostDeviceVector<bst_node_t>> position(1);
RegTree tree;
approx_maker->Update(&param, &gpair, dmat, common::Span<HostDeviceVector<bst_node_t>>{position},
{&tree});
return tree;
}
void VerifyApproxColumnSplit(bst_idx_t rows, bst_feature_t cols, RegTree const& expected_tree) {
auto ctx = MakeCUDACtx(DistGpuIdx());
auto Xy = RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(true);
auto const world_size = collective::GetWorldSize();
auto const rank = collective::GetRank();
std::unique_ptr<DMatrix> sliced{Xy->SliceCol(world_size, rank)};
RegTree tree = GetApproxTree(&ctx, sliced.get());
Json json{Object{}};
tree.SaveModel(&json);
Json expected_json{Object{}};
expected_tree.SaveModel(&expected_json);
ASSERT_EQ(json, expected_json);
}
} // anonymous namespace
class MGPUApproxTest : public collective::BaseMGPUTest {};
TEST_F(MGPUApproxTest, GPUApproxColumnSplit) {
auto constexpr kRows = 32;
auto constexpr kCols = 16;
Context ctx(MakeCUDACtx(0));
auto dmat = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true);
RegTree expected_tree = GetApproxTree(&ctx, dmat.get());
this->DoTest([&] { VerifyApproxColumnSplit(kRows, kCols, expected_tree); }, true);
this->DoTest([&] { VerifyApproxColumnSplit(kRows, kCols, expected_tree); }, false);
}
} // namespace xgboost::tree

View File

@ -1,77 +1,18 @@
from typing import List, cast
"""Copyright 2024, XGBoost contributors"""
import numpy as np
from distributed import Client, Scheduler, Worker, get_worker
import pytest
from distributed import Client, Scheduler, Worker
from distributed.utils_test import gen_cluster
import xgboost as xgb
from xgboost import testing as tm
from xgboost.compat import concat
def run_external_memory(worker_id: int, n_workers: int, comm_args: dict) -> None:
n_samples_per_batch = 32
n_features = 4
n_batches = 16
use_cupy = False
n_threads = get_worker().state.nthreads
with xgb.collective.CommunicatorContext(dmlc_communicator="rabit", **comm_args):
it = tm.IteratorForTest(
*tm.make_batches(
n_samples_per_batch,
n_features,
n_batches,
use_cupy,
random_state=worker_id,
),
cache="cache",
)
Xy = xgb.DMatrix(it, nthread=n_threads)
results: xgb.callback.TrainingCallback.EvalsLog = {}
booster = xgb.train(
{"tree_method": "hist", "nthread": n_threads},
Xy,
evals=[(Xy, "Train")],
num_boost_round=32,
evals_result=results,
)
assert tm.non_increasing(cast(List[float], results["Train"]["rmse"]))
lx, ly, lw = [], [], []
for i in range(n_workers):
x, y, w = tm.make_batches(
n_samples_per_batch,
n_features,
n_batches,
use_cupy,
random_state=i,
)
lx.extend(x)
ly.extend(y)
lw.extend(w)
X = concat(lx)
yconcat = concat(ly)
wconcat = concat(lw)
Xy = xgb.DMatrix(X, yconcat, weight=wconcat, nthread=n_threads)
results_local: xgb.callback.TrainingCallback.EvalsLog = {}
booster = xgb.train(
{"tree_method": "hist", "nthread": n_threads},
Xy,
evals=[(Xy, "Train")],
num_boost_round=32,
evals_result=results_local,
)
np.testing.assert_allclose(
results["Train"]["rmse"], results_local["Train"]["rmse"], rtol=1e-4
)
from xgboost.testing.dask import check_external_memory
@pytest.mark.parametrize("is_qdm", [True, False])
@gen_cluster(client=True)
async def test_external_memory(
client: Client, s: Scheduler, a: Worker, b: Worker
client: Client, s: Scheduler, a: Worker, b: Worker, is_qdm: bool
) -> None:
workers = tm.get_client_workers(client)
args = await client.sync(
@ -83,6 +24,11 @@ async def test_external_memory(
n_workers = len(workers)
futs = client.map(
run_external_memory, range(n_workers), n_workers=n_workers, comm_args=args
check_external_memory,
range(n_workers),
n_workers=n_workers,
device="cpu",
comm_args=args,
is_qdm=is_qdm,
)
await client.gather(futs)

View File

@ -7,24 +7,9 @@ import pickle
import socket
import tempfile
from concurrent.futures import ThreadPoolExecutor
from copy import copy
from functools import partial
from itertools import starmap
from math import ceil
from operator import attrgetter, getitem
from pathlib import Path
from typing import (
Any,
Dict,
Generator,
List,
Literal,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
from typing import Any, Dict, Generator, Literal, Optional, Tuple, Type, Union
import hypothesis
import numpy as np
@ -37,7 +22,6 @@ from sklearn.datasets import make_classification, make_regression
import xgboost as xgb
from xgboost import dask as dxgb
from xgboost import testing as tm
from xgboost.data import _is_cudf_df
from xgboost.testing.params import hist_cache_strategy, hist_parameter_strategy
from xgboost.testing.shared import (
get_feature_weights,