[Breaking] Switch from rabit to the collective communicator (#8257)
* Switch from rabit to the collective communicator * fix size_t specialization * really fix size_t * try again * add include * more include * fix lint errors * remove rabit includes * fix pylint error * return dict from communicator context * fix communicator shutdown * fix dask test * reset communicator mocklist * fix distributed tests * do not save device communicator * fix jvm gpu tests * add python test for federated communicator * Update gputreeshap submodule Co-authored-by: Hyunsu Philip Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -11,11 +11,10 @@
|
||||
#include <utility>
|
||||
#include <tuple>
|
||||
|
||||
#include "rabit/rabit.h"
|
||||
#include "xgboost/span.h"
|
||||
#include "xgboost/data.h"
|
||||
#include "auc.h"
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "../collective/device_communicator.cuh"
|
||||
#include "../common/ranking_utils.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
@@ -46,9 +45,8 @@ struct DeviceAUCCache {
|
||||
dh::device_vector<size_t> unique_idx;
|
||||
// p^T: transposed prediction matrix, used by MultiClassAUC
|
||||
dh::device_vector<float> predts_t;
|
||||
std::unique_ptr<dh::AllReducer> reducer;
|
||||
|
||||
void Init(common::Span<float const> predts, bool is_multi, int32_t device) {
|
||||
void Init(common::Span<float const> predts, bool is_multi) {
|
||||
if (sorted_idx.size() != predts.size()) {
|
||||
sorted_idx.resize(predts.size());
|
||||
fptp.resize(sorted_idx.size());
|
||||
@@ -58,10 +56,6 @@ struct DeviceAUCCache {
|
||||
predts_t.resize(sorted_idx.size());
|
||||
}
|
||||
}
|
||||
if (is_multi && !reducer) {
|
||||
reducer.reset(new dh::AllReducer);
|
||||
reducer->Init(device);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -72,7 +66,7 @@ void InitCacheOnce(common::Span<float const> predts, int32_t device,
|
||||
if (!cache) {
|
||||
cache.reset(new DeviceAUCCache);
|
||||
}
|
||||
cache->Init(predts, is_multi, device);
|
||||
cache->Init(predts, is_multi);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -205,9 +199,11 @@ double ScaleClasses(common::Span<double> results, common::Span<double> local_are
|
||||
common::Span<double> tp, common::Span<double> auc,
|
||||
std::shared_ptr<DeviceAUCCache> cache, size_t n_classes) {
|
||||
dh::XGBDeviceAllocator<char> alloc;
|
||||
if (rabit::IsDistributed()) {
|
||||
CHECK_EQ(dh::CudaGetPointerDevice(results.data()), dh::CurrentDevice());
|
||||
cache->reducer->AllReduceSum(results.data(), results.data(), results.size());
|
||||
if (collective::IsDistributed()) {
|
||||
int32_t device = dh::CurrentDevice();
|
||||
CHECK_EQ(dh::CudaGetPointerDevice(results.data()), device);
|
||||
auto* communicator = collective::Communicator::GetDevice(device);
|
||||
communicator->AllReduceSum(results.data(), results.size());
|
||||
}
|
||||
auto reduce_in = dh::MakeTransformIterator<Pair>(
|
||||
thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) {
|
||||
|
||||
Reference in New Issue
Block a user