[EM] Add basic distributed GPU tests. (#10861)
- Split Hist and Approx tests in unittests. - Basic GPU tests for distributed.
This commit is contained in:
parent
92f1c48a22
commit
9ecb7583e9
@ -1,14 +1,16 @@
|
|||||||
"""Tests for dask shared by different test modules."""
|
"""Tests for dask shared by different test modules."""
|
||||||
|
|
||||||
from typing import Literal
|
from typing import List, Literal, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from dask import array as da
|
from dask import array as da
|
||||||
from dask import dataframe as dd
|
from dask import dataframe as dd
|
||||||
from distributed import Client
|
from distributed import Client, get_worker
|
||||||
|
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
|
import xgboost.testing as tm
|
||||||
|
from xgboost.compat import concat
|
||||||
from xgboost.testing.updater import get_basescore
|
from xgboost.testing.updater import get_basescore
|
||||||
|
|
||||||
|
|
||||||
@ -91,3 +93,76 @@ def check_uneven_nan(
|
|||||||
dd.from_pandas(X, npartitions=n_workers),
|
dd.from_pandas(X, npartitions=n_workers),
|
||||||
dd.from_pandas(y, npartitions=n_workers),
|
dd.from_pandas(y, npartitions=n_workers),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def check_external_memory( # pylint: disable=too-many-locals
|
||||||
|
worker_id: int,
|
||||||
|
n_workers: int,
|
||||||
|
device: str,
|
||||||
|
comm_args: dict,
|
||||||
|
is_qdm: bool,
|
||||||
|
) -> None:
|
||||||
|
"""Basic checks for distributed external memory."""
|
||||||
|
n_samples_per_batch = 32
|
||||||
|
n_features = 4
|
||||||
|
n_batches = 16
|
||||||
|
use_cupy = device != "cpu"
|
||||||
|
|
||||||
|
n_threads = get_worker().state.nthreads
|
||||||
|
with xgb.collective.CommunicatorContext(dmlc_communicator="rabit", **comm_args):
|
||||||
|
it = tm.IteratorForTest(
|
||||||
|
*tm.make_batches(
|
||||||
|
n_samples_per_batch,
|
||||||
|
n_features,
|
||||||
|
n_batches,
|
||||||
|
use_cupy=use_cupy,
|
||||||
|
random_state=worker_id,
|
||||||
|
),
|
||||||
|
cache="cache",
|
||||||
|
)
|
||||||
|
if is_qdm:
|
||||||
|
Xy: xgb.DMatrix = xgb.ExtMemQuantileDMatrix(it, nthread=n_threads)
|
||||||
|
else:
|
||||||
|
Xy = xgb.DMatrix(it, nthread=n_threads)
|
||||||
|
results: xgb.callback.TrainingCallback.EvalsLog = {}
|
||||||
|
xgb.train(
|
||||||
|
{"tree_method": "hist", "nthread": n_threads, "device": device},
|
||||||
|
Xy,
|
||||||
|
evals=[(Xy, "Train")],
|
||||||
|
num_boost_round=32,
|
||||||
|
evals_result=results,
|
||||||
|
)
|
||||||
|
assert tm.non_increasing(cast(List[float], results["Train"]["rmse"]))
|
||||||
|
|
||||||
|
lx, ly, lw = [], [], []
|
||||||
|
for i in range(n_workers):
|
||||||
|
x, y, w = tm.make_batches(
|
||||||
|
n_samples_per_batch,
|
||||||
|
n_features,
|
||||||
|
n_batches,
|
||||||
|
use_cupy=use_cupy,
|
||||||
|
random_state=i,
|
||||||
|
)
|
||||||
|
lx.extend(x)
|
||||||
|
ly.extend(y)
|
||||||
|
lw.extend(w)
|
||||||
|
|
||||||
|
X = concat(lx)
|
||||||
|
yconcat = concat(ly)
|
||||||
|
wconcat = concat(lw)
|
||||||
|
if is_qdm:
|
||||||
|
Xy = xgb.QuantileDMatrix(X, yconcat, weight=wconcat, nthread=n_threads)
|
||||||
|
else:
|
||||||
|
Xy = xgb.DMatrix(X, yconcat, weight=wconcat, nthread=n_threads)
|
||||||
|
|
||||||
|
results_local: xgb.callback.TrainingCallback.EvalsLog = {}
|
||||||
|
xgb.train(
|
||||||
|
{"tree_method": "hist", "nthread": n_threads, "device": device},
|
||||||
|
Xy,
|
||||||
|
evals=[(Xy, "Train")],
|
||||||
|
num_boost_round=32,
|
||||||
|
evals_result=results_local,
|
||||||
|
)
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
results["Train"]["rmse"], results_local["Train"]["rmse"], rtol=1e-4
|
||||||
|
)
|
||||||
|
|||||||
@ -318,55 +318,4 @@ TEST_F(MGPUHistTest, HistColumnSplit) {
|
|||||||
this->DoTest([&] { VerifyHistColumnSplit(kRows, kCols, expected_tree); }, true);
|
this->DoTest([&] { VerifyHistColumnSplit(kRows, kCols, expected_tree); }, true);
|
||||||
this->DoTest([&] { VerifyHistColumnSplit(kRows, kCols, expected_tree); }, false);
|
this->DoTest([&] { VerifyHistColumnSplit(kRows, kCols, expected_tree); }, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
|
||||||
RegTree GetApproxTree(Context const* ctx, DMatrix* dmat) {
|
|
||||||
ObjInfo task{ObjInfo::kRegression};
|
|
||||||
std::unique_ptr<TreeUpdater> approx_maker{TreeUpdater::Create("grow_gpu_approx", ctx, &task)};
|
|
||||||
approx_maker->Configure(Args{});
|
|
||||||
|
|
||||||
TrainParam param;
|
|
||||||
param.UpdateAllowUnknown(Args{});
|
|
||||||
|
|
||||||
linalg::Matrix<GradientPair> gpair({dmat->Info().num_row_}, ctx->Device());
|
|
||||||
gpair.Data()->Copy(GenerateRandomGradients(dmat->Info().num_row_));
|
|
||||||
|
|
||||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
|
||||||
RegTree tree;
|
|
||||||
approx_maker->Update(¶m, &gpair, dmat, common::Span<HostDeviceVector<bst_node_t>>{position},
|
|
||||||
{&tree});
|
|
||||||
return tree;
|
|
||||||
}
|
|
||||||
|
|
||||||
void VerifyApproxColumnSplit(bst_idx_t rows, bst_feature_t cols, RegTree const& expected_tree) {
|
|
||||||
auto ctx = MakeCUDACtx(DistGpuIdx());
|
|
||||||
|
|
||||||
auto Xy = RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(true);
|
|
||||||
auto const world_size = collective::GetWorldSize();
|
|
||||||
auto const rank = collective::GetRank();
|
|
||||||
std::unique_ptr<DMatrix> sliced{Xy->SliceCol(world_size, rank)};
|
|
||||||
|
|
||||||
RegTree tree = GetApproxTree(&ctx, sliced.get());
|
|
||||||
|
|
||||||
Json json{Object{}};
|
|
||||||
tree.SaveModel(&json);
|
|
||||||
Json expected_json{Object{}};
|
|
||||||
expected_tree.SaveModel(&expected_json);
|
|
||||||
ASSERT_EQ(json, expected_json);
|
|
||||||
}
|
|
||||||
} // anonymous namespace
|
|
||||||
|
|
||||||
class MGPUApproxTest : public collective::BaseMGPUTest {};
|
|
||||||
|
|
||||||
TEST_F(MGPUApproxTest, GPUApproxColumnSplit) {
|
|
||||||
auto constexpr kRows = 32;
|
|
||||||
auto constexpr kCols = 16;
|
|
||||||
|
|
||||||
Context ctx(MakeCUDACtx(0));
|
|
||||||
auto dmat = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true);
|
|
||||||
RegTree expected_tree = GetApproxTree(&ctx, dmat.get());
|
|
||||||
|
|
||||||
this->DoTest([&] { VerifyApproxColumnSplit(kRows, kCols, expected_tree); }, true);
|
|
||||||
this->DoTest([&] { VerifyApproxColumnSplit(kRows, kCols, expected_tree); }, false);
|
|
||||||
}
|
|
||||||
} // namespace xgboost::tree
|
} // namespace xgboost::tree
|
||||||
|
|||||||
@ -1,77 +1,18 @@
|
|||||||
from typing import List, cast
|
"""Copyright 2024, XGBoost contributors"""
|
||||||
|
|
||||||
import numpy as np
|
import pytest
|
||||||
from distributed import Client, Scheduler, Worker, get_worker
|
from distributed import Client, Scheduler, Worker
|
||||||
from distributed.utils_test import gen_cluster
|
from distributed.utils_test import gen_cluster
|
||||||
|
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
from xgboost import testing as tm
|
from xgboost import testing as tm
|
||||||
from xgboost.compat import concat
|
from xgboost.testing.dask import check_external_memory
|
||||||
|
|
||||||
|
|
||||||
def run_external_memory(worker_id: int, n_workers: int, comm_args: dict) -> None:
|
|
||||||
n_samples_per_batch = 32
|
|
||||||
n_features = 4
|
|
||||||
n_batches = 16
|
|
||||||
use_cupy = False
|
|
||||||
|
|
||||||
n_threads = get_worker().state.nthreads
|
|
||||||
with xgb.collective.CommunicatorContext(dmlc_communicator="rabit", **comm_args):
|
|
||||||
it = tm.IteratorForTest(
|
|
||||||
*tm.make_batches(
|
|
||||||
n_samples_per_batch,
|
|
||||||
n_features,
|
|
||||||
n_batches,
|
|
||||||
use_cupy,
|
|
||||||
random_state=worker_id,
|
|
||||||
),
|
|
||||||
cache="cache",
|
|
||||||
)
|
|
||||||
Xy = xgb.DMatrix(it, nthread=n_threads)
|
|
||||||
results: xgb.callback.TrainingCallback.EvalsLog = {}
|
|
||||||
booster = xgb.train(
|
|
||||||
{"tree_method": "hist", "nthread": n_threads},
|
|
||||||
Xy,
|
|
||||||
evals=[(Xy, "Train")],
|
|
||||||
num_boost_round=32,
|
|
||||||
evals_result=results,
|
|
||||||
)
|
|
||||||
assert tm.non_increasing(cast(List[float], results["Train"]["rmse"]))
|
|
||||||
|
|
||||||
lx, ly, lw = [], [], []
|
|
||||||
for i in range(n_workers):
|
|
||||||
x, y, w = tm.make_batches(
|
|
||||||
n_samples_per_batch,
|
|
||||||
n_features,
|
|
||||||
n_batches,
|
|
||||||
use_cupy,
|
|
||||||
random_state=i,
|
|
||||||
)
|
|
||||||
lx.extend(x)
|
|
||||||
ly.extend(y)
|
|
||||||
lw.extend(w)
|
|
||||||
|
|
||||||
X = concat(lx)
|
|
||||||
yconcat = concat(ly)
|
|
||||||
wconcat = concat(lw)
|
|
||||||
Xy = xgb.DMatrix(X, yconcat, weight=wconcat, nthread=n_threads)
|
|
||||||
|
|
||||||
results_local: xgb.callback.TrainingCallback.EvalsLog = {}
|
|
||||||
booster = xgb.train(
|
|
||||||
{"tree_method": "hist", "nthread": n_threads},
|
|
||||||
Xy,
|
|
||||||
evals=[(Xy, "Train")],
|
|
||||||
num_boost_round=32,
|
|
||||||
evals_result=results_local,
|
|
||||||
)
|
|
||||||
np.testing.assert_allclose(
|
|
||||||
results["Train"]["rmse"], results_local["Train"]["rmse"], rtol=1e-4
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("is_qdm", [True, False])
|
||||||
@gen_cluster(client=True)
|
@gen_cluster(client=True)
|
||||||
async def test_external_memory(
|
async def test_external_memory(
|
||||||
client: Client, s: Scheduler, a: Worker, b: Worker
|
client: Client, s: Scheduler, a: Worker, b: Worker, is_qdm: bool
|
||||||
) -> None:
|
) -> None:
|
||||||
workers = tm.get_client_workers(client)
|
workers = tm.get_client_workers(client)
|
||||||
args = await client.sync(
|
args = await client.sync(
|
||||||
@ -83,6 +24,11 @@ async def test_external_memory(
|
|||||||
n_workers = len(workers)
|
n_workers = len(workers)
|
||||||
|
|
||||||
futs = client.map(
|
futs = client.map(
|
||||||
run_external_memory, range(n_workers), n_workers=n_workers, comm_args=args
|
check_external_memory,
|
||||||
|
range(n_workers),
|
||||||
|
n_workers=n_workers,
|
||||||
|
device="cpu",
|
||||||
|
comm_args=args,
|
||||||
|
is_qdm=is_qdm,
|
||||||
)
|
)
|
||||||
await client.gather(futs)
|
await client.gather(futs)
|
||||||
|
|||||||
@ -7,24 +7,9 @@ import pickle
|
|||||||
import socket
|
import socket
|
||||||
import tempfile
|
import tempfile
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from copy import copy
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from itertools import starmap
|
|
||||||
from math import ceil
|
|
||||||
from operator import attrgetter, getitem
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (
|
from typing import Any, Dict, Generator, Literal, Optional, Tuple, Type, Union
|
||||||
Any,
|
|
||||||
Dict,
|
|
||||||
Generator,
|
|
||||||
List,
|
|
||||||
Literal,
|
|
||||||
Optional,
|
|
||||||
Tuple,
|
|
||||||
Type,
|
|
||||||
TypeVar,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
import hypothesis
|
import hypothesis
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -37,7 +22,6 @@ from sklearn.datasets import make_classification, make_regression
|
|||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
from xgboost import dask as dxgb
|
from xgboost import dask as dxgb
|
||||||
from xgboost import testing as tm
|
from xgboost import testing as tm
|
||||||
from xgboost.data import _is_cudf_df
|
|
||||||
from xgboost.testing.params import hist_cache_strategy, hist_parameter_strategy
|
from xgboost.testing.params import hist_cache_strategy, hist_parameter_strategy
|
||||||
from xgboost.testing.shared import (
|
from xgboost.testing.shared import (
|
||||||
get_feature_weights,
|
get_feature_weights,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user