[Breaking] Switch from rabit to the collective communicator (#8257)

* Switch from rabit to the collective communicator

* fix size_t specialization

* really fix size_t

* try again

* add include

* more include

* fix lint errors

* remove rabit includes

* fix pylint error

* return dict from communicator context

* fix communicator shutdown

* fix dask test

* reset communicator mocklist

* fix distributed tests

* do not save device communicator

* fix jvm gpu tests

* add python test for federated communicator

* Update gputreeshap submodule

Co-authored-by: Hyunsu Philip Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Rong Ou
2022-10-05 15:39:01 -07:00
committed by GitHub
parent e47b3a3da3
commit 668b8a0ea4
79 changed files with 805 additions and 2212 deletions

View File

@@ -19,7 +19,6 @@
#include <thrust/unique.h>
#include <thrust/binary_search.h>
#include <rabit/rabit.h>
#include <cub/cub.cuh>
#include <cub/util_allocator.cuh>
@@ -36,6 +35,7 @@
#include "xgboost/span.h"
#include "xgboost/global_config.h"
#include "../collective/communicator-inl.h"
#include "common.h"
#include "algorithm.cuh"
@@ -404,7 +404,7 @@ inline detail::MemoryLogger &GlobalMemoryLogger() {
// dh::DebugSyncDevice(__FILE__, __LINE__);
inline void DebugSyncDevice(std::string file="", int32_t line = -1) {
if (file != "" && line != -1) {
auto rank = rabit::GetRank();
auto rank = xgboost::collective::GetRank();
LOG(DEBUG) << "R:" << rank << ": " << file << ":" << line;
}
safe_cuda(cudaDeviceSynchronize());
@@ -423,7 +423,7 @@ using XGBBaseDeviceAllocator = thrust::device_malloc_allocator<T>;
inline void ThrowOOMError(std::string const& err, size_t bytes) {
auto device = CurrentDevice();
auto rank = rabit::GetRank();
auto rank = xgboost::collective::GetRank();
std::stringstream ss;
ss << "Memory allocation error on worker " << rank << ": " << err << "\n"
<< "- Free memory: " << AvailableMemory(device) << "\n"
@@ -737,512 +737,6 @@ using TypedDiscard =
std::conditional_t<HasThrustMinorVer<12>(), detail::TypedDiscardCTK114<T>,
detail::TypedDiscard<T>>;
/**
* \class AllReducer
*
* \brief All reducer class that manages its own communication group and
* streams. Must be initialised before use. If XGBoost is compiled without NCCL,
* this falls back to use Rabit.
*/
template <typename AllReducer>
class AllReducerBase : public xgboost::common::Crtp<AllReducer> {
public:
virtual ~AllReducerBase() = default;
/**
* \brief Initialise with the desired device ordinal for this allreducer.
*
* \param device_ordinal The device ordinal.
*/
void Init(int _device_ordinal) {
device_ordinal_ = _device_ordinal;
dh::safe_cuda(cudaSetDevice(device_ordinal_));
if (rabit::GetWorldSize() == 1) {
return;
}
this->Underlying().DoInit(_device_ordinal);
initialised_ = true;
}
/**
* \brief Allgather implemented as grouped calls to Broadcast. This way we can accept
* different size of data on different workers.
*
* \param data Buffer storing the input data.
* \param length_bytes Size of input data in bytes.
* \param segments Size of data on each worker.
* \param recvbuf Buffer storing the result of data from all workers.
*/
void AllGather(void const *data, size_t length_bytes, std::vector<size_t> *segments,
dh::caching_device_vector<char> *recvbuf) {
if (rabit::GetWorldSize() == 1) {
return;
}
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinal_));
this->Underlying().DoAllGather(data, length_bytes, segments, recvbuf);
}
/**
* \brief Allgather. Use in exactly the same way as NCCL but without needing
* streams or comms.
*
* \param data Buffer storing the input data.
* \param length Size of input data in bytes.
* \param recvbuf Buffer storing the result of data from all workers.
*/
void AllGather(uint32_t const *data, size_t length,
dh::caching_device_vector<uint32_t> *recvbuf) {
if (rabit::GetWorldSize() == 1) {
return;
}
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinal_));
this->Underlying().DoAllGather(data, length, recvbuf);
}
/**
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
* streams or comms.
*
* \param sendbuff The sendbuff.
* \param recvbuff The recvbuff.
* \param count Number of elements.
*/
void AllReduceSum(const double *sendbuff, double *recvbuff, int count) {
if (rabit::GetWorldSize() == 1) {
return;
}
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinal_));
this->Underlying().DoAllReduceSum(sendbuff, recvbuff, count);
allreduce_bytes_ += count * sizeof(double);
allreduce_calls_ += 1;
}
/**
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
* streams or comms.
*
* \param sendbuff The sendbuff.
* \param recvbuff The recvbuff.
* \param count Number of elements.
*/
void AllReduceSum(const float *sendbuff, float *recvbuff, int count) {
if (rabit::GetWorldSize() == 1) {
return;
}
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinal_));
this->Underlying().DoAllReduceSum(sendbuff, recvbuff, count);
allreduce_bytes_ += count * sizeof(float);
allreduce_calls_ += 1;
}
/**
* \brief Allreduce. Use in exactly the same way as NCCL but without needing streams or comms.
*
* \param count Number of.
*
* \param sendbuff The sendbuff.
* \param recvbuff The recvbuff.
* \param count Number of.
*/
void AllReduceSum(const int64_t *sendbuff, int64_t *recvbuff, int count) {
if (rabit::GetWorldSize() == 1) {
return;
}
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinal_));
this->Underlying().DoAllReduceSum(sendbuff, recvbuff, count);
allreduce_bytes_ += count * sizeof(int64_t);
allreduce_calls_ += 1;
}
/**
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
* streams or comms.
*
* \param sendbuff The sendbuff.
* \param recvbuff The recvbuff.
* \param count Number of elements.
*/
void AllReduceSum(const uint32_t *sendbuff, uint32_t *recvbuff, int count) {
if (rabit::GetWorldSize() == 1) {
return;
}
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinal_));
this->Underlying().DoAllReduceSum(sendbuff, recvbuff, count);
allreduce_bytes_ += count * sizeof(uint32_t);
allreduce_calls_ += 1;
}
/**
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
* streams or comms.
*
* \param sendbuff The sendbuff.
* \param recvbuff The recvbuff.
* \param count Number of elements.
*/
void AllReduceSum(const uint64_t *sendbuff, uint64_t *recvbuff, int count) {
if (rabit::GetWorldSize() == 1) {
return;
}
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinal_));
this->Underlying().DoAllReduceSum(sendbuff, recvbuff, count);
allreduce_bytes_ += count * sizeof(uint64_t);
allreduce_calls_ += 1;
}
/**
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
* streams or comms.
*
* Specialization for size_t, which is implementation defined so it might or might not
* be one of uint64_t/uint32_t/unsigned long long/unsigned long.
*
* \param sendbuff The sendbuff.
* \param recvbuff The recvbuff.
* \param count Number of elements.
*/
template <typename T = size_t,
std::enable_if_t<std::is_same<size_t, T>::value &&
!std::is_same<size_t, unsigned long long>::value> // NOLINT
* = nullptr>
void AllReduceSum(const T *sendbuff, T *recvbuff, int count) { // NOLINT
if (rabit::GetWorldSize() == 1) {
return;
}
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinal_));
static_assert(sizeof(unsigned long long) == sizeof(uint64_t), ""); // NOLINT
this->Underlying().DoAllReduceSum(sendbuff, recvbuff, count);
allreduce_bytes_ += count * sizeof(T);
allreduce_calls_ += 1;
}
/**
* \fn void Synchronize()
*
* \brief Synchronizes the entire communication group.
*/
void Synchronize() {
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinal_));
this->Underlying().DoSynchronize();
}
protected:
bool initialised_{false};
size_t allreduce_bytes_{0}; // Keep statistics of the number of bytes communicated.
size_t allreduce_calls_{0}; // Keep statistics of the number of reduce calls.
private:
int device_ordinal_{-1};
};
#ifdef XGBOOST_USE_NCCL
class NcclAllReducer : public AllReducerBase<NcclAllReducer> {
public:
friend class AllReducerBase<NcclAllReducer>;
~NcclAllReducer() override;
private:
/**
* \brief Initialise with the desired device ordinal for this communication
* group.
*
* \param device_ordinal The device ordinal.
*/
void DoInit(int _device_ordinal);
/**
* \brief Allgather implemented as grouped calls to Broadcast. This way we can accept
* different size of data on different workers.
*
* \param data Buffer storing the input data.
* \param length_bytes Size of input data in bytes.
* \param segments Size of data on each worker.
* \param recvbuf Buffer storing the result of data from all workers.
*/
void DoAllGather(void const *data, size_t length_bytes, std::vector<size_t> *segments,
dh::caching_device_vector<char> *recvbuf);
/**
* \brief Allgather. Use in exactly the same way as NCCL but without needing
* streams or comms.
*
* \param data Buffer storing the input data.
* \param length Size of input data in bytes.
* \param recvbuf Buffer storing the result of data from all workers.
*/
void DoAllGather(uint32_t const *data, size_t length,
dh::caching_device_vector<uint32_t> *recvbuf) {
size_t world = rabit::GetWorldSize();
recvbuf->resize(length * world);
safe_nccl(ncclAllGather(data, recvbuf->data().get(), length, ncclUint32, comm_, stream_));
}
/**
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
* streams or comms.
*
* \param sendbuff The sendbuff.
* \param recvbuff The recvbuff.
* \param count Number of elements.
*/
void DoAllReduceSum(const double *sendbuff, double *recvbuff, int count) {
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclDouble, ncclSum, comm_, stream_));
}
/**
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
* streams or comms.
*
* \param sendbuff The sendbuff.
* \param recvbuff The recvbuff.
* \param count Number of elements.
*/
void DoAllReduceSum(const float *sendbuff, float *recvbuff, int count) {
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclFloat, ncclSum, comm_, stream_));
}
/**
* \brief Allreduce. Use in exactly the same way as NCCL but without needing streams or comms.
*
* \param count Number of.
*
* \param sendbuff The sendbuff.
* \param recvbuff The recvbuff.
* \param count Number of.
*/
void DoAllReduceSum(const int64_t *sendbuff, int64_t *recvbuff, int count) {
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclInt64, ncclSum, comm_, stream_));
}
/**
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
* streams or comms.
*
* \param sendbuff The sendbuff.
* \param recvbuff The recvbuff.
* \param count Number of elements.
*/
void DoAllReduceSum(const uint32_t *sendbuff, uint32_t *recvbuff, int count) {
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclUint32, ncclSum, comm_, stream_));
}
/**
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
* streams or comms.
*
* \param sendbuff The sendbuff.
* \param recvbuff The recvbuff.
* \param count Number of elements.
*/
void DoAllReduceSum(const uint64_t *sendbuff, uint64_t *recvbuff, int count) {
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclUint64, ncclSum, comm_, stream_));
}
/**
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
* streams or comms.
*
* Specialization for size_t, which is implementation defined so it might or might not
* be one of uint64_t/uint32_t/unsigned long long/unsigned long.
*
* \param sendbuff The sendbuff.
* \param recvbuff The recvbuff.
* \param count Number of elements.
*/
template <typename T = size_t,
std::enable_if_t<std::is_same<size_t, T>::value &&
!std::is_same<size_t, unsigned long long>::value> // NOLINT
* = nullptr>
void DoAllReduceSum(const T *sendbuff, T *recvbuff, int count) { // NOLINT
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclUint64, ncclSum, comm_, stream_));
}
/**
* \brief Synchronizes the entire communication group.
*/
void DoSynchronize() { dh::safe_cuda(cudaStreamSynchronize(stream_)); }
/**
* \fn ncclUniqueId GetUniqueId()
*
* \brief Gets the Unique ID from NCCL to be used in setting up interprocess
* communication
*
* \return the Unique ID
*/
ncclUniqueId GetUniqueId() {
static const int kRootRank = 0;
ncclUniqueId id;
if (rabit::GetRank() == kRootRank) {
dh::safe_nccl(ncclGetUniqueId(&id));
}
rabit::Broadcast(static_cast<void *>(&id), sizeof(ncclUniqueId), static_cast<int>(kRootRank));
return id;
}
ncclComm_t comm_;
cudaStream_t stream_;
ncclUniqueId id_;
};
using AllReducer = NcclAllReducer;
#else
class RabitAllReducer : public AllReducerBase<RabitAllReducer> {
public:
friend class AllReducerBase<RabitAllReducer>;
private:
/**
* \brief Initialise with the desired device ordinal for this allreducer.
*
* \param device_ordinal The device ordinal.
*/
static void DoInit(int _device_ordinal);
/**
* \brief Allgather implemented as grouped calls to Broadcast. This way we can accept
* different size of data on different workers.
*
* \param data Buffer storing the input data.
* \param length_bytes Size of input data in bytes.
* \param segments Size of data on each worker.
* \param recvbuf Buffer storing the result of data from all workers.
*/
void DoAllGather(void const *data, size_t length_bytes, std::vector<size_t> *segments,
dh::caching_device_vector<char> *recvbuf);
/**
* \brief Allgather. Use in exactly the same way as NCCL.
*
* \param data Buffer storing the input data.
* \param length Size of input data in bytes.
* \param recvbuf Buffer storing the result of data from all workers.
*/
void DoAllGather(uint32_t *data, size_t length, dh::caching_device_vector<uint32_t> *recvbuf) {
size_t world = rabit::GetWorldSize();
auto total_size = length * world;
recvbuf->resize(total_size);
sendrecvbuf_.reserve(total_size);
auto rank = rabit::GetRank();
safe_cuda(cudaMemcpy(sendrecvbuf_.data() + rank * length, data, length, cudaMemcpyDefault));
rabit::Allgather(sendrecvbuf_.data(), total_size, rank * length, length, length);
safe_cuda(cudaMemcpy(data, sendrecvbuf_.data(), total_size, cudaMemcpyDefault));
}
/**
* \brief Allreduce. Use in exactly the same way as NCCL.
*
* \param sendbuff The sendbuff.
* \param recvbuff The recvbuff.
* \param count Number of elements.
*/
void DoAllReduceSum(const double *sendbuff, double *recvbuff, int count) {
RabitAllReduceSum(sendbuff, recvbuff, count);
}
/**
* \brief Allreduce. Use in exactly the same way as NCCL.
*
* \param sendbuff The sendbuff.
* \param recvbuff The recvbuff.
* \param count Number of elements.
*/
void DoAllReduceSum(const float *sendbuff, float *recvbuff, int count) {
RabitAllReduceSum(sendbuff, recvbuff, count);
}
/**
* \brief Allreduce. Use in exactly the same way as NCCL.
*
* \param sendbuff The sendbuff.
* \param recvbuff The recvbuff.
* \param count Number of elements.
*/
void DoAllReduceSum(const int64_t *sendbuff, int64_t *recvbuff, int count) {
RabitAllReduceSum(sendbuff, recvbuff, count);
}
/**
* \brief Allreduce. Use in exactly the same way as NCCL.
*
* \param sendbuff The sendbuff.
* \param recvbuff The recvbuff.
* \param count Number of elements.
*/
void DoAllReduceSum(const uint32_t *sendbuff, uint32_t *recvbuff, int count) {
RabitAllReduceSum(sendbuff, recvbuff, count);
}
/**
* \brief Allreduce. Use in exactly the same way as NCCL.
*
* \param sendbuff The sendbuff.
* \param recvbuff The recvbuff.
* \param count Number of elements.
*/
void DoAllReduceSum(const uint64_t *sendbuff, uint64_t *recvbuff, int count) {
RabitAllReduceSum(sendbuff, recvbuff, count);
}
/**
* \brief Allreduce. Use in exactly the same way as NCCL.
*
* Specialization for size_t, which is implementation defined so it might or might not
* be one of uint64_t/uint32_t/unsigned long long/unsigned long.
*
* \param sendbuff The sendbuff.
* \param recvbuff The recvbuff.
* \param count Number of elements.
*/
template <typename T = size_t,
std::enable_if_t<std::is_same<size_t, T>::value &&
!std::is_same<size_t, unsigned long long>::value> // NOLINT
* = nullptr>
void DoAllReduceSum(const T *sendbuff, T *recvbuff, int count) { // NOLINT
RabitAllReduceSum(sendbuff, recvbuff, count);
}
/**
* \brief Synchronizes the allreducer.
*/
void DoSynchronize() {}
/**
* \brief Allreduce. Use in exactly the same way as NCCL.
*
* Copy the device buffer to host, call rabit allreduce, then copy the buffer back
* to device.
*
* \param sendbuff The sendbuff.
* \param recvbuff The recvbuff.
* \param count Number of elements.
*/
template <typename T>
void RabitAllReduceSum(const T *sendbuff, T *recvbuff, int count) {
auto total_size = count * sizeof(T);
sendrecvbuf_.reserve(total_size);
safe_cuda(cudaMemcpy(sendrecvbuf_.data(), sendbuff, total_size, cudaMemcpyDefault));
rabit::Allreduce<rabit::op::Sum>(reinterpret_cast<T*>(sendrecvbuf_.data()), count);
safe_cuda(cudaMemcpy(recvbuff, sendrecvbuf_.data(), total_size, cudaMemcpyDefault));
}
/// Host buffer used to call rabit functions.
std::vector<char> sendrecvbuf_{};
};
using AllReducer = RabitAllReducer;
#endif
template <typename VectorT, typename T = typename VectorT::value_type,
typename IndexT = typename xgboost::common::Span<T>::index_type>
xgboost::common::Span<T> ToSpan(