Revamp the rabit implementation. (#10112)

This PR replaces the original RABIT implementation with a new one, which has already been partially merged into XGBoost. The new one features:
- Federated learning for both CPU and GPU.
- NCCL.
- More data types.
- A unified interface for all the underlying implementations.
- Improved timeout handling for both tracker and workers.
- Exhausted tests with metrics (fixed a couple of bugs along the way).
- A reusable tracker for Python and JVM packages.
This commit is contained in:
Jiaming Yuan
2024-05-20 11:56:23 +08:00
committed by GitHub
parent ba9b4cb1ee
commit a5a58102e5
195 changed files with 2768 additions and 9234 deletions

View File

@@ -5,11 +5,12 @@
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
#pragma GCC diagnostic ignored "-W#pragma-messages"
#include <rabit/rabit.h>
#pragma GCC diagnostic pop
#include "../sycl/device_manager.h"
#include "../../src/collective/communicator-inl.h"
namespace xgboost {
namespace sycl {
@@ -21,22 +22,23 @@ namespace sycl {
}
bool not_use_default_selector = (device_spec.ordinal != kDefaultOrdinal) ||
(rabit::IsDistributed());
(collective::IsDistributed());
if (not_use_default_selector) {
DeviceRegister& device_register = GetDevicesRegister();
const int device_idx = rabit::IsDistributed() ? rabit::GetRank() : device_spec.ordinal;
const int device_idx =
collective::IsDistributed() ? collective::GetRank() : device_spec.ordinal;
if (device_spec.IsSyclDefault()) {
auto& devices = device_register.devices;
CHECK_LT(device_idx, devices.size());
return devices[device_idx];
auto& devices = device_register.devices;
CHECK_LT(device_idx, devices.size());
return devices[device_idx];
} else if (device_spec.IsSyclCPU()) {
auto& cpu_devices = device_register.cpu_devices;
CHECK_LT(device_idx, cpu_devices.size());
return cpu_devices[device_idx];
auto& cpu_devices = device_register.cpu_devices;
CHECK_LT(device_idx, cpu_devices.size());
return cpu_devices[device_idx];
} else {
auto& gpu_devices = device_register.gpu_devices;
CHECK_LT(device_idx, gpu_devices.size());
return gpu_devices[device_idx];
auto& gpu_devices = device_register.gpu_devices;
CHECK_LT(device_idx, gpu_devices.size());
return gpu_devices[device_idx];
}
} else {
if (device_spec.IsSyclCPU()) {
@@ -62,24 +64,25 @@ namespace sycl {
}
bool not_use_default_selector = (device_spec.ordinal != kDefaultOrdinal) ||
(rabit::IsDistributed());
(collective::IsDistributed());
std::lock_guard<std::mutex> guard(queue_registering_mutex);
if (not_use_default_selector) {
DeviceRegister& device_register = GetDevicesRegister();
const int device_idx = rabit::IsDistributed() ? rabit::GetRank() : device_spec.ordinal;
if (device_spec.IsSyclDefault()) {
auto& devices = device_register.devices;
CHECK_LT(device_idx, devices.size());
queue_register[device_spec.Name()] = ::sycl::queue(devices[device_idx]);
} else if (device_spec.IsSyclCPU()) {
auto& cpu_devices = device_register.cpu_devices;
CHECK_LT(device_idx, cpu_devices.size());
queue_register[device_spec.Name()] = ::sycl::queue(cpu_devices[device_idx]);;
} else if (device_spec.IsSyclGPU()) {
auto& gpu_devices = device_register.gpu_devices;
CHECK_LT(device_idx, gpu_devices.size());
queue_register[device_spec.Name()] = ::sycl::queue(gpu_devices[device_idx]);
}
DeviceRegister& device_register = GetDevicesRegister();
const int device_idx =
collective::IsDistributed() ? collective::GetRank() : device_spec.ordinal;
if (device_spec.IsSyclDefault()) {
auto& devices = device_register.devices;
CHECK_LT(device_idx, devices.size());
queue_register[device_spec.Name()] = ::sycl::queue(devices[device_idx]);
} else if (device_spec.IsSyclCPU()) {
auto& cpu_devices = device_register.cpu_devices;
CHECK_LT(device_idx, cpu_devices.size());
queue_register[device_spec.Name()] = ::sycl::queue(cpu_devices[device_idx]);
} else if (device_spec.IsSyclGPU()) {
auto& gpu_devices = device_register.gpu_devices;
CHECK_LT(device_idx, gpu_devices.size());
queue_register[device_spec.Name()] = ::sycl::queue(gpu_devices[device_idx]);
}
} else {
if (device_spec.IsSyclCPU()) {
queue_register[device_spec.Name()] = ::sycl::queue(::sycl::cpu_selector_v);

View File

@@ -6,7 +6,6 @@
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
#pragma GCC diagnostic ignored "-W#pragma-messages"
#include <rabit/rabit.h>
#pragma GCC diagnostic pop
#include <vector>

View File

@@ -9,7 +9,6 @@
#include <xgboost/logging.h>
#include <xgboost/objective.h>
#pragma GCC diagnostic pop
#include <rabit/rabit.h>
#include <cmath>
#include <memory>

View File

@@ -4,7 +4,6 @@
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
#pragma GCC diagnostic ignored "-W#pragma-messages"
#include <rabit/rabit.h>
#pragma GCC diagnostic pop
#include <cstddef>