Clean up MGPU C++ tests (#9430)

This commit is contained in:
Rong Ou
2023-08-01 23:31:18 -07:00
committed by GitHub
parent a9da2e244a
commit c2b85ab68a
28 changed files with 200 additions and 194 deletions

View File

@@ -5,7 +5,7 @@
namespace xgboost {
TEST(Plugin, ExampleObjective) {
xgboost::Context ctx = MakeCUDACtx(GPUIDX);
xgboost::Context ctx = MakeCUDACtx(GetGPUId());
auto* obj = xgboost::ObjFunction::Create("mylogistic", &ctx);
ASSERT_EQ(obj->DefaultEvalMetric(), std::string{"logloss"});
delete obj;

View File

@@ -12,6 +12,7 @@
#include "../../../src/collective/communicator-inl.cuh"
#include "../../../src/collective/device_communicator_adapter.cuh"
#include "./helpers.h"
#include "../helpers.h"
namespace xgboost::collective {
@@ -26,10 +27,12 @@ namespace {
void VerifyAllReduceSum() {
auto const world_size = collective::GetWorldSize();
auto const rank = collective::GetRank();
auto const device = GetGPUId();
int count = 3;
common::SetDevice(device);
thrust::device_vector<double> buffer(count, 0);
thrust::sequence(buffer.begin(), buffer.end());
collective::AllReduce<collective::Operation::kSum>(rank, buffer.data().get(), count);
collective::AllReduce<collective::Operation::kSum>(device, buffer.data().get(), count);
thrust::host_vector<double> host_buffer = buffer;
EXPECT_EQ(host_buffer.size(), count);
for (auto i = 0; i < count; i++) {
@@ -39,10 +42,6 @@ void VerifyAllReduceSum() {
} // anonymous namespace
TEST_F(FederatedAdapterTest, MGPUAllReduceSum) {
auto const n_gpus = common::AllVisibleGPUs();
if (n_gpus <= 1) {
GTEST_SKIP() << "Skipping MGPUAllReduceSum test with # GPUs = " << n_gpus;
}
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyAllReduceSum);
}
@@ -50,13 +49,15 @@ namespace {
void VerifyAllGatherV() {
auto const world_size = collective::GetWorldSize();
auto const rank = collective::GetRank();
auto const device = GetGPUId();
int const count = rank + 2;
common::SetDevice(device);
thrust::device_vector<char> buffer(count, 0);
thrust::sequence(buffer.begin(), buffer.end());
std::vector<std::size_t> segments(world_size);
dh::caching_device_vector<char> receive_buffer{};
collective::AllGatherV(rank, buffer.data().get(), count, &segments, &receive_buffer);
collective::AllGatherV(device, buffer.data().get(), count, &segments, &receive_buffer);
EXPECT_EQ(segments[0], 2);
EXPECT_EQ(segments[1], 3);
@@ -70,11 +71,6 @@ void VerifyAllGatherV() {
} // anonymous namespace
TEST_F(FederatedAdapterTest, MGPUAllGatherV) {
auto const n_gpus = common::AllVisibleGPUs();
if (n_gpus <= 1) {
GTEST_SKIP() << "Skipping MGPUAllGatherV test with # GPUs = " << n_gpus;
}
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyAllGatherV);
}
} // namespace xgboost::collective