[Breaking] Switch from rabit to the collective communicator (#8257)

* Switch from rabit to the collective communicator

* fix size_t specialization

* really fix size_t

* try again

* add include

* more include

* fix lint errors

* remove rabit includes

* fix pylint error

* return dict from communicator context

* fix communicator shutdown

* fix dask test

* reset communicator mocklist

* fix distributed tests

* do not save device communicator

* fix jvm gpu tests

* add python test for federated communicator

* Update gputreeshap submodule

Co-authored-by: Hyunsu Philip Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Rong Ou
2022-10-05 15:39:01 -07:00
committed by GitHub
parent e47b3a3da3
commit 668b8a0ea4
79 changed files with 805 additions and 2212 deletions

View File

@@ -11,11 +11,10 @@
#include <utility>
#include <tuple>
#include "rabit/rabit.h"
#include "xgboost/span.h"
#include "xgboost/data.h"
#include "auc.h"
#include "../common/device_helpers.cuh"
#include "../collective/device_communicator.cuh"
#include "../common/ranking_utils.cuh"
namespace xgboost {
@@ -46,9 +45,8 @@ struct DeviceAUCCache {
dh::device_vector<size_t> unique_idx;
// p^T: transposed prediction matrix, used by MultiClassAUC
dh::device_vector<float> predts_t;
std::unique_ptr<dh::AllReducer> reducer;
void Init(common::Span<float const> predts, bool is_multi, int32_t device) {
void Init(common::Span<float const> predts, bool is_multi) {
if (sorted_idx.size() != predts.size()) {
sorted_idx.resize(predts.size());
fptp.resize(sorted_idx.size());
@@ -58,10 +56,6 @@ struct DeviceAUCCache {
predts_t.resize(sorted_idx.size());
}
}
if (is_multi && !reducer) {
reducer.reset(new dh::AllReducer);
reducer->Init(device);
}
}
};
@@ -72,7 +66,7 @@ void InitCacheOnce(common::Span<float const> predts, int32_t device,
if (!cache) {
cache.reset(new DeviceAUCCache);
}
cache->Init(predts, is_multi, device);
cache->Init(predts, is_multi);
}
/**
@@ -205,9 +199,11 @@ double ScaleClasses(common::Span<double> results, common::Span<double> local_are
common::Span<double> tp, common::Span<double> auc,
std::shared_ptr<DeviceAUCCache> cache, size_t n_classes) {
dh::XGBDeviceAllocator<char> alloc;
if (rabit::IsDistributed()) {
CHECK_EQ(dh::CudaGetPointerDevice(results.data()), dh::CurrentDevice());
cache->reducer->AllReduceSum(results.data(), results.data(), results.size());
if (collective::IsDistributed()) {
int32_t device = dh::CurrentDevice();
CHECK_EQ(dh::CudaGetPointerDevice(results.data()), device);
auto* communicator = collective::Communicator::GetDevice(device);
communicator->AllReduceSum(results.data(), results.size());
}
auto reduce_in = dh::MakeTransformIterator<Pair>(
thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) {