Revamp the rabit implementation. (#10112)
This PR replaces the original RABIT implementation with a new one, which has already been partially merged into XGBoost. The new one features: - Federated learning for both CPU and GPU. - NCCL. - More data types. - A unified interface for all the underlying implementations. - Improved timeout handling for both tracker and workers. - Exhausted tests with metrics (fixed a couple of bugs along the way). - A reusable tracker for Python and JVM packages.
This commit is contained in:
@@ -108,6 +108,32 @@ TEST_F(FederatedCollTestGPU, Allreduce) {
|
||||
});
|
||||
}
|
||||
|
||||
TEST(FederatedCollGPUGlobal, Allreduce) {
|
||||
std::int32_t n_workers = common::AllVisibleGPUs();
|
||||
TestFederatedGlobal(n_workers, [&] {
|
||||
auto r = collective::GetRank();
|
||||
auto world = collective::GetWorldSize();
|
||||
CHECK_EQ(n_workers, world);
|
||||
|
||||
dh::device_vector<std::uint32_t> values(3, r);
|
||||
auto ctx = MakeCUDACtx(r);
|
||||
auto rc = collective::Allreduce(
|
||||
&ctx, linalg::MakeVec(values.data().get(), values.size(), DeviceOrd::CUDA(r)),
|
||||
Op::kBitwiseOR);
|
||||
SafeColl(rc);
|
||||
|
||||
std::vector<std::uint32_t> expected(values.size(), 0);
|
||||
for (std::int32_t rank = 0; rank < world; ++rank) {
|
||||
for (std::size_t i = 0; i < expected.size(); ++i) {
|
||||
expected[i] |= rank;
|
||||
}
|
||||
}
|
||||
for (std::size_t i = 0; i < expected.size(); ++i) {
|
||||
CHECK_EQ(expected[i], values[i]);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(FederatedCollTestGPU, Broadcast) {
|
||||
std::int32_t n_workers = common::AllVisibleGPUs();
|
||||
TestFederated(n_workers, [=](std::shared_ptr<FederatedComm> comm, std::int32_t rank) {
|
||||
|
||||
@@ -11,12 +11,24 @@
|
||||
|
||||
#include "../../../../plugin/federated/federated_tracker.h"
|
||||
#include "../../../../src/collective/comm_group.h"
|
||||
#include "../../../../src/collective/communicator-inl.h"
|
||||
#include "federated_comm.h" // for FederatedComm
|
||||
#include "xgboost/json.h" // for Json
|
||||
|
||||
namespace xgboost::collective {
|
||||
inline Json FederatedTestConfig(std::int32_t n_workers, std::int32_t port, std::int32_t i) {
|
||||
Json config{Object{}};
|
||||
config["dmlc_communicator"] = std::string{"federated"};
|
||||
config["dmlc_task_id"] = std::to_string(i);
|
||||
config["dmlc_retry"] = 2;
|
||||
config["federated_world_size"] = n_workers;
|
||||
config["federated_rank"] = i;
|
||||
config["federated_server_address"] = "0.0.0.0:" + std::to_string(port);
|
||||
return config;
|
||||
}
|
||||
|
||||
template <typename WorkerFn>
|
||||
void TestFederated(std::int32_t n_workers, WorkerFn&& fn) {
|
||||
void TestFederatedImpl(std::int32_t n_workers, WorkerFn&& fn) {
|
||||
Json config{Object()};
|
||||
config["federated_secure"] = Boolean{false};
|
||||
config["n_workers"] = Integer{n_workers};
|
||||
@@ -30,16 +42,7 @@ void TestFederated(std::int32_t n_workers, WorkerFn&& fn) {
|
||||
std::int32_t port = tracker.Port();
|
||||
|
||||
for (std::int32_t i = 0; i < n_workers; ++i) {
|
||||
workers.emplace_back([=] {
|
||||
Json config{Object{}};
|
||||
config["federated_world_size"] = n_workers;
|
||||
config["federated_rank"] = i;
|
||||
config["federated_server_address"] = "0.0.0.0:" + std::to_string(port);
|
||||
auto comm = std::make_shared<FederatedComm>(
|
||||
DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, std::to_string(i), config);
|
||||
|
||||
fn(comm, i);
|
||||
});
|
||||
workers.emplace_back([=] { fn(port, i); });
|
||||
}
|
||||
|
||||
for (auto& t : workers) {
|
||||
@@ -51,39 +54,33 @@ void TestFederated(std::int32_t n_workers, WorkerFn&& fn) {
|
||||
ASSERT_TRUE(fut.get().OK());
|
||||
}
|
||||
|
||||
template <typename WorkerFn>
|
||||
void TestFederated(std::int32_t n_workers, WorkerFn&& fn) {
|
||||
TestFederatedImpl(n_workers, [&](std::int32_t port, std::int32_t i) {
|
||||
auto config = FederatedTestConfig(n_workers, port, i);
|
||||
auto comm = std::make_shared<FederatedComm>(
|
||||
DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, std::to_string(i), config);
|
||||
|
||||
fn(comm, i);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename WorkerFn>
|
||||
void TestFederatedGroup(std::int32_t n_workers, WorkerFn&& fn) {
|
||||
Json config{Object()};
|
||||
config["federated_secure"] = Boolean{false};
|
||||
config["n_workers"] = Integer{n_workers};
|
||||
FederatedTracker tracker{config};
|
||||
auto fut = tracker.Run();
|
||||
TestFederatedImpl(n_workers, [&](std::int32_t port, std::int32_t i) {
|
||||
auto config = FederatedTestConfig(n_workers, port, i);
|
||||
std::shared_ptr<CommGroup> comm_group{CommGroup::Create(config)};
|
||||
fn(comm_group, i);
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<std::thread> workers;
|
||||
auto rc = tracker.WaitUntilReady();
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
std::int32_t port = tracker.Port();
|
||||
|
||||
for (std::int32_t i = 0; i < n_workers; ++i) {
|
||||
workers.emplace_back([=] {
|
||||
Json config{Object{}};
|
||||
config["dmlc_communicator"] = std::string{"federated"};
|
||||
config["dmlc_task_id"] = std::to_string(i);
|
||||
config["dmlc_retry"] = 2;
|
||||
config["federated_world_size"] = n_workers;
|
||||
config["federated_rank"] = i;
|
||||
config["federated_server_address"] = "0.0.0.0:" + std::to_string(port);
|
||||
std::shared_ptr<CommGroup> comm_group{CommGroup::Create(config)};
|
||||
fn(comm_group, i);
|
||||
});
|
||||
}
|
||||
|
||||
for (auto& t : workers) {
|
||||
t.join();
|
||||
}
|
||||
|
||||
rc = tracker.Shutdown();
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
ASSERT_TRUE(fut.get().OK());
|
||||
template <typename WorkerFn>
|
||||
void TestFederatedGlobal(std::int32_t n_workers, WorkerFn&& fn) {
|
||||
TestFederatedImpl(n_workers, [&](std::int32_t port, std::int32_t i) {
|
||||
auto config = FederatedTestConfig(n_workers, port, i);
|
||||
collective::Init(config);
|
||||
fn();
|
||||
collective::Finalize();
|
||||
});
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@@ -1,99 +0,0 @@
|
||||
/**
|
||||
* Copyright 2022-2023, XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <dmlc/omp.h>
|
||||
#include <grpcpp/server_builder.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/json.h>
|
||||
|
||||
#include <random>
|
||||
#include <thread> // for thread, sleep_for
|
||||
|
||||
#include "../../../plugin/federated/federated_server.h"
|
||||
#include "../../../src/collective/communicator-inl.h"
|
||||
#include "../../../src/common/threading_utils.h"
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
class ServerForTest {
|
||||
std::string server_address_;
|
||||
std::unique_ptr<std::thread> server_thread_;
|
||||
std::unique_ptr<grpc::Server> server_;
|
||||
|
||||
public:
|
||||
explicit ServerForTest(std::size_t world_size) {
|
||||
server_thread_.reset(new std::thread([this, world_size] {
|
||||
grpc::ServerBuilder builder;
|
||||
xgboost::federated::FederatedService service{static_cast<std::int32_t>(world_size)};
|
||||
int selected_port;
|
||||
builder.AddListeningPort("localhost:0", grpc::InsecureServerCredentials(), &selected_port);
|
||||
builder.RegisterService(&service);
|
||||
server_ = builder.BuildAndStart();
|
||||
server_address_ = std::string("localhost:") + std::to_string(selected_port);
|
||||
server_->Wait();
|
||||
}));
|
||||
}
|
||||
|
||||
~ServerForTest() {
|
||||
using namespace std::chrono_literals;
|
||||
while (!server_) {
|
||||
std::this_thread::sleep_for(100ms);
|
||||
}
|
||||
server_->Shutdown();
|
||||
while (!server_thread_) {
|
||||
std::this_thread::sleep_for(100ms);
|
||||
}
|
||||
server_thread_->join();
|
||||
}
|
||||
|
||||
auto Address() const {
|
||||
using namespace std::chrono_literals;
|
||||
while (server_address_.empty()) {
|
||||
std::this_thread::sleep_for(100ms);
|
||||
}
|
||||
return server_address_;
|
||||
}
|
||||
};
|
||||
|
||||
class BaseFederatedTest : public ::testing::Test {
|
||||
protected:
|
||||
void SetUp() override { server_ = std::make_unique<ServerForTest>(kWorldSize); }
|
||||
|
||||
void TearDown() override { server_.reset(nullptr); }
|
||||
|
||||
static int constexpr kWorldSize{2};
|
||||
std::unique_ptr<ServerForTest> server_;
|
||||
};
|
||||
|
||||
template <typename Function, typename... Args>
|
||||
void RunWithFederatedCommunicator(int32_t world_size, std::string const& server_address,
|
||||
Function&& function, Args&&... args) {
|
||||
auto run = [&](auto rank) {
|
||||
Json config{JsonObject()};
|
||||
config["xgboost_communicator"] = String("federated");
|
||||
config["federated_secure"] = false;
|
||||
config["federated_server_address"] = String(server_address);
|
||||
config["federated_world_size"] = world_size;
|
||||
config["federated_rank"] = rank;
|
||||
xgboost::collective::Init(config);
|
||||
|
||||
std::forward<Function>(function)(std::forward<Args>(args)...);
|
||||
|
||||
xgboost::collective::Finalize();
|
||||
};
|
||||
#if defined(_OPENMP)
|
||||
common::ParallelFor(world_size, world_size, run);
|
||||
#else
|
||||
std::vector<std::thread> threads;
|
||||
for (auto rank = 0; rank < world_size; rank++) {
|
||||
threads.emplace_back(run, rank);
|
||||
}
|
||||
for (auto& thread : threads) {
|
||||
thread.join();
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
@@ -1,97 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <thrust/host_vector.h>
|
||||
|
||||
#include <ctime>
|
||||
#include <iostream>
|
||||
#include <thread>
|
||||
|
||||
#include "../../../plugin/federated/federated_communicator.h"
|
||||
#include "../../../src/collective/communicator-inl.cuh"
|
||||
#include "../../../src/collective/device_communicator_adapter.cuh"
|
||||
#include "../helpers.h"
|
||||
#include "./helpers.h"
|
||||
|
||||
namespace xgboost::collective {
|
||||
|
||||
class FederatedAdapterTest : public BaseFederatedTest {};
|
||||
|
||||
TEST(FederatedAdapterSimpleTest, ThrowOnInvalidDeviceOrdinal) {
|
||||
auto construct = []() { DeviceCommunicatorAdapter adapter{-1}; };
|
||||
EXPECT_THROW(construct(), dmlc::Error);
|
||||
}
|
||||
|
||||
namespace {
|
||||
void VerifyAllReduceSum() {
|
||||
auto const world_size = collective::GetWorldSize();
|
||||
auto const device = GPUIDX;
|
||||
int count = 3;
|
||||
common::SetDevice(device);
|
||||
thrust::device_vector<double> buffer(count, 0);
|
||||
thrust::sequence(buffer.begin(), buffer.end());
|
||||
collective::AllReduce<collective::Operation::kSum>(device, buffer.data().get(), count);
|
||||
thrust::host_vector<double> host_buffer = buffer;
|
||||
EXPECT_EQ(host_buffer.size(), count);
|
||||
for (auto i = 0; i < count; i++) {
|
||||
EXPECT_EQ(host_buffer[i], i * world_size);
|
||||
}
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
TEST_F(FederatedAdapterTest, MGPUAllReduceSum) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyAllReduceSum);
|
||||
}
|
||||
|
||||
namespace {
|
||||
void VerifyAllGather() {
|
||||
auto const world_size = collective::GetWorldSize();
|
||||
auto const rank = collective::GetRank();
|
||||
auto const device = GPUIDX;
|
||||
common::SetDevice(device);
|
||||
thrust::device_vector<double> send_buffer(1, rank);
|
||||
thrust::device_vector<double> receive_buffer(world_size, 0);
|
||||
collective::AllGather(device, send_buffer.data().get(), receive_buffer.data().get(),
|
||||
sizeof(double));
|
||||
thrust::host_vector<double> host_buffer = receive_buffer;
|
||||
EXPECT_EQ(host_buffer.size(), world_size);
|
||||
for (auto i = 0; i < world_size; i++) {
|
||||
EXPECT_EQ(host_buffer[i], i);
|
||||
}
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
TEST_F(FederatedAdapterTest, MGPUAllGather) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyAllGather);
|
||||
}
|
||||
|
||||
namespace {
|
||||
void VerifyAllGatherV() {
|
||||
auto const world_size = collective::GetWorldSize();
|
||||
auto const rank = collective::GetRank();
|
||||
auto const device = GPUIDX;
|
||||
int const count = rank + 2;
|
||||
common::SetDevice(device);
|
||||
thrust::device_vector<char> buffer(count, 0);
|
||||
thrust::sequence(buffer.begin(), buffer.end());
|
||||
std::vector<std::size_t> segments(world_size);
|
||||
dh::caching_device_vector<char> receive_buffer{};
|
||||
|
||||
collective::AllGatherV(device, buffer.data().get(), count, &segments, &receive_buffer);
|
||||
|
||||
EXPECT_EQ(segments[0], 2);
|
||||
EXPECT_EQ(segments[1], 3);
|
||||
thrust::host_vector<char> host_buffer = receive_buffer;
|
||||
EXPECT_EQ(host_buffer.size(), 5);
|
||||
int expected[] = {0, 1, 0, 1, 2};
|
||||
for (auto i = 0; i < 5; i++) {
|
||||
EXPECT_EQ(host_buffer[i], expected[i]);
|
||||
}
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
TEST_F(FederatedAdapterTest, MGPUAllGatherV) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyAllGatherV);
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
@@ -1,161 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#include <dmlc/parameter.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <thread>
|
||||
|
||||
#include "../../../plugin/federated/federated_communicator.h"
|
||||
#include "helpers.h"
|
||||
|
||||
namespace xgboost::collective {
|
||||
|
||||
class FederatedCommunicatorTest : public BaseFederatedTest {
|
||||
public:
|
||||
static void VerifyAllgather(int rank, const std::string &server_address) {
|
||||
FederatedCommunicator comm{kWorldSize, rank, server_address};
|
||||
CheckAllgather(comm, rank);
|
||||
}
|
||||
|
||||
static void VerifyAllgatherV(int rank, const std::string &server_address) {
|
||||
FederatedCommunicator comm{kWorldSize, rank, server_address};
|
||||
CheckAllgatherV(comm, rank);
|
||||
}
|
||||
|
||||
static void VerifyAllreduce(int rank, const std::string &server_address) {
|
||||
FederatedCommunicator comm{kWorldSize, rank, server_address};
|
||||
CheckAllreduce(comm);
|
||||
}
|
||||
|
||||
static void VerifyBroadcast(int rank, const std::string &server_address) {
|
||||
FederatedCommunicator comm{kWorldSize, rank, server_address};
|
||||
CheckBroadcast(comm, rank);
|
||||
}
|
||||
|
||||
protected:
|
||||
static void CheckAllgather(FederatedCommunicator &comm, int rank) {
|
||||
std::string input{static_cast<char>('0' + rank)};
|
||||
auto output = comm.AllGather(input);
|
||||
for (auto i = 0; i < kWorldSize; i++) {
|
||||
EXPECT_EQ(output[i], static_cast<char>('0' + i));
|
||||
}
|
||||
}
|
||||
|
||||
static void CheckAllgatherV(FederatedCommunicator &comm, int rank) {
|
||||
std::vector<std::string_view> inputs{"Federated", " Learning!!!"};
|
||||
auto output = comm.AllGatherV(inputs[rank]);
|
||||
EXPECT_EQ(output, "Federated Learning!!!");
|
||||
}
|
||||
|
||||
static void CheckAllreduce(FederatedCommunicator &comm) {
|
||||
int buffer[] = {1, 2, 3, 4, 5};
|
||||
comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt32, Operation::kSum);
|
||||
int expected[] = {2, 4, 6, 8, 10};
|
||||
for (auto i = 0; i < 5; i++) {
|
||||
EXPECT_EQ(buffer[i], expected[i]);
|
||||
}
|
||||
}
|
||||
|
||||
static void CheckBroadcast(FederatedCommunicator &comm, int rank) {
|
||||
if (rank == 0) {
|
||||
std::string buffer{"hello"};
|
||||
comm.Broadcast(&buffer[0], buffer.size(), 0);
|
||||
EXPECT_EQ(buffer, "hello");
|
||||
} else {
|
||||
std::string buffer{" "};
|
||||
comm.Broadcast(&buffer[0], buffer.size(), 0);
|
||||
EXPECT_EQ(buffer, "hello");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
TEST(FederatedCommunicatorSimpleTest, ThrowOnWorldSizeTooSmall) {
|
||||
auto construct = [] { FederatedCommunicator comm{0, 0, "localhost:0", "", "", ""}; };
|
||||
EXPECT_THROW(construct(), dmlc::Error);
|
||||
}
|
||||
|
||||
TEST(FederatedCommunicatorSimpleTest, ThrowOnRankTooSmall) {
|
||||
auto construct = [] { FederatedCommunicator comm{1, -1, "localhost:0", "", "", ""}; };
|
||||
EXPECT_THROW(construct(), dmlc::Error);
|
||||
}
|
||||
|
||||
TEST(FederatedCommunicatorSimpleTest, ThrowOnRankTooBig) {
|
||||
auto construct = [] { FederatedCommunicator comm{1, 1, "localhost:0", "", "", ""}; };
|
||||
EXPECT_THROW(construct(), dmlc::Error);
|
||||
}
|
||||
|
||||
TEST(FederatedCommunicatorSimpleTest, ThrowOnWorldSizeNotInteger) {
|
||||
auto construct = [] {
|
||||
Json config{JsonObject()};
|
||||
config["federated_server_address"] = std::string("localhost:0");
|
||||
config["federated_world_size"] = std::string("1");
|
||||
config["federated_rank"] = Integer(0);
|
||||
FederatedCommunicator::Create(config);
|
||||
};
|
||||
EXPECT_THROW(construct(), dmlc::Error);
|
||||
}
|
||||
|
||||
TEST(FederatedCommunicatorSimpleTest, ThrowOnRankNotInteger) {
|
||||
auto construct = [] {
|
||||
Json config{JsonObject()};
|
||||
config["federated_server_address"] = std::string("localhost:0");
|
||||
config["federated_world_size"] = 1;
|
||||
config["federated_rank"] = std::string("0");
|
||||
FederatedCommunicator::Create(config);
|
||||
};
|
||||
EXPECT_THROW(construct(), dmlc::Error);
|
||||
}
|
||||
|
||||
TEST(FederatedCommunicatorSimpleTest, GetWorldSizeAndRank) {
|
||||
FederatedCommunicator comm{6, 3, "localhost:0"};
|
||||
EXPECT_EQ(comm.GetWorldSize(), 6);
|
||||
EXPECT_EQ(comm.GetRank(), 3);
|
||||
}
|
||||
|
||||
TEST(FederatedCommunicatorSimpleTest, IsDistributed) {
|
||||
FederatedCommunicator comm{2, 1, "localhost:0"};
|
||||
EXPECT_TRUE(comm.IsDistributed());
|
||||
}
|
||||
|
||||
TEST_F(FederatedCommunicatorTest, Allgather) {
|
||||
std::vector<std::thread> threads;
|
||||
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||
threads.emplace_back(&FederatedCommunicatorTest::VerifyAllgather, rank, server_->Address());
|
||||
}
|
||||
for (auto &thread : threads) {
|
||||
thread.join();
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(FederatedCommunicatorTest, AllgatherV) {
|
||||
std::vector<std::thread> threads;
|
||||
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||
threads.emplace_back(&FederatedCommunicatorTest::VerifyAllgatherV, rank, server_->Address());
|
||||
}
|
||||
for (auto &thread : threads) {
|
||||
thread.join();
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(FederatedCommunicatorTest, Allreduce) {
|
||||
std::vector<std::thread> threads;
|
||||
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||
threads.emplace_back(&FederatedCommunicatorTest::VerifyAllreduce, rank, server_->Address());
|
||||
}
|
||||
for (auto &thread : threads) {
|
||||
thread.join();
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(FederatedCommunicatorTest, Broadcast) {
|
||||
std::vector<std::thread> threads;
|
||||
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||
threads.emplace_back(&FederatedCommunicatorTest::VerifyBroadcast, rank, server_->Address());
|
||||
}
|
||||
for (auto &thread : threads) {
|
||||
thread.join();
|
||||
}
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
@@ -6,16 +6,13 @@
|
||||
|
||||
#include <thread>
|
||||
|
||||
#include "../../../plugin/federated/federated_server.h"
|
||||
#include "../../../src/collective/communicator-inl.h"
|
||||
#include "../filesystem.h"
|
||||
#include "../helpers.h"
|
||||
#include "helpers.h"
|
||||
#include "federated/test_worker.h"
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
class FederatedDataTest : public BaseFederatedTest {};
|
||||
|
||||
void VerifyLoadUri() {
|
||||
auto const rank = collective::GetRank();
|
||||
|
||||
@@ -47,7 +44,8 @@ void VerifyLoadUri() {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(FederatedDataTest, LoadUri) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyLoadUri);
|
||||
TEST(FederatedDataTest, LoadUri) {
|
||||
static int constexpr kWorldSize{2};
|
||||
collective::TestFederatedGlobal(kWorldSize, [] { VerifyLoadUri(); });
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
@@ -1,17 +1,19 @@
|
||||
/*!
|
||||
* Copyright 2023 XGBoost contributors
|
||||
/**
|
||||
* Copyright 2023-2024, XGBoost contributors
|
||||
*
|
||||
* Some other tests for federated learning are in the main test suite (test_learner.cc),
|
||||
* gaurded by the `XGBOOST_USE_FEDERATED`.
|
||||
*/
|
||||
#include <dmlc/parameter.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/data.h>
|
||||
#include <xgboost/objective.h>
|
||||
|
||||
#include "../../../plugin/federated/federated_server.h"
|
||||
#include "../../../src/collective/communicator-inl.h"
|
||||
#include "../../../src/common/linalg_op.h"
|
||||
#include "../../../src/common/linalg_op.h" // for begin, end
|
||||
#include "../helpers.h"
|
||||
#include "../objective_helpers.h" // for MakeObjNamesForTest, ObjTestNameGenerator
|
||||
#include "helpers.h"
|
||||
#include "federated/test_worker.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace {
|
||||
@@ -36,32 +38,16 @@ auto MakeModel(std::string tree_method, std::string device, std::string objectiv
|
||||
return model;
|
||||
}
|
||||
|
||||
void VerifyObjective(size_t rows, size_t cols, float expected_base_score, Json expected_model,
|
||||
std::string tree_method, std::string device, std::string objective) {
|
||||
auto const world_size = collective::GetWorldSize();
|
||||
auto const rank = collective::GetRank();
|
||||
void VerifyObjective(std::size_t rows, std::size_t cols, float expected_base_score,
|
||||
Json expected_model, std::string const &tree_method, std::string device,
|
||||
std::string const &objective) {
|
||||
auto rank = collective::GetRank();
|
||||
std::shared_ptr<DMatrix> dmat{RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(rank == 0)};
|
||||
|
||||
if (rank == 0) {
|
||||
auto &h_upper = dmat->Info().labels_upper_bound_.HostVector();
|
||||
auto &h_lower = dmat->Info().labels_lower_bound_.HostVector();
|
||||
h_lower.resize(rows);
|
||||
h_upper.resize(rows);
|
||||
for (size_t i = 0; i < rows; ++i) {
|
||||
h_lower[i] = 1;
|
||||
h_upper[i] = 10;
|
||||
}
|
||||
|
||||
if (objective.find("rank:") != std::string::npos) {
|
||||
auto h_label = dmat->Info().labels.HostView();
|
||||
std::size_t k = 0;
|
||||
for (auto &v : h_label) {
|
||||
v = k % 2 == 0;
|
||||
++k;
|
||||
}
|
||||
}
|
||||
MakeLabelForObjTest(dmat, objective);
|
||||
}
|
||||
std::shared_ptr<DMatrix> sliced{dmat->SliceCol(world_size, rank)};
|
||||
std::shared_ptr<DMatrix> sliced{dmat->SliceCol(collective::GetWorldSize(), rank)};
|
||||
|
||||
auto model = MakeModel(tree_method, device, objective, sliced);
|
||||
auto base_score = GetBaseScore(model);
|
||||
@@ -71,18 +57,15 @@ void VerifyObjective(size_t rows, size_t cols, float expected_base_score, Json e
|
||||
} // namespace
|
||||
|
||||
class VerticalFederatedLearnerTest : public ::testing::TestWithParam<std::string> {
|
||||
std::unique_ptr<ServerForTest> server_;
|
||||
static int constexpr kWorldSize{3};
|
||||
|
||||
protected:
|
||||
void SetUp() override { server_ = std::make_unique<ServerForTest>(kWorldSize); }
|
||||
void TearDown() override { server_.reset(nullptr); }
|
||||
|
||||
void Run(std::string tree_method, std::string device, std::string objective) {
|
||||
static auto constexpr kRows{16};
|
||||
static auto constexpr kCols{16};
|
||||
|
||||
std::shared_ptr<DMatrix> dmat{RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true)};
|
||||
MakeLabelForObjTest(dmat, objective);
|
||||
|
||||
auto &h_upper = dmat->Info().labels_upper_bound_.HostVector();
|
||||
auto &h_lower = dmat->Info().labels_lower_bound_.HostVector();
|
||||
@@ -103,9 +86,9 @@ class VerticalFederatedLearnerTest : public ::testing::TestWithParam<std::string
|
||||
|
||||
auto model = MakeModel(tree_method, device, objective, dmat);
|
||||
auto score = GetBaseScore(model);
|
||||
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyObjective, kRows, kCols,
|
||||
score, model, tree_method, device, objective);
|
||||
collective::TestFederatedGlobal(kWorldSize, [&]() {
|
||||
VerifyObjective(kRows, kCols, score, model, tree_method, device, objective);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -1,243 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2023 XGBoost contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "../metric/test_auc.h"
|
||||
#include "../metric/test_elementwise_metric.h"
|
||||
#include "../metric/test_multiclass_metric.h"
|
||||
#include "../metric/test_rank_metric.h"
|
||||
#include "../metric/test_survival_metric.h"
|
||||
#include "helpers.h"
|
||||
|
||||
namespace {
|
||||
class FederatedMetricTest : public xgboost::BaseFederatedTest {};
|
||||
} // anonymous namespace
|
||||
|
||||
namespace xgboost {
|
||||
namespace metric {
|
||||
TEST_F(FederatedMetricTest, BinaryAUCRowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyBinaryAUC,
|
||||
DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, BinaryAUCColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyBinaryAUC,
|
||||
DataSplitMode::kCol);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, MultiClassAUCRowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassAUC,
|
||||
DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, MultiClassAUCColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassAUC,
|
||||
DataSplitMode::kCol);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, RankingAUCRowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRankingAUC,
|
||||
DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, RankingAUCColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRankingAUC,
|
||||
DataSplitMode::kCol);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, PRAUCRowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPRAUC, DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, PRAUCColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPRAUC, DataSplitMode::kCol);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, MultiClassPRAUCRowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassPRAUC,
|
||||
DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, MultiClassPRAUCColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassPRAUC,
|
||||
DataSplitMode::kCol);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, RankingPRAUCRowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRankingPRAUC,
|
||||
DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, RankingPRAUCColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRankingPRAUC,
|
||||
DataSplitMode::kCol);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, RMSERowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRMSE, DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, RMSEColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRMSE, DataSplitMode::kCol);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, RMSLERowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRMSLE, DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, RMSLEColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRMSLE, DataSplitMode::kCol);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, MAERowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAE, DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, MAEColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAE, DataSplitMode::kCol);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, MAPERowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAPE, DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, MAPEColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAPE, DataSplitMode::kCol);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, MPHERowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMPHE, DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, MPHEColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMPHE, DataSplitMode::kCol);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, LogLossRowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyLogLoss, DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, LogLossColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyLogLoss, DataSplitMode::kCol);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, ErrorRowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyError, DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, ErrorColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyError, DataSplitMode::kCol);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, PoissonNegLogLikRowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPoissonNegLogLik,
|
||||
DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, PoissonNegLogLikColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPoissonNegLogLik,
|
||||
DataSplitMode::kCol);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, MultiRMSERowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiRMSE,
|
||||
DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, MultiRMSEColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiRMSE,
|
||||
DataSplitMode::kCol);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, QuantileRowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyQuantile,
|
||||
DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, QuantileColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyQuantile,
|
||||
DataSplitMode::kCol);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, MultiClassErrorRowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassError,
|
||||
DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, MultiClassErrorColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassError,
|
||||
DataSplitMode::kCol);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, MultiClassLogLossRowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassLogLoss,
|
||||
DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, MultiClassLogLossColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassLogLoss,
|
||||
DataSplitMode::kCol);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, PrecisionRowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPrecision,
|
||||
DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, PrecisionColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPrecision,
|
||||
DataSplitMode::kCol);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, NDCGRowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyNDCG, DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, NDCGColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyNDCG, DataSplitMode::kCol);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, MAPRowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAP, DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, MAPColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAP, DataSplitMode::kCol);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, NDCGExpGainRowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyNDCGExpGain,
|
||||
DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, NDCGExpGainColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyNDCGExpGain,
|
||||
DataSplitMode::kCol);
|
||||
}
|
||||
} // namespace metric
|
||||
} // namespace xgboost
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
TEST_F(FederatedMetricTest, AFTNegLogLikRowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyAFTNegLogLik,
|
||||
DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, AFTNegLogLikColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyAFTNegLogLik,
|
||||
DataSplitMode::kCol);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, IntervalRegressionAccuracyRowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyIntervalRegressionAccuracy,
|
||||
DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, IntervalRegressionAccuracyColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyIntervalRegressionAccuracy,
|
||||
DataSplitMode::kCol);
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
@@ -1,133 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2017-2020 XGBoost contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <thread>
|
||||
|
||||
#include "federated_client.h"
|
||||
#include "helpers.h"
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
class FederatedServerTest : public BaseFederatedTest {
|
||||
public:
|
||||
static void VerifyAllgather(int rank, const std::string& server_address) {
|
||||
federated::FederatedClient client{server_address, rank};
|
||||
CheckAllgather(client, rank);
|
||||
}
|
||||
|
||||
static void VerifyAllgatherV(int rank, const std::string& server_address) {
|
||||
federated::FederatedClient client{server_address, rank};
|
||||
CheckAllgatherV(client, rank);
|
||||
}
|
||||
|
||||
static void VerifyAllreduce(int rank, const std::string& server_address) {
|
||||
federated::FederatedClient client{server_address, rank};
|
||||
CheckAllreduce(client);
|
||||
}
|
||||
|
||||
static void VerifyBroadcast(int rank, const std::string& server_address) {
|
||||
federated::FederatedClient client{server_address, rank};
|
||||
CheckBroadcast(client, rank);
|
||||
}
|
||||
|
||||
static void VerifyMixture(int rank, const std::string& server_address) {
|
||||
federated::FederatedClient client{server_address, rank};
|
||||
for (auto i = 0; i < 10; i++) {
|
||||
CheckAllgather(client, rank);
|
||||
CheckAllreduce(client);
|
||||
CheckBroadcast(client, rank);
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
static void CheckAllgather(federated::FederatedClient& client, int rank) {
|
||||
int data[] = {rank};
|
||||
std::string send_buffer(reinterpret_cast<char const*>(data), sizeof(data));
|
||||
auto reply = client.Allgather(send_buffer);
|
||||
auto const* result = reinterpret_cast<int const*>(reply.data());
|
||||
for (auto i = 0; i < kWorldSize; i++) {
|
||||
EXPECT_EQ(result[i], i);
|
||||
}
|
||||
}
|
||||
|
||||
static void CheckAllgatherV(federated::FederatedClient& client, int rank) {
|
||||
std::vector<std::string_view> inputs{"Hello,", " World!"};
|
||||
auto reply = client.AllgatherV(inputs[rank]);
|
||||
EXPECT_EQ(reply, "Hello, World!");
|
||||
}
|
||||
|
||||
static void CheckAllreduce(federated::FederatedClient& client) {
|
||||
int data[] = {1, 2, 3, 4, 5};
|
||||
std::string send_buffer(reinterpret_cast<char const*>(data), sizeof(data));
|
||||
auto reply = client.Allreduce(send_buffer, federated::INT32, federated::SUM);
|
||||
auto const* result = reinterpret_cast<int const*>(reply.data());
|
||||
int expected[] = {2, 4, 6, 8, 10};
|
||||
for (auto i = 0; i < 5; i++) {
|
||||
EXPECT_EQ(result[i], expected[i]);
|
||||
}
|
||||
}
|
||||
|
||||
static void CheckBroadcast(federated::FederatedClient& client, int rank) {
|
||||
std::string send_buffer{};
|
||||
if (rank == 0) {
|
||||
send_buffer = "hello broadcast";
|
||||
}
|
||||
auto reply = client.Broadcast(send_buffer, 0);
|
||||
EXPECT_EQ(reply, "hello broadcast") << "rank " << rank;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(FederatedServerTest, Allgather) {
|
||||
std::vector<std::thread> threads;
|
||||
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||
threads.emplace_back(&FederatedServerTest::VerifyAllgather, rank, server_->Address());
|
||||
}
|
||||
for (auto& thread : threads) {
|
||||
thread.join();
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(FederatedServerTest, AllgatherV) {
|
||||
std::vector<std::thread> threads;
|
||||
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||
threads.emplace_back(&FederatedServerTest::VerifyAllgatherV, rank, server_->Address());
|
||||
}
|
||||
for (auto& thread : threads) {
|
||||
thread.join();
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(FederatedServerTest, Allreduce) {
|
||||
std::vector<std::thread> threads;
|
||||
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||
threads.emplace_back(&FederatedServerTest::VerifyAllreduce, rank, server_->Address());
|
||||
}
|
||||
for (auto& thread : threads) {
|
||||
thread.join();
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(FederatedServerTest, Broadcast) {
|
||||
std::vector<std::thread> threads;
|
||||
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||
threads.emplace_back(&FederatedServerTest::VerifyBroadcast, rank, server_->Address());
|
||||
}
|
||||
for (auto& thread : threads) {
|
||||
thread.join();
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(FederatedServerTest, Mixture) {
|
||||
std::vector<std::thread> threads;
|
||||
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||
threads.emplace_back(&FederatedServerTest::VerifyMixture, rank, server_->Address());
|
||||
}
|
||||
for (auto& thread : threads) {
|
||||
thread.join();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
Reference in New Issue
Block a user