Revamp the rabit implementation. (#10112)

This PR replaces the original RABIT implementation with a new one, which has already been partially merged into XGBoost. The new one features:
- Federated learning for both CPU and GPU.
- NCCL.
- More data types.
- A unified interface for all the underlying implementations.
- Improved timeout handling for both tracker and workers.
- Exhausted tests with metrics (fixed a couple of bugs along the way).
- A reusable tracker for Python and JVM packages.
This commit is contained in:
Jiaming Yuan
2024-05-20 11:56:23 +08:00
committed by GitHub
parent ba9b4cb1ee
commit a5a58102e5
195 changed files with 2768 additions and 9234 deletions

View File

@@ -615,7 +615,12 @@ auto DispatchDType(ArrayInterfaceHandler::Type dtype, Fn dispatch) {
case ArrayInterfaceHandler::kF16: {
using T = long double;
CHECK(sizeof(T) == 16) << error::NoF128();
return dispatch(T{});
// Avoid invalid type.
if constexpr (sizeof(T) == 16) {
return dispatch(T{});
} else {
return dispatch(double{});
}
}
case ArrayInterfaceHandler::kI1: {
return dispatch(std::int8_t{});

View File

@@ -18,7 +18,8 @@
#include <type_traits> // for remove_pointer_t, remove_reference
#include "../collective/communicator-inl.h" // for GetRank, GetWorldSize, Allreduce, IsFederated
#include "../collective/communicator.h" // for Operation
#include "../collective/allgather.h"
#include "../collective/allreduce.h"
#include "../common/algorithm.h" // for StableSort
#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry
#include "../common/error_msg.h" // for GroupSize, GroupWeight, InfInData
@@ -601,41 +602,42 @@ void MetaInfo::GetInfo(char const* key, bst_ulong* out_len, DataType dtype,
}
void MetaInfo::SetFeatureInfo(const char* key, const char **info, const bst_ulong size) {
if (size != 0 && this->num_col_ != 0 && !IsColumnSplit()) {
bool is_col_split = this->IsColumnSplit();
if (size != 0 && this->num_col_ != 0 && !is_col_split) {
CHECK_EQ(size, this->num_col_) << "Length of " << key << " must be equal to number of columns.";
CHECK(info);
}
if (!std::strcmp(key, "feature_type")) {
feature_type_names.clear();
for (size_t i = 0; i < size; ++i) {
auto elem = info[i];
feature_type_names.emplace_back(elem);
}
if (IsColumnSplit()) {
feature_type_names = collective::AllgatherStrings(feature_type_names);
CHECK_EQ(feature_type_names.size(), num_col_)
// Gather column info when data is split by columns
auto gather_columns = [is_col_split, key, n_columns = this->num_col_](auto const& inputs) {
if (is_col_split) {
std::remove_const_t<std::remove_reference_t<decltype(inputs)>> result;
auto rc = collective::AllgatherStrings(inputs, &result);
collective::SafeColl(rc);
CHECK_EQ(result.size(), n_columns)
<< "Length of " << key << " must be equal to number of columns.";
return result;
}
return inputs;
};
if (StringView{key} == "feature_type") { // NOLINT
this->feature_type_names.clear();
std::copy(info, info + size, std::back_inserter(feature_type_names));
feature_type_names = gather_columns(feature_type_names);
auto& h_feature_types = feature_types.HostVector();
this->has_categorical_ = LoadFeatureType(feature_type_names, &h_feature_types);
} else if (!std::strcmp(key, "feature_name")) {
if (IsColumnSplit()) {
std::vector<std::string> local_feature_names{};
} else if (StringView{key} == "feature_name") { // NOLINT
feature_names.clear();
if (is_col_split) {
auto const rank = collective::GetRank();
for (std::size_t i = 0; i < size; ++i) {
auto elem = std::to_string(rank) + "." + info[i];
local_feature_names.emplace_back(elem);
}
feature_names = collective::AllgatherStrings(local_feature_names);
CHECK_EQ(feature_names.size(), num_col_)
<< "Length of " << key << " must be equal to number of columns.";
std::transform(info, info + size, std::back_inserter(feature_names),
[rank](char const* elem) { return std::to_string(rank) + "." + elem; });
} else {
feature_names.clear();
for (size_t i = 0; i < size; ++i) {
feature_names.emplace_back(info[i]);
}
std::copy(info, info + size, std::back_inserter(feature_names));
}
feature_names = gather_columns(feature_names);
} else {
LOG(FATAL) << "Unknown feature info name: " << key;
}
@@ -728,12 +730,10 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col
}
}
void MetaInfo::SynchronizeNumberOfColumns(Context const*) {
if (IsColumnSplit()) {
collective::Allreduce<collective::Operation::kSum>(&num_col_, 1);
} else {
collective::Allreduce<collective::Operation::kMax>(&num_col_, 1);
}
void MetaInfo::SynchronizeNumberOfColumns(Context const* ctx) {
auto op = IsColumnSplit() ? collective::Op::kSum : collective::Op::kMax;
auto rc = collective::Allreduce(ctx, linalg::MakeVec(&num_col_, 1), op);
collective::SafeColl(rc);
}
namespace {

View File

@@ -9,11 +9,12 @@
#include <type_traits> // for underlying_type_t
#include <vector> // for vector
#include "../collective/communicator-inl.h"
#include "../common/categorical.h" // common::IsCat
#include "../collective/allreduce.h" // for Allreduce
#include "../collective/communicator-inl.h" // for IsDistributed
#include "../common/categorical.h" // common::IsCat
#include "../common/column_matrix.h"
#include "../tree/param.h" // FIXME(jiamingy): Find a better way to share this parameter.
#include "batch_utils.h" // for RegenGHist
#include "../tree/param.h" // FIXME(jiamingy): Find a better way to share this parameter.
#include "batch_utils.h" // for RegenGHist
#include "gradient_index.h"
#include "proxy_dmatrix.h"
#include "simple_batch_iterator.h"
@@ -95,13 +96,13 @@ void GetCutsFromRef(Context const* ctx, std::shared_ptr<DMatrix> ref, bst_featur
namespace {
// Synchronize feature type in case of empty DMatrix
void SyncFeatureType(Context const*, std::vector<FeatureType>* p_h_ft) {
void SyncFeatureType(Context const* ctx, std::vector<FeatureType>* p_h_ft) {
if (!collective::IsDistributed()) {
return;
}
auto& h_ft = *p_h_ft;
auto n_ft = h_ft.size();
collective::Allreduce<collective::Operation::kMax>(&n_ft, 1);
bst_idx_t n_ft = h_ft.size();
collective::SafeColl(collective::Allreduce(ctx, &n_ft, collective::Op::kMax));
if (!h_ft.empty()) {
// Check correct size if this is not an empty DMatrix.
CHECK_EQ(h_ft.size(), n_ft);
@@ -109,7 +110,8 @@ void SyncFeatureType(Context const*, std::vector<FeatureType>* p_h_ft) {
if (n_ft > 0) {
h_ft.resize(n_ft);
auto ptr = reinterpret_cast<std::underlying_type_t<FeatureType>*>(h_ft.data());
collective::Allreduce<collective::Operation::kMax>(ptr, h_ft.size());
collective::SafeColl(
collective::Allreduce(ctx, linalg::MakeVec(ptr, h_ft.size()), collective::Op::kMax));
}
}
} // anonymous namespace
@@ -175,7 +177,7 @@ void IterativeDMatrix::InitFromCPU(Context const* ctx, BatchParam const& p,
// We use do while here as the first batch is fetched in ctor
if (n_features == 0) {
n_features = num_cols();
collective::Allreduce<collective::Operation::kMax>(&n_features, 1);
collective::SafeColl(collective::Allreduce(ctx, &n_features, collective::Op::kMax));
column_sizes.clear();
column_sizes.resize(n_features, 0);
info_.num_col_ = n_features;

View File

@@ -1,20 +1,18 @@
/**
* Copyright 2020-2023, XGBoost contributors
* Copyright 2020-2024, XGBoost contributors
*/
#include <algorithm>
#include <memory>
#include <type_traits>
#include "../collective/allreduce.h"
#include "../common/hist_util.cuh"
#include "batch_utils.h" // for RegenGHist
#include "device_adapter.cuh"
#include "ellpack_page.cuh"
#include "gradient_index.h"
#include "iterative_dmatrix.h"
#include "proxy_dmatrix.cuh"
#include "proxy_dmatrix.h"
#include "simple_batch_iterator.h"
#include "sparse_page_source.h"
namespace xgboost::data {
void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
@@ -63,7 +61,8 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
dh::safe_cuda(cudaSetDevice(get_device().ordinal));
if (cols == 0) {
cols = num_cols();
collective::Allreduce<collective::Operation::kMax>(&cols, 1);
auto rc = collective::Allreduce(ctx, linalg::MakeVec(&cols, 1), collective::Op::kMax);
SafeColl(rc);
this->info_.num_col_ = cols;
} else {
CHECK_EQ(cols, num_cols()) << "Inconsistent number of columns.";

View File

@@ -171,12 +171,13 @@ decltype(auto) HostAdapterDispatch(DMatrixProxy const* proxy, Fn fn, bool* type_
} else {
LOG(FATAL) << "Unknown type: " << proxy->Adapter().type().name();
}
if constexpr (get_value) {
return std::invoke_result_t<
Fn, decltype(std::declval<std::shared_ptr<ArrayAdapter>>()->Value())>();
} else {
return std::invoke_result_t<Fn, decltype(std::declval<std::shared_ptr<ArrayAdapter>>())>();
}
}
if constexpr (get_value) {
return std::invoke_result_t<Fn,
decltype(std::declval<std::shared_ptr<ArrayAdapter>>()->Value())>();
} else {
return std::invoke_result_t<Fn, decltype(std::declval<std::shared_ptr<ArrayAdapter>>())>();
}
}

View File

@@ -1,5 +1,5 @@
/**
* Copyright 2014~2023, XGBoost Contributors
* Copyright 2014-2024, XGBoost Contributors
* \file simple_dmatrix.cc
* \brief the input data structure for gradient boosting
* \author Tianqi Chen
@@ -13,6 +13,7 @@
#include <vector>
#include "../collective/communicator-inl.h" // for GetWorldSize, GetRank, Allgather
#include "../collective/allgather.h"
#include "../common/error_msg.h" // for InconsistentMaxBin
#include "./simple_batch_iterator.h"
#include "adapter.h"
@@ -76,8 +77,11 @@ DMatrix* SimpleDMatrix::SliceCol(int num_slices, int slice_id) {
void SimpleDMatrix::ReindexFeatures(Context const* ctx) {
if (info_.IsColumnSplit() && collective::GetWorldSize() > 1) {
auto const cols = collective::Allgather(info_.num_col_);
auto const offset = std::accumulate(cols.cbegin(), cols.cbegin() + collective::GetRank(), 0ul);
std::vector<std::uint64_t> buffer(collective::GetWorldSize());
buffer[collective::GetRank()] = info_.num_col_;
auto rc = collective::Allgather(ctx, linalg::MakeVec(buffer.data(), buffer.size()));
SafeColl(rc);
auto offset = std::accumulate(buffer.cbegin(), buffer.cbegin() + collective::GetRank(), 0);
if (offset == 0) {
return;
}

View File

@@ -11,6 +11,7 @@
#include <future> // for async
#include <memory> // for unique_ptr
#include <mutex> // for mutex
#include <numeric> // for partial_sum
#include <string> // for string
#include <utility> // for pair, move
#include <vector> // for vector