Remove MGPU cpp tests. (#8276)

Co-authored-by: Hyunsu Philip Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Jiaming Yuan 2022-09-27 21:18:23 +08:00 committed by GitHub
parent fcab51aa82
commit 6d1452074a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 52 additions and 104 deletions

View File

@ -1,11 +1,17 @@
/*! /*!
* Copyright 2018 XGBoost contributors * Copyright 2018-2022 XGBoost contributors
*/ */
#include "common.h" #include "common.h"
namespace xgboost { namespace xgboost {
namespace common { namespace common {
void SetDevice(std::int32_t device) {
if (device >= 0) {
dh::safe_cuda(cudaSetDevice(device));
}
}
int AllVisibleGPUs() { int AllVisibleGPUs() {
int n_visgpus = 0; int n_visgpus = 0;
try { try {

View File

@ -246,6 +246,16 @@ inline void AssertOneAPISupport() {
#endif // XGBOOST_USE_ONEAPI #endif // XGBOOST_USE_ONEAPI
} }
void SetDevice(std::int32_t device);
#if !defined(XGBOOST_USE_CUDA)
inline void SetDevice(std::int32_t device) {
if (device >= 0) {
AssertGPUSupport();
}
}
#endif
template <typename Idx, typename Container, template <typename Idx, typename Container,
typename V = typename Container::value_type, typename V = typename Container::value_type,
typename Comp = std::less<V>> typename Comp = std::less<V>>

View File

@ -327,6 +327,8 @@ void GenericParameter::ConfigureGpuId(bool require_gpu) {
// Just set it to CPU, don't think about it. // Just set it to CPU, don't think about it.
this->UpdateAllowUnknown(Args{{"gpu_id", std::to_string(kCpuId)}}); this->UpdateAllowUnknown(Args{{"gpu_id", std::to_string(kCpuId)}});
#endif // defined(XGBOOST_USE_CUDA) #endif // defined(XGBOOST_USE_CUDA)
common::SetDevice(this->gpu_id);
} }
int32_t GenericParameter::Threads() const { int32_t GenericParameter::Threads() const {

View File

@ -78,7 +78,7 @@ steps:
command: "tests/buildkite/test-cpp-gpu.sh" command: "tests/buildkite/test-cpp-gpu.sh"
key: test-cpp-gpu key: test-cpp-gpu
agents: agents:
queue: linux-amd64-mgpu queue: linux-amd64-gpu
- label: ":console: Run integration tests with JVM packages" - label: ":console: Run integration tests with JVM packages"
command: "tests/buildkite/test-integration-jvm-packages.sh" command: "tests/buildkite/test-integration-jvm-packages.sh"
key: test-integration-jvm-packages key: test-integration-jvm-packages

View File

@ -11,13 +11,14 @@
namespace xgboost { namespace xgboost {
namespace common { namespace common {
namespace {
void SetDevice(int device) { void SetDeviceForTest(int device) {
int n_devices; int n_devices;
dh::safe_cuda(cudaGetDeviceCount(&n_devices)); dh::safe_cuda(cudaGetDeviceCount(&n_devices));
device %= n_devices; device %= n_devices;
dh::safe_cuda(cudaSetDevice(device)); dh::safe_cuda(cudaSetDevice(device));
} }
} // namespace
struct HostDeviceVectorSetDeviceHandler { struct HostDeviceVectorSetDeviceHandler {
template <typename Functor> template <typename Functor>
@ -57,7 +58,7 @@ void InitHostDeviceVector(size_t n, int device, HostDeviceVector<int> *v) {
void PlusOne(HostDeviceVector<int> *v) { void PlusOne(HostDeviceVector<int> *v) {
int device = v->DeviceIdx(); int device = v->DeviceIdx();
SetDevice(device); SetDeviceForTest(device);
thrust::transform(dh::tcbegin(*v), dh::tcend(*v), dh::tbegin(*v), thrust::transform(dh::tcbegin(*v), dh::tcend(*v), dh::tbegin(*v),
[=]__device__(unsigned int a){ return a + 1; }); [=]__device__(unsigned int a){ return a + 1; });
ASSERT_TRUE(v->DeviceCanWrite()); ASSERT_TRUE(v->DeviceCanWrite());
@ -68,7 +69,7 @@ void CheckDevice(HostDeviceVector<int>* v,
unsigned int first, unsigned int first,
GPUAccess access) { GPUAccess access) {
ASSERT_EQ(v->Size(), size); ASSERT_EQ(v->Size(), size);
SetDevice(v->DeviceIdx()); SetDeviceForTest(v->DeviceIdx());
ASSERT_TRUE(thrust::equal(dh::tcbegin(*v), dh::tcend(*v), ASSERT_TRUE(thrust::equal(dh::tcbegin(*v), dh::tcend(*v),
thrust::make_counting_iterator(first))); thrust::make_counting_iterator(first)));
@ -182,16 +183,5 @@ TEST(HostDeviceVector, Empty) {
ASSERT_FALSE(another.Empty()); ASSERT_FALSE(another.Empty());
ASSERT_TRUE(vec.Empty()); ASSERT_TRUE(vec.Empty());
} }
TEST(HostDeviceVector, MGPU_Basic) { // NOLINT
if (AllVisibleGPUs() < 2) {
LOG(WARNING) << "Not testing in multi-gpu environment.";
return;
}
size_t n = 1001;
int device = 1;
TestHostDeviceVector(n, device);
}
} // namespace common } // namespace common
} // namespace xgboost } // namespace xgboost

View File

@ -1,35 +0,0 @@
/*!
* Copyright 2018-2022 by XGBoost Contributors
* \brief This converts all tests from CPU to GPU.
*/
#include "test_transform_range.cc"
namespace xgboost {
namespace common {
TEST(Transform, MGPU_SpecifiedGpuId) { // NOLINT
if (AllVisibleGPUs() < 2) {
LOG(WARNING) << "Not testing in multi-gpu environment.";
return;
}
// Use 1 GPU, Numbering of GPU starts from 1
auto device = 1;
auto const size {256};
std::vector<bst_float> h_in(size);
std::vector<bst_float> h_out(size);
std::iota(h_in.begin(), h_in.end(), 0);
std::vector<bst_float> h_sol(size);
std::iota(h_sol.begin(), h_sol.end(), 0);
const HostDeviceVector<bst_float> in_vec {h_in, device};
HostDeviceVector<bst_float> out_vec {h_out, device};
ASSERT_NO_THROW(Transform<>::Init(TestTransformRange<bst_float>{}, Range{0, size},
common::OmpGetNumThreads(0), device)
.Eval(&out_vec, &in_vec));
std::vector<bst_float> res = out_vec.HostVector();
ASSERT_TRUE(std::equal(h_sol.begin(), h_sol.end(), res.begin()));
}
} // namespace common
} // namespace xgboost

View File

@ -84,29 +84,3 @@ TEST(Metric, DeclareUnifiedTest(MultiClassLogLoss)) {
TestMultiClassLogLoss(GPUIDX); TestMultiClassLogLoss(GPUIDX);
xgboost::CheckDeterministicMetricMultiClass(xgboost::StringView{"mlogloss"}, GPUIDX); xgboost::CheckDeterministicMetricMultiClass(xgboost::StringView{"mlogloss"}, GPUIDX);
} }
#if defined(__CUDACC__)
namespace xgboost {
namespace common {
TEST(Metric, MGPU_MultiClassError) {
if (AllVisibleGPUs() < 2) {
LOG(WARNING) << "Not testing in multi-gpu environment.";
return;
}
{
TestMultiClassError(0);
}
{
TestMultiClassError(1);
}
{
TestMultiClassLogLoss(0);
}
{
TestMultiClassLogLoss(1);
}
}
} // namespace common
} // namespace xgboost
#endif // defined(__CUDACC__)

View File

@ -172,7 +172,7 @@ TEST(CpuPredictor, InplacePredict) {
std::string arr_str; std::string arr_str;
Json::Dump(array_interface, &arr_str); Json::Dump(array_interface, &arr_str);
x->SetArrayData(arr_str.data()); x->SetArrayData(arr_str.data());
TestInplacePrediction(x, "cpu_predictor", kRows, kCols, -1); TestInplacePrediction(x, "cpu_predictor", kRows, kCols, Context::kCpuId);
} }
{ {
@ -189,7 +189,7 @@ TEST(CpuPredictor, InplacePredict) {
Json::Dump(col_interface, &col_str); Json::Dump(col_interface, &col_str);
std::shared_ptr<data::DMatrixProxy> x{new data::DMatrixProxy}; std::shared_ptr<data::DMatrixProxy> x{new data::DMatrixProxy};
x->SetCSRData(rptr_str.data(), col_str.data(), data_str.data(), kCols, true); x->SetCSRData(rptr_str.data(), col_str.data(), data_str.data(), kCols, true);
TestInplacePrediction(x, "cpu_predictor", kRows, kCols, -1); TestInplacePrediction(x, "cpu_predictor", kRows, kCols, Context::kCpuId);
} }
} }

View File

@ -140,26 +140,10 @@ TEST(GPUPredictor, InplacePredictCuDF) {
TestInplacePrediction(p_fmat, "gpu_predictor", kRows, kCols, 0); TestInplacePrediction(p_fmat, "gpu_predictor", kRows, kCols, 0);
} }
TEST(GPUPredictor, MGPU_InplacePredict) { // NOLINT
int32_t n_gpus = xgboost::common::AllVisibleGPUs();
if (n_gpus <= 1) {
LOG(WARNING) << "GPUPredictor.MGPU_InplacePredict is skipped.";
return;
}
size_t constexpr kRows{128}, kCols{64};
RandomDataGenerator gen(kRows, kCols, 0.5);
gen.Device(1);
HostDeviceVector<float> data;
std::string interface_str = gen.GenerateArrayInterface(&data);
std::shared_ptr<DMatrix> p_fmat{new data::DMatrixProxy};
dynamic_cast<data::DMatrixProxy*>(p_fmat.get())->SetCUDAArray(interface_str.c_str());
TestInplacePrediction(p_fmat, "gpu_predictor", kRows, kCols, 1);
EXPECT_THROW(TestInplacePrediction(p_fmat, "gpu_predictor", kRows, kCols, 0), dmlc::Error);
}
TEST(GpuPredictor, LesserFeatures) { TEST(GpuPredictor, LesserFeatures) {
TestPredictionWithLesserFeatures("gpu_predictor"); TestPredictionWithLesserFeatures("gpu_predictor");
} }
// Very basic test of empty model // Very basic test of empty model
TEST(GPUPredictor, ShapStump) { TEST(GPUPredictor, ShapStump) {
cudaSetDevice(0); cudaSetDevice(0);

View File

@ -148,10 +148,9 @@ class TestGPUPredict:
from_dmatrix = booster.predict(dtrain) from_dmatrix = booster.predict(dtrain)
cp.testing.assert_allclose(from_inplace, from_dmatrix) cp.testing.assert_allclose(from_inplace, from_dmatrix)
@pytest.mark.skipif(**tm.no_cupy()) def run_inplace_predict_cupy(self, device: int) -> None:
def test_inplace_predict_cupy(self):
import cupy as cp import cupy as cp
cp.cuda.runtime.setDevice(0) cp.cuda.runtime.setDevice(device)
rows = 1000 rows = 1000
cols = 10 cols = 10
missing = 11 # set to integer for testing missing = 11 # set to integer for testing
@ -166,15 +165,17 @@ class TestGPUPredict:
dtrain = xgb.DMatrix(X, y) dtrain = xgb.DMatrix(X, y)
booster = xgb.train({'tree_method': 'gpu_hist'}, dtrain, num_boost_round=10) booster = xgb.train(
{'tree_method': 'gpu_hist', "gpu_id": device}, dtrain, num_boost_round=10
)
test = xgb.DMatrix(X[:10, ...], missing=missing) test = xgb.DMatrix(X[:10, ...], missing=missing)
predt_from_array = booster.inplace_predict(X[:10, ...], missing=missing) predt_from_array = booster.inplace_predict(X[:10, ...], missing=missing)
predt_from_dmatrix = booster.predict(test) predt_from_dmatrix = booster.predict(test)
cp.testing.assert_allclose(predt_from_array, predt_from_dmatrix) cp.testing.assert_allclose(predt_from_array, predt_from_dmatrix)
def predict_dense(x): def predict_dense(x):
cp.cuda.runtime.setDevice(device)
inplace_predt = booster.inplace_predict(x) inplace_predt = booster.inplace_predict(x)
d = xgb.DMatrix(x) d = xgb.DMatrix(x)
copied_predt = cp.array(booster.predict(d)) copied_predt = cp.array(booster.predict(d))
@ -183,7 +184,8 @@ class TestGPUPredict:
# Don't do this on Windows, see issue #5793 # Don't do this on Windows, see issue #5793
if sys.platform.startswith("win"): if sys.platform.startswith("win"):
pytest.skip( pytest.skip(
'Multi-threaded in-place prediction with cuPy is not working on Windows') 'Multi-threaded in-place prediction with cuPy is not working on Windows'
)
for i in range(10): for i in range(10):
run_threaded_predict(X, rows, predict_dense) run_threaded_predict(X, rows, predict_dense)
@ -196,13 +198,28 @@ class TestGPUPredict:
missing_idx = [i for i in range(0, X.shape[1], 16)] missing_idx = [i for i in range(0, X.shape[1], 16)]
X[:, missing_idx] = missing X[:, missing_idx] = missing
reg = xgb.XGBRegressor(tree_method="gpu_hist", n_estimators=8, missing=missing) reg = xgb.XGBRegressor(
tree_method="gpu_hist", n_estimators=8, missing=missing, gpu_id=device
)
reg.fit(X, y) reg.fit(X, y)
gpu_predt = reg.predict(X) gpu_predt = reg.predict(X)
reg.set_params(predictor="cpu_predictor") reg.set_params(predictor="cpu_predictor")
cpu_predt = reg.predict(X) cpu_predt = reg.predict(X)
np.testing.assert_allclose(gpu_predt, cpu_predt, atol=1e-6) np.testing.assert_allclose(gpu_predt, cpu_predt, atol=1e-6)
cp.cuda.runtime.setDevice(0)
@pytest.mark.skipif(**tm.no_cupy())
def test_inplace_predict_cupy(self):
self.run_inplace_predict_cupy(0)
@pytest.mark.skipif(**tm.no_cupy())
@pytest.mark.mgpu
def test_inplace_predict_cupy_specified_device(self):
import cupy as cp
n_devices = cp.cuda.runtime.getDeviceCount()
for d in range(n_devices):
self.run_inplace_predict_cupy(d)
@pytest.mark.skipif(**tm.no_cupy()) @pytest.mark.skipif(**tm.no_cupy())
@pytest.mark.skipif(**tm.no_cudf()) @pytest.mark.skipif(**tm.no_cudf())