Clean up MGPU C++ tests (#9430)
This commit is contained in:
@@ -5,7 +5,7 @@
|
||||
|
||||
namespace xgboost {
|
||||
TEST(Plugin, ExampleObjective) {
|
||||
xgboost::Context ctx = MakeCUDACtx(GPUIDX);
|
||||
xgboost::Context ctx = MakeCUDACtx(GetGPUId());
|
||||
auto* obj = xgboost::ObjFunction::Create("mylogistic", &ctx);
|
||||
ASSERT_EQ(obj->DefaultEvalMetric(), std::string{"logloss"});
|
||||
delete obj;
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "../../../src/collective/communicator-inl.cuh"
|
||||
#include "../../../src/collective/device_communicator_adapter.cuh"
|
||||
#include "./helpers.h"
|
||||
#include "../helpers.h"
|
||||
|
||||
namespace xgboost::collective {
|
||||
|
||||
@@ -26,10 +27,12 @@ namespace {
|
||||
void VerifyAllReduceSum() {
|
||||
auto const world_size = collective::GetWorldSize();
|
||||
auto const rank = collective::GetRank();
|
||||
auto const device = GetGPUId();
|
||||
int count = 3;
|
||||
common::SetDevice(device);
|
||||
thrust::device_vector<double> buffer(count, 0);
|
||||
thrust::sequence(buffer.begin(), buffer.end());
|
||||
collective::AllReduce<collective::Operation::kSum>(rank, buffer.data().get(), count);
|
||||
collective::AllReduce<collective::Operation::kSum>(device, buffer.data().get(), count);
|
||||
thrust::host_vector<double> host_buffer = buffer;
|
||||
EXPECT_EQ(host_buffer.size(), count);
|
||||
for (auto i = 0; i < count; i++) {
|
||||
@@ -39,10 +42,6 @@ void VerifyAllReduceSum() {
|
||||
} // anonymous namespace
|
||||
|
||||
TEST_F(FederatedAdapterTest, MGPUAllReduceSum) {
|
||||
auto const n_gpus = common::AllVisibleGPUs();
|
||||
if (n_gpus <= 1) {
|
||||
GTEST_SKIP() << "Skipping MGPUAllReduceSum test with # GPUs = " << n_gpus;
|
||||
}
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyAllReduceSum);
|
||||
}
|
||||
|
||||
@@ -50,13 +49,15 @@ namespace {
|
||||
void VerifyAllGatherV() {
|
||||
auto const world_size = collective::GetWorldSize();
|
||||
auto const rank = collective::GetRank();
|
||||
auto const device = GetGPUId();
|
||||
int const count = rank + 2;
|
||||
common::SetDevice(device);
|
||||
thrust::device_vector<char> buffer(count, 0);
|
||||
thrust::sequence(buffer.begin(), buffer.end());
|
||||
std::vector<std::size_t> segments(world_size);
|
||||
dh::caching_device_vector<char> receive_buffer{};
|
||||
|
||||
collective::AllGatherV(rank, buffer.data().get(), count, &segments, &receive_buffer);
|
||||
collective::AllGatherV(device, buffer.data().get(), count, &segments, &receive_buffer);
|
||||
|
||||
EXPECT_EQ(segments[0], 2);
|
||||
EXPECT_EQ(segments[1], 3);
|
||||
@@ -70,11 +71,6 @@ void VerifyAllGatherV() {
|
||||
} // anonymous namespace
|
||||
|
||||
TEST_F(FederatedAdapterTest, MGPUAllGatherV) {
|
||||
auto const n_gpus = common::AllVisibleGPUs();
|
||||
if (n_gpus <= 1) {
|
||||
GTEST_SKIP() << "Skipping MGPUAllGatherV test with # GPUs = " << n_gpus;
|
||||
}
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyAllGatherV);
|
||||
}
|
||||
|
||||
} // namespace xgboost::collective
|
||||
|
||||
Reference in New Issue
Block a user