Reduce time for some multi-gpu tests (#8288)

* Faster dask tests

* Reuse AllReducer objects in tests.

* Faster boost from prediction tests.

* Use rmm dask fixture.

* Speed up dask demo.

* mypy

* Format with black.

* mypy

* Clang-tidy

Co-authored-by: Hyunsu Philip Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Rory Mitchell 2022-10-04 12:49:33 +02:00 committed by GitHub
parent ca0547bb65
commit d686bf52a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 337 additions and 336 deletions

View File

@ -4,13 +4,12 @@ Example of training with Dask on GPU
"""
from dask_cuda import LocalCUDACluster
import dask_cudf
from dask.distributed import Client, wait
from dask.distributed import Client
from dask import array as da
from dask import dataframe as dd
import xgboost as xgb
from xgboost import dask as dxgb
from xgboost.dask import DaskDMatrix
import argparse
def using_dask_matrix(client: Client, X, y):
@ -51,7 +50,7 @@ def using_quantile_device_dmatrix(client: Client, X, y):
# `DaskDeviceQuantileDMatrix` is used instead of `DaskDMatrix`, be careful
# that it can not be used for anything else other than training.
dtrain = dxgb.DaskDeviceQuantileDMatrix(client, X, y)
dtrain = dxgb.DaskQuantileDMatrix(client, X, y)
output = xgb.dask.train(client,
{'verbosity': 2,
'tree_method': 'gpu_hist'},
@ -63,12 +62,6 @@ def using_quantile_device_dmatrix(client: Client, X, y):
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--ddqdm', choices=[0, 1], type=int, default=1,
help='''Whether should we use `DaskDeviceQuantileDMatrix`''')
args = parser.parse_args()
# `LocalCUDACluster` is used for assigning GPU to XGBoost processes. Here
# `n_workers` represents the number of GPUs since we use one GPU per worker
# process.
@ -77,12 +70,10 @@ if __name__ == '__main__':
# generate some random data for demonstration
m = 100000
n = 100
X = da.random.random(size=(m, n), chunks=100)
y = da.random.random(size=(m, ), chunks=100)
X = da.random.random(size=(m, n), chunks=10000)
y = da.random.random(size=(m, ), chunks=10000)
if args.ddqdm == 1:
print('Using DaskDeviceQuantileDMatrix')
from_ddqdm = using_quantile_device_dmatrix(client, X, y)
else:
print('Using DMatrix')
from_dmatrix = using_dask_matrix(client, X, y)
print('Using DaskQuantileDMatrix')
from_ddqdm = using_quantile_device_dmatrix(client, X, y)
print('Using DMatrix')
from_dmatrix = using_dask_matrix(client, X, y)

View File

@ -508,7 +508,7 @@ void SketchContainer::AllReduce() {
timer_.Start(__func__);
if (!reducer_) {
reducer_ = std::make_unique<dh::AllReducer>();
reducer_ = std::make_shared<dh::AllReducer>();
reducer_->Init(device_);
}
// Reduce the overhead on syncing.
@ -518,6 +518,7 @@ void SketchContainer::AllReduce() {
std::min(global_sum_rows, static_cast<size_t>(num_bins_ * kFactor));
this->Prune(intermediate_num_cuts);
auto d_columns_ptr = this->columns_ptr_.ConstDeviceSpan();
CHECK_EQ(d_columns_ptr.size(), num_columns_ + 1);
size_t n = d_columns_ptr.size();

View File

@ -37,7 +37,7 @@ class SketchContainer {
private:
Monitor timer_;
std::unique_ptr<dh::AllReducer> reducer_;
std::shared_ptr<dh::AllReducer> reducer_;
HostDeviceVector<FeatureType> feature_types_;
bst_row_t num_rows_;
bst_feature_t num_columns_;
@ -93,35 +93,37 @@ class SketchContainer {
* \param num_columns Total number of columns in dataset.
* \param num_rows Total number of rows in known dataset (typically the rows in current worker).
* \param device GPU ID.
* \param reducer Optional initialised reducer. Useful for speeding up testing.
*/
SketchContainer(HostDeviceVector<FeatureType> const& feature_types,
int32_t max_bin,
bst_feature_t num_columns, bst_row_t num_rows,
int32_t device)
: num_rows_{num_rows},
num_columns_{num_columns}, num_bins_{max_bin}, device_{device} {
CHECK_GE(device, 0);
// Initialize Sketches for this dmatrix
this->columns_ptr_.SetDevice(device_);
this->columns_ptr_.Resize(num_columns + 1);
this->columns_ptr_b_.SetDevice(device_);
this->columns_ptr_b_.Resize(num_columns + 1);
SketchContainer(HostDeviceVector<FeatureType> const &feature_types,
int32_t max_bin, bst_feature_t num_columns,
bst_row_t num_rows, int32_t device,
std::shared_ptr<dh::AllReducer> reducer = nullptr)
: num_rows_{num_rows},
num_columns_{num_columns}, num_bins_{max_bin}, device_{device},
reducer_(std::move(reducer)) {
CHECK_GE(device, 0);
// Initialize Sketches for this dmatrix
this->columns_ptr_.SetDevice(device_);
this->columns_ptr_.Resize(num_columns + 1);
this->columns_ptr_b_.SetDevice(device_);
this->columns_ptr_b_.Resize(num_columns + 1);
this->feature_types_.Resize(feature_types.Size());
this->feature_types_.Copy(feature_types);
// Pull to device.
this->feature_types_.SetDevice(device);
this->feature_types_.ConstDeviceSpan();
this->feature_types_.ConstHostSpan();
this->feature_types_.Resize(feature_types.Size());
this->feature_types_.Copy(feature_types);
// Pull to device.
this->feature_types_.SetDevice(device);
this->feature_types_.ConstDeviceSpan();
this->feature_types_.ConstHostSpan();
auto d_feature_types = feature_types_.ConstDeviceSpan();
has_categorical_ =
!d_feature_types.empty() &&
thrust::any_of(dh::tbegin(d_feature_types), dh::tend(d_feature_types),
common::IsCatOp{});
auto d_feature_types = feature_types_.ConstDeviceSpan();
has_categorical_ =
!d_feature_types.empty() &&
thrust::any_of(dh::tbegin(d_feature_types), dh::tend(d_feature_types),
common::IsCatOp{});
timer_.Init(__func__);
}
timer_.Init(__func__);
}
/* \brief Return GPU ID for this container. */
int32_t DeviceIdx() const { return device_; }
/* \brief Whether the predictor matrix contains categorical features. */

View File

@ -349,6 +349,9 @@ TEST(GPUQuantile, AllReduceBasic) {
return;
}
auto reducer = std::make_shared<dh::AllReducer>();
reducer->Init(0);
constexpr size_t kRows = 1000, kCols = 100;
RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const& info) {
// Set up single node version;
@ -378,12 +381,12 @@ TEST(GPUQuantile, AllReduceBasic) {
}
sketch_on_single_node.Unique();
TestQuantileElemRank(0, sketch_on_single_node.Data(),
sketch_on_single_node.ColumnsPtr());
sketch_on_single_node.ColumnsPtr(), true);
// Set up distributed version. We rely on using rank as seed to generate
// the exact same copy of data.
auto rank = rabit::GetRank();
SketchContainer sketch_distributed(ft, n_bins, kCols, kRows, 0);
SketchContainer sketch_distributed(ft, n_bins, kCols, kRows, 0, reducer);
HostDeviceVector<float> storage;
std::string interface_str = RandomDataGenerator{kRows, kCols, 0}
.Device(0)
@ -402,7 +405,7 @@ TEST(GPUQuantile, AllReduceBasic) {
sketch_on_single_node.Data().size());
TestQuantileElemRank(0, sketch_distributed.Data(),
sketch_distributed.ColumnsPtr());
sketch_distributed.ColumnsPtr(), true);
std::vector<SketchEntry> single_node_data(
sketch_on_single_node.Data().size());
@ -432,13 +435,15 @@ TEST(GPUQuantile, SameOnAllWorkers) {
} else {
return;
}
auto reducer = std::make_shared<dh::AllReducer>();
reducer->Init(0);
constexpr size_t kRows = 1000, kCols = 100;
RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins,
MetaInfo const &info) {
auto rank = rabit::GetRank();
HostDeviceVector<FeatureType> ft;
SketchContainer sketch_distributed(ft, n_bins, kCols, kRows, 0);
SketchContainer sketch_distributed(ft, n_bins, kCols, kRows, 0, reducer);
HostDeviceVector<float> storage;
std::string interface_str = RandomDataGenerator{kRows, kCols, 0}
.Device(0)
@ -450,7 +455,7 @@ TEST(GPUQuantile, SameOnAllWorkers) {
&sketch_distributed);
sketch_distributed.AllReduce();
sketch_distributed.Unique();
TestQuantileElemRank(0, sketch_distributed.Data(), sketch_distributed.ColumnsPtr());
TestQuantileElemRank(0, sketch_distributed.Data(), sketch_distributed.ColumnsPtr(), true);
// Test for all workers having the same sketch.
size_t n_data = sketch_distributed.Data().size();
@ -467,12 +472,9 @@ TEST(GPUQuantile, SameOnAllWorkers) {
thrust::copy(thrust::device, local_data.data(),
local_data.data() + local_data.size(),
all_workers.begin() + local_data.size() * rank);
dh::AllReducer reducer;
reducer.Init(0);
reducer.AllReduceSum(all_workers.data().get(), all_workers.data().get(),
reducer->AllReduceSum(all_workers.data().get(), all_workers.data().get(),
all_workers.size());
reducer.Synchronize();
reducer->Synchronize();
auto base_line = dh::ToSpan(all_workers).subspan(0, size_as_float);
std::vector<float> h_base_line(base_line.size());

View File

@ -37,12 +37,12 @@ inline void InitRabitContext(std::string msg, int32_t n_workers) {
}
template <typename Fn> void RunWithSeedsAndBins(size_t rows, Fn fn) {
std::vector<int32_t> seeds(4);
std::vector<int32_t> seeds(2);
SimpleLCG lcg;
SimpleRealUniformDistribution<float> dist(3, 1000);
std::generate(seeds.begin(), seeds.end(), [&](){ return dist(&lcg); });
std::vector<size_t> bins(8);
std::vector<size_t> bins(2);
for (size_t i = 0; i < bins.size() - 1; ++i) {
bins[i] = i * 35 + 2;
}

View File

@ -22,8 +22,8 @@ def setup_rmm_pool(request, pytestconfig):
rmm.reinitialize(pool_allocator=True, initial_pool_size=1024*1024*1024,
devices=list(range(get_n_gpus())))
@pytest.fixture(scope='function')
def local_cuda_cluster(request, pytestconfig):
@pytest.fixture(scope='class')
def local_cuda_client(request, pytestconfig):
kwargs = {}
if hasattr(request, 'param'):
kwargs.update(request.param)
@ -31,13 +31,12 @@ def local_cuda_cluster(request, pytestconfig):
if not has_rmm():
raise ImportError('The --use-rmm-pool option requires the RMM package')
import rmm
from dask_cuda.utils import get_n_gpus
kwargs['rmm_pool_size'] = '2GB'
if tm.no_dask_cuda()['condition']:
raise ImportError('The local_cuda_cluster fixture requires dask_cuda package')
from dask_cuda import LocalCUDACluster
with LocalCUDACluster(**kwargs) as cluster:
yield cluster
from dask.distributed import Client
yield Client(LocalCUDACluster(**kwargs))
def pytest_addoption(parser):
parser.addoption('--use-rmm-pool', action='store_true', default=False, help='Use RMM pool')

View File

@ -32,8 +32,5 @@ def test_categorical_demo():
@pytest.mark.mgpu
def test_dask_training():
script = os.path.join(tm.PROJECT_ROOT, 'demo', 'dask', 'gpu_training.py')
cmd = ['python', script, '--ddqdm=1']
subprocess.check_call(cmd)
cmd = ['python', script, '--ddqdm=0']
cmd = ['python', script]
subprocess.check_call(cmd)

View File

@ -17,26 +17,26 @@ if sys.platform.startswith("win"):
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
sys.path.append("tests/python")
import testing as tm # noqa
import testing as tm # noqa
if tm.no_dask_cuda()["condition"]:
pytest.skip(tm.no_dask_cuda()["reason"], allow_module_level=True)
from test_with_dask import run_empty_dmatrix_reg # noqa
from test_with_dask import run_empty_dmatrix_auc # noqa
from test_with_dask import run_auc # noqa
from test_with_dask import run_empty_dmatrix_reg # noqa
from test_with_dask import run_empty_dmatrix_auc # noqa
from test_with_dask import run_auc # noqa
from test_with_dask import run_boost_from_prediction # noqa
from test_with_dask import run_boost_from_prediction_multi_class # noqa
from test_with_dask import run_dask_classifier # noqa
from test_with_dask import run_empty_dmatrix_cls # noqa
from test_with_dask import _get_client_workers # noqa
from test_with_dask import generate_array # noqa
from test_with_dask import kCols as random_cols # noqa
from test_with_dask import suppress # noqa
from test_with_dask import run_tree_stats # noqa
from test_with_dask import run_categorical # noqa
from test_with_dask import make_categorical # noqa
from test_with_dask import run_dask_classifier # noqa
from test_with_dask import run_empty_dmatrix_cls # noqa
from test_with_dask import _get_client_workers # noqa
from test_with_dask import generate_array # noqa
from test_with_dask import kCols as random_cols # noqa
from test_with_dask import suppress # noqa
from test_with_dask import run_tree_stats # noqa
from test_with_dask import run_categorical # noqa
from test_with_dask import make_categorical # noqa
try:
@ -45,7 +45,7 @@ try:
import xgboost as xgb
from dask.distributed import Client
from dask import array as da
from dask_cuda import LocalCUDACluster
from dask_cuda import LocalCUDACluster, utils
import cudf
except ImportError:
pass
@ -53,6 +53,7 @@ except ImportError:
def run_with_dask_dataframe(DMatrixT: Type, client: Client) -> None:
import cupy as cp
cp.cuda.runtime.setDevice(0)
X, y, _ = generate_array()
@ -63,14 +64,16 @@ def run_with_dask_dataframe(DMatrixT: Type, client: Client) -> None:
y = y.map_partitions(cudf.from_pandas)
dtrain = DMatrixT(client, X, y)
out = dxgb.train(client, {'tree_method': 'gpu_hist',
'debug_synchronize': True},
dtrain=dtrain,
evals=[(dtrain, 'X')],
num_boost_round=4)
out = dxgb.train(
client,
{"tree_method": "gpu_hist", "debug_synchronize": True},
dtrain=dtrain,
evals=[(dtrain, "X")],
num_boost_round=4,
)
assert isinstance(out['booster'], dxgb.Booster)
assert len(out['history']['X']['rmse']) == 4
assert isinstance(out["booster"], dxgb.Booster)
assert len(out["history"]["X"]["rmse"]) == 4
predictions = dxgb.predict(client, out, dtrain)
assert isinstance(predictions.compute(), np.ndarray)
@ -78,27 +81,23 @@ def run_with_dask_dataframe(DMatrixT: Type, client: Client) -> None:
series_predictions = dxgb.inplace_predict(client, out, X)
assert isinstance(series_predictions, dd.Series)
single_node = out['booster'].predict(xgboost.DMatrix(X.compute()))
single_node = out["booster"].predict(xgboost.DMatrix(X.compute()))
cp.testing.assert_allclose(single_node, predictions.compute())
np.testing.assert_allclose(single_node,
series_predictions.compute().to_numpy())
np.testing.assert_allclose(single_node, series_predictions.compute().to_numpy())
predt = dxgb.predict(client, out, X)
assert isinstance(predt, dd.Series)
T = TypeVar('T')
T = TypeVar("T")
def is_df(part: T) -> T:
assert isinstance(part, cudf.DataFrame), part
return part
predt.map_partitions(
is_df,
meta=dd.utils.make_meta({'prediction': 'f4'}))
predt.map_partitions(is_df, meta=dd.utils.make_meta({"prediction": "f4"}))
cp.testing.assert_allclose(
predt.values.compute(), single_node)
cp.testing.assert_allclose(predt.values.compute(), single_node)
# Make sure the output can be integrated back to original dataframe
X["predict"] = predictions
@ -110,49 +109,35 @@ def run_with_dask_dataframe(DMatrixT: Type, client: Client) -> None:
def run_with_dask_array(DMatrixT: Type, client: Client) -> None:
import cupy as cp
cp.cuda.runtime.setDevice(0)
X, y, _ = generate_array()
X = X.map_blocks(cp.asarray)
y = y.map_blocks(cp.asarray)
dtrain = DMatrixT(client, X, y)
out = dxgb.train(client, {'tree_method': 'gpu_hist',
'debug_synchronize': True},
dtrain=dtrain,
evals=[(dtrain, 'X')],
num_boost_round=2)
out = dxgb.train(
client,
{"tree_method": "gpu_hist", "debug_synchronize": True},
dtrain=dtrain,
evals=[(dtrain, "X")],
num_boost_round=2,
)
from_dmatrix = dxgb.predict(client, out, dtrain).compute()
inplace_predictions = dxgb.inplace_predict(
client, out, X).compute()
single_node = out['booster'].predict(
xgboost.DMatrix(X.compute()))
inplace_predictions = dxgb.inplace_predict(client, out, X).compute()
single_node = out["booster"].predict(xgboost.DMatrix(X.compute()))
np.testing.assert_allclose(single_node, from_dmatrix)
device = cp.cuda.runtime.getDevice()
assert device == inplace_predictions.device.id
single_node = cp.array(single_node)
assert device == single_node.device.id
cp.testing.assert_allclose(
single_node,
inplace_predictions)
@pytest.mark.skipif(**tm.no_dask_cudf())
def test_categorical(local_cuda_cluster: LocalCUDACluster) -> None:
with Client(local_cuda_cluster) as client:
import dask_cudf
X, y = make_categorical(client, 10000, 30, 13)
X = dask_cudf.from_dask_dataframe(X)
X_onehot, _ = make_categorical(client, 10000, 30, 13, True)
X_onehot = dask_cudf.from_dask_dataframe(X_onehot)
run_categorical(client, "gpu_hist", X, X_onehot, y)
cp.testing.assert_allclose(single_node, inplace_predictions)
def to_cp(x: Any, DMatrixT: Type) -> Any:
import cupy
if isinstance(x, np.ndarray) and \
DMatrixT is dxgb.DaskDeviceQuantileDMatrix:
if isinstance(x, np.ndarray) and DMatrixT is dxgb.DaskDeviceQuantileDMatrix:
X = cupy.array(x)
else:
X = x
@ -213,217 +198,250 @@ def run_gpu_hist(
assert tm.non_increasing(history)
@pytest.mark.skipif(**tm.no_cudf())
def test_boost_from_prediction(local_cuda_cluster: LocalCUDACluster) -> None:
import cudf
from sklearn.datasets import load_breast_cancer, load_digits
with Client(local_cuda_cluster) as client:
X_, y_ = load_breast_cancer(return_X_y=True)
X = dd.from_array(X_, chunksize=100).map_partitions(cudf.from_pandas)
y = dd.from_array(y_, chunksize=100).map_partitions(cudf.from_pandas)
run_boost_from_prediction(X, y, "gpu_hist", client)
def test_tree_stats() -> None:
with LocalCUDACluster(n_workers=1) as cluster:
with Client(cluster) as client:
local = run_tree_stats(client, "gpu_hist")
X_, y_ = load_digits(return_X_y=True)
X = dd.from_array(X_, chunksize=100).map_partitions(cudf.from_pandas)
y = dd.from_array(y_, chunksize=100).map_partitions(cudf.from_pandas)
run_boost_from_prediction_multi_class(X, y, "gpu_hist", client)
with LocalCUDACluster(n_workers=2) as cluster:
with Client(cluster) as client:
distributed = run_tree_stats(client, "gpu_hist")
assert local == distributed
class TestDistributedGPU:
@pytest.mark.skipif(**tm.no_cudf())
def test_boost_from_prediction(self, local_cuda_client: Client) -> None:
import cudf
from sklearn.datasets import load_breast_cancer, load_iris
X_, y_ = load_breast_cancer(return_X_y=True)
X = dd.from_array(X_, chunksize=100).map_partitions(cudf.from_pandas)
y = dd.from_array(y_, chunksize=100).map_partitions(cudf.from_pandas)
run_boost_from_prediction(X, y, "gpu_hist", local_cuda_client)
X_, y_ = load_iris(return_X_y=True)
X = dd.from_array(X_, chunksize=50).map_partitions(cudf.from_pandas)
y = dd.from_array(y_, chunksize=50).map_partitions(cudf.from_pandas)
run_boost_from_prediction_multi_class(X, y, "gpu_hist", local_cuda_client)
@pytest.mark.skipif(**tm.no_dask_cudf())
def test_dask_dataframe(self, local_cuda_cluster: LocalCUDACluster) -> None:
with Client(local_cuda_cluster) as client:
run_with_dask_dataframe(dxgb.DaskDMatrix, client)
run_with_dask_dataframe(dxgb.DaskDeviceQuantileDMatrix, client)
def test_dask_dataframe(self, local_cuda_client: Client) -> None:
run_with_dask_dataframe(dxgb.DaskDMatrix, local_cuda_client)
run_with_dask_dataframe(dxgb.DaskDeviceQuantileDMatrix, local_cuda_client)
@pytest.mark.skipif(**tm.no_dask_cudf())
def test_categorical(self, local_cuda_client: Client) -> None:
import dask_cudf
X, y = make_categorical(local_cuda_client, 10000, 30, 13)
X = dask_cudf.from_dask_dataframe(X)
X_onehot, _ = make_categorical(local_cuda_client, 10000, 30, 13, True)
X_onehot = dask_cudf.from_dask_dataframe(X_onehot)
run_categorical(local_cuda_client, "gpu_hist", X, X_onehot, y)
@given(
params=parameter_strategy,
num_rounds=strategies.integers(1, 20),
dataset=tm.dataset_strategy,
dmatrix_type=strategies.sampled_from(
[dxgb.DaskDMatrix, dxgb.DaskDeviceQuantileDMatrix]
),
)
@settings(
deadline=duration(seconds=120),
max_examples=20,
suppress_health_check=suppress,
print_blob=True,
)
@settings(deadline=duration(seconds=120), suppress_health_check=suppress, print_blob=True)
@pytest.mark.skipif(**tm.no_cupy())
@pytest.mark.parametrize(
"local_cuda_cluster", [{"n_workers": 2}], indirect=["local_cuda_cluster"]
)
def test_gpu_hist(
self,
params: Dict,
num_rounds: int,
dataset: tm.TestDataset,
local_cuda_cluster: LocalCUDACluster,
dmatrix_type: type,
local_cuda_client: Client,
) -> None:
with Client(local_cuda_cluster) as client:
run_gpu_hist(params, num_rounds, dataset, dxgb.DaskDMatrix, client)
run_gpu_hist(
params, num_rounds, dataset, dxgb.DaskDeviceQuantileDMatrix, client
)
run_gpu_hist(params, num_rounds, dataset, dmatrix_type, local_cuda_client)
@pytest.mark.skipif(**tm.no_cupy())
def test_dask_array(self, local_cuda_cluster: LocalCUDACluster) -> None:
with Client(local_cuda_cluster) as client:
run_with_dask_array(dxgb.DaskDMatrix, client)
run_with_dask_array(dxgb.DaskDeviceQuantileDMatrix, client)
def test_dask_array(self, local_cuda_client: Client) -> None:
run_with_dask_array(dxgb.DaskDMatrix, local_cuda_client)
run_with_dask_array(dxgb.DaskDeviceQuantileDMatrix, local_cuda_client)
@pytest.mark.skipif(**tm.no_cupy())
def test_early_stopping(self, local_cuda_cluster: LocalCUDACluster) -> None:
def test_early_stopping(self, local_cuda_client: Client) -> None:
from sklearn.datasets import load_breast_cancer
with Client(local_cuda_cluster) as client:
X, y = load_breast_cancer(return_X_y=True)
X, y = da.from_array(X), da.from_array(y)
m = dxgb.DaskDMatrix(client, X, y)
X, y = load_breast_cancer(return_X_y=True)
X, y = da.from_array(X), da.from_array(y)
valid = dxgb.DaskDMatrix(client, X, y)
early_stopping_rounds = 5
booster = dxgb.train(client, {'objective': 'binary:logistic',
'eval_metric': 'error',
'tree_method': 'gpu_hist'}, m,
evals=[(valid, 'Valid')],
num_boost_round=1000,
early_stopping_rounds=early_stopping_rounds)[
'booster']
assert hasattr(booster, 'best_score')
dump = booster.get_dump(dump_format='json')
print(booster.best_iteration)
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
m = dxgb.DaskDMatrix(local_cuda_client, X, y)
valid_X = X
valid_y = y
cls = dxgb.DaskXGBClassifier(objective='binary:logistic',
tree_method='gpu_hist',
eval_metric='error',
n_estimators=100)
cls.client = client
cls.fit(X, y, early_stopping_rounds=early_stopping_rounds,
eval_set=[(valid_X, valid_y)])
booster = cls.get_booster()
dump = booster.get_dump(dump_format='json')
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
valid = dxgb.DaskDMatrix(local_cuda_client, X, y)
early_stopping_rounds = 5
booster = dxgb.train(
local_cuda_client,
{
"objective": "binary:logistic",
"eval_metric": "error",
"tree_method": "gpu_hist",
},
m,
evals=[(valid, "Valid")],
num_boost_round=1000,
early_stopping_rounds=early_stopping_rounds,
)["booster"]
assert hasattr(booster, "best_score")
dump = booster.get_dump(dump_format="json")
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
valid_X = X
valid_y = y
cls = dxgb.DaskXGBClassifier(
objective="binary:logistic",
tree_method="gpu_hist",
eval_metric="error",
n_estimators=100,
)
cls.client = local_cuda_client
cls.fit(
X,
y,
early_stopping_rounds=early_stopping_rounds,
eval_set=[(valid_X, valid_y)],
)
booster = cls.get_booster()
dump = booster.get_dump(dump_format="json")
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
@pytest.mark.skipif(**tm.no_cudf())
@pytest.mark.parametrize("model", ["boosting"])
def test_dask_classifier(
self, model: str, local_cuda_cluster: LocalCUDACluster
) -> None:
def test_dask_classifier(self, model: str, local_cuda_client: Client) -> None:
import dask_cudf
with Client(local_cuda_cluster) as client:
X_, y_, w_ = generate_array(with_weights=True)
y_ = (y_ * 10).astype(np.int32)
X = dask_cudf.from_dask_dataframe(dd.from_dask_array(X_))
y = dask_cudf.from_dask_dataframe(dd.from_dask_array(y_))
w = dask_cudf.from_dask_dataframe(dd.from_dask_array(w_))
run_dask_classifier(X, y, w, model, "gpu_hist", client, 10)
def test_empty_dmatrix(self, local_cuda_cluster: LocalCUDACluster) -> None:
with Client(local_cuda_cluster) as client:
parameters = {'tree_method': 'gpu_hist', 'debug_synchronize': True}
run_empty_dmatrix_reg(client, parameters)
run_empty_dmatrix_cls(client, parameters)
X_, y_, w_ = generate_array(with_weights=True)
y_ = (y_ * 10).astype(np.int32)
X = dask_cudf.from_dask_dataframe(dd.from_dask_array(X_))
y = dask_cudf.from_dask_dataframe(dd.from_dask_array(y_))
w = dask_cudf.from_dask_dataframe(dd.from_dask_array(w_))
run_dask_classifier(X, y, w, model, "gpu_hist", local_cuda_client, 10)
def test_empty_dmatrix(self, local_cuda_client: Client) -> None:
parameters = {"tree_method": "gpu_hist", "debug_synchronize": True}
run_empty_dmatrix_reg(local_cuda_client, parameters)
run_empty_dmatrix_cls(local_cuda_client, parameters)
@pytest.mark.skipif(**tm.no_dask_cudf())
def test_empty_partition(self, local_cuda_cluster: LocalCUDACluster) -> None:
def test_empty_partition(self, local_cuda_client: Client) -> None:
import dask_cudf
import cudf
import cupy
with Client(local_cuda_cluster) as client:
mult = 100
df = cudf.DataFrame(
{
"a": [1, 2, 3, 4, 5.1] * mult,
"b": [10, 15, 29.3, 30, 31] * mult,
"y": [10, 20, 30, 40., 50] * mult,
}
)
parameters = {"tree_method": "gpu_hist", "debug_synchronize": True}
empty = df.iloc[:0]
ddf = dask_cudf.concat(
[dask_cudf.from_cudf(empty, npartitions=1)]
+ [dask_cudf.from_cudf(df, npartitions=3)]
+ [dask_cudf.from_cudf(df, npartitions=3)]
)
X = ddf[ddf.columns.difference(["y"])]
y = ddf[["y"]]
dtrain = dxgb.DaskDeviceQuantileDMatrix(client, X, y)
bst_empty = xgb.dask.train(
client, parameters, dtrain, evals=[(dtrain, "train")]
)
predt_empty = dxgb.predict(client, bst_empty, X).compute().values
mult = 100
df = cudf.DataFrame(
{
"a": [1, 2, 3, 4, 5.1] * mult,
"b": [10, 15, 29.3, 30, 31] * mult,
"y": [10, 20, 30, 40.0, 50] * mult,
}
)
parameters = {"tree_method": "gpu_hist", "debug_synchronize": True}
ddf = dask_cudf.concat(
[dask_cudf.from_cudf(df, npartitions=3)]
+ [dask_cudf.from_cudf(df, npartitions=3)]
)
X = ddf[ddf.columns.difference(["y"])]
y = ddf[["y"]]
dtrain = dxgb.DaskDeviceQuantileDMatrix(client, X, y)
bst = xgb.dask.train(client, parameters, dtrain, evals=[(dtrain, "train")])
empty = df.iloc[:0]
ddf = dask_cudf.concat(
[dask_cudf.from_cudf(empty, npartitions=1)]
+ [dask_cudf.from_cudf(df, npartitions=3)]
+ [dask_cudf.from_cudf(df, npartitions=3)]
)
X = ddf[ddf.columns.difference(["y"])]
y = ddf[["y"]]
dtrain = dxgb.DaskDeviceQuantileDMatrix(local_cuda_client, X, y)
bst_empty = xgb.dask.train(
local_cuda_client, parameters, dtrain, evals=[(dtrain, "train")]
)
predt_empty = dxgb.predict(local_cuda_client, bst_empty, X).compute().values
predt = dxgb.predict(client, bst, X).compute().values
cupy.testing.assert_allclose(predt, predt_empty)
ddf = dask_cudf.concat(
[dask_cudf.from_cudf(df, npartitions=3)]
+ [dask_cudf.from_cudf(df, npartitions=3)]
)
X = ddf[ddf.columns.difference(["y"])]
y = ddf[["y"]]
dtrain = dxgb.DaskDeviceQuantileDMatrix(local_cuda_client, X, y)
bst = xgb.dask.train(
local_cuda_client, parameters, dtrain, evals=[(dtrain, "train")]
)
predt = dxgb.predict(client, bst, dtrain).compute()
cupy.testing.assert_allclose(predt, predt_empty)
predt = dxgb.predict(local_cuda_client, bst, X).compute().values
cupy.testing.assert_allclose(predt, predt_empty)
predt = dxgb.inplace_predict(client, bst, X).compute().values
cupy.testing.assert_allclose(predt, predt_empty)
predt = dxgb.predict(local_cuda_client, bst, dtrain).compute()
cupy.testing.assert_allclose(predt, predt_empty)
df = df.to_pandas()
empty = df.iloc[:0]
ddf = dd.concat(
[dd.from_pandas(empty, npartitions=1)]
+ [dd.from_pandas(df, npartitions=3)]
+ [dd.from_pandas(df, npartitions=3)]
)
X = ddf[ddf.columns.difference(["y"])]
y = ddf[["y"]]
predt = dxgb.inplace_predict(local_cuda_client, bst, X).compute().values
cupy.testing.assert_allclose(predt, predt_empty)
predt_empty = cupy.asnumpy(predt_empty)
df = df.to_pandas()
empty = df.iloc[:0]
ddf = dd.concat(
[dd.from_pandas(empty, npartitions=1)]
+ [dd.from_pandas(df, npartitions=3)]
+ [dd.from_pandas(df, npartitions=3)]
)
X = ddf[ddf.columns.difference(["y"])]
y = ddf[["y"]]
predt = dxgb.predict(client, bst_empty, X).compute().values
np.testing.assert_allclose(predt, predt_empty)
predt_empty = cupy.asnumpy(predt_empty)
in_predt = dxgb.inplace_predict(client, bst_empty, X).compute().values
np.testing.assert_allclose(predt, in_predt)
predt = dxgb.predict(local_cuda_client, bst_empty, X).compute().values
np.testing.assert_allclose(predt, predt_empty)
def test_empty_dmatrix_auc(self, local_cuda_cluster: LocalCUDACluster) -> None:
with Client(local_cuda_cluster) as client:
n_workers = len(_get_client_workers(client))
run_empty_dmatrix_auc(client, "gpu_hist", n_workers)
in_predt = (
dxgb.inplace_predict(local_cuda_client, bst_empty, X).compute().values
)
np.testing.assert_allclose(predt, in_predt)
def test_auc(self, local_cuda_cluster: LocalCUDACluster) -> None:
with Client(local_cuda_cluster) as client:
run_auc(client, "gpu_hist")
def test_empty_dmatrix_auc(self, local_cuda_client: Client) -> None:
n_workers = len(_get_client_workers(local_cuda_client))
run_empty_dmatrix_auc(local_cuda_client, "gpu_hist", n_workers)
def test_data_initialization(self, local_cuda_cluster: LocalCUDACluster) -> None:
with Client(local_cuda_cluster) as client:
X, y, _ = generate_array()
fw = da.random.random((random_cols, ))
fw = fw - fw.min()
m = dxgb.DaskDMatrix(client, X, y, feature_weights=fw)
def test_auc(self, local_cuda_client: Client) -> None:
run_auc(local_cuda_client, "gpu_hist")
workers = _get_client_workers(client)
rabit_args = client.sync(dxgb._get_rabit_args, len(workers), None, client)
def test_data_initialization(self, local_cuda_client: Client) -> None:
def worker_fn(worker_addr: str, data_ref: Dict) -> None:
with dxgb.RabitContext(rabit_args):
local_dtrain = dxgb._dmatrix_from_list_of_parts(**data_ref, nthread=7)
fw_rows = local_dtrain.get_float_info("feature_weights").shape[0]
assert fw_rows == local_dtrain.num_col()
X, y, _ = generate_array()
fw = da.random.random((random_cols,))
fw = fw - fw.min()
m = dxgb.DaskDMatrix(local_cuda_client, X, y, feature_weights=fw)
futures = []
for i in range(len(workers)):
futures.append(
client.submit(
worker_fn,
workers[i],
m._create_fn_args(workers[i]),
pure=False,
workers=[workers[i]]
)
workers = _get_client_workers(local_cuda_client)
rabit_args = local_cuda_client.sync(
dxgb._get_rabit_args, len(workers), None, local_cuda_client
)
def worker_fn(worker_addr: str, data_ref: Dict) -> None:
with dxgb.RabitContext(rabit_args):
local_dtrain = dxgb._dmatrix_from_list_of_parts(**data_ref, nthread=7)
fw_rows = local_dtrain.get_float_info("feature_weights").shape[0]
assert fw_rows == local_dtrain.num_col()
futures = []
for i in range(len(workers)):
futures.append(
local_cuda_client.submit(
worker_fn,
workers[i],
m._create_fn_args(workers[i]),
pure=False,
workers=[workers[i]],
)
client.gather(futures)
)
local_cuda_client.gather(futures)
def test_interface_consistency(self) -> None:
sig = OrderedDict(signature(dxgb.DaskDMatrix).parameters)
@ -441,7 +459,7 @@ class TestDistributedGPU:
assert ddm_names[i] == ddqdm_names[i]
sig = OrderedDict(signature(xgb.DMatrix).parameters)
del sig["nthread"] # no nthread in dask
del sig["nthread"] # no nthread in dask
dm_names = list(sig.keys())
sig = OrderedDict(signature(xgb.QuantileDMatrix).parameters)
del sig["nthread"]
@ -470,81 +488,79 @@ class TestDistributedGPU:
for rn, drn in zip(ranker_names, dranker_names):
assert rn == drn
def test_tree_stats(self) -> None:
with LocalCUDACluster(n_workers=1) as cluster:
with Client(cluster) as client:
local = run_tree_stats(client, "gpu_hist")
with LocalCUDACluster(n_workers=2) as cluster:
with Client(cluster) as client:
distributed = run_tree_stats(client, "gpu_hist")
assert local == distributed
def run_quantile(self, name: str, local_cuda_cluster: LocalCUDACluster) -> None:
def run_quantile(self, name: str, local_cuda_client: Client) -> None:
if sys.platform.startswith("win"):
pytest.skip("Skipping dask tests on Windows")
exe = None
for possible_path in {'./testxgboost', './build/testxgboost',
'../build/testxgboost', '../gpu-build/testxgboost'}:
for possible_path in {
"./testxgboost",
"./build/testxgboost",
"../build/testxgboost",
"../gpu-build/testxgboost",
}:
if os.path.exists(possible_path):
exe = possible_path
assert exe, 'No testxgboost executable found.'
assert exe, "No testxgboost executable found."
test = "--gtest_filter=GPUQuantile." + name
def runit(
worker_addr: str, rabit_args: List[bytes]
) -> subprocess.CompletedProcess:
port_env = ''
port_env = ""
# setup environment for running the c++ part.
for arg in rabit_args:
if arg.decode('utf-8').startswith('DMLC_TRACKER_PORT'):
port_env = arg.decode('utf-8')
if arg.decode("utf-8").startswith("DMLC_TRACKER_PORT"):
port_env = arg.decode("utf-8")
if arg.decode("utf-8").startswith("DMLC_TRACKER_URI"):
uri_env = arg.decode("utf-8")
port = port_env.split('=')
port = port_env.split("=")
env = os.environ.copy()
env[port[0]] = port[1]
uri = uri_env.split("=")
env[uri[0]] = uri[1]
return subprocess.run([str(exe), test], env=env, stdout=subprocess.PIPE)
with Client(local_cuda_cluster) as client:
workers = _get_client_workers(client)
rabit_args = client.sync(dxgb._get_rabit_args, len(workers), None, client)
futures = client.map(runit,
workers,
pure=False,
workers=workers,
rabit_args=rabit_args)
results = client.gather(futures)
for ret in results:
msg = ret.stdout.decode('utf-8')
assert msg.find('1 test from GPUQuantile') != -1, msg
assert ret.returncode == 0, msg
workers = _get_client_workers(local_cuda_client)
rabit_args = local_cuda_client.sync(
dxgb._get_rabit_args, len(workers), None, local_cuda_client
)
futures = local_cuda_client.map(
runit, workers, pure=False, workers=workers, rabit_args=rabit_args
)
results = local_cuda_client.gather(futures)
for ret in results:
msg = ret.stdout.decode("utf-8")
assert msg.find("1 test from GPUQuantile") != -1, msg
assert ret.returncode == 0, msg
@pytest.mark.gtest
def test_quantile_basic(self, local_cuda_cluster: LocalCUDACluster) -> None:
self.run_quantile('AllReduceBasic', local_cuda_cluster)
def test_quantile_basic(self, local_cuda_client: Client) -> None:
self.run_quantile("AllReduceBasic", local_cuda_client)
@pytest.mark.gtest
def test_quantile_same_on_all_workers(
self, local_cuda_cluster: LocalCUDACluster
) -> None:
self.run_quantile('SameOnAllWorkers', local_cuda_cluster)
def test_quantile_same_on_all_workers(self, local_cuda_client: Client) -> None:
self.run_quantile("SameOnAllWorkers", local_cuda_client)
@pytest.mark.skipif(**tm.no_cupy())
def test_with_asyncio(local_cuda_client: Client) -> None:
address = local_cuda_client.scheduler.address
output = asyncio.run(run_from_dask_array_asyncio(address))
assert isinstance(output["booster"], xgboost.Booster)
assert isinstance(output["history"], dict)
async def run_from_dask_array_asyncio(scheduler_address: str) -> dxgb.TrainReturnT:
async with Client(scheduler_address, asynchronous=True) as client:
import cupy as cp
X, y, _ = generate_array()
X = X.map_blocks(cp.array)
y = y.map_blocks(cp.array)
m = await xgboost.dask.DaskDeviceQuantileDMatrix(client, X, y)
output = await xgboost.dask.train(client, {'tree_method': 'gpu_hist'},
dtrain=m)
output = await xgboost.dask.train(client, {"tree_method": "gpu_hist"}, dtrain=m)
with_m = await xgboost.dask.predict(client, output, m)
with_X = await xgboost.dask.predict(client, output, X)
@ -553,19 +569,12 @@ async def run_from_dask_array_asyncio(scheduler_address: str) -> dxgb.TrainRetur
assert isinstance(with_X, da.Array)
assert isinstance(inplace, da.Array)
cp.testing.assert_allclose(await client.compute(with_m),
await client.compute(with_X))
cp.testing.assert_allclose(await client.compute(with_m),
await client.compute(inplace))
cp.testing.assert_allclose(
await client.compute(with_m), await client.compute(with_X)
)
cp.testing.assert_allclose(
await client.compute(with_m), await client.compute(inplace)
)
client.shutdown()
return output
@pytest.mark.skipif(**tm.no_cupy())
def test_with_asyncio(local_cuda_cluster: LocalCUDACluster) -> None:
with Client(local_cuda_cluster) as client:
address = client.scheduler.address
output = asyncio.run(run_from_dask_array_asyncio(address))
assert isinstance(output['booster'], xgboost.Booster)
assert isinstance(output['history'], dict)