[Breaking] Switch from rabit to the collective communicator (#8257)
* Switch from rabit to the collective communicator * fix size_t specialization * really fix size_t * try again * add include * more include * fix lint errors * remove rabit includes * fix pylint error * return dict from communicator context * fix communicator shutdown * fix dask test * reset communicator mocklist * fix distributed tests * do not save device communicator * fix jvm gpu tests * add python test for federated communicator * Update gputreeshap submodule Co-authored-by: Hyunsu Philip Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -46,8 +46,8 @@ template <bool use_column>
|
||||
void TestDistributedQuantile(size_t rows, size_t cols) {
|
||||
std::string msg {"Skipping AllReduce test"};
|
||||
int32_t constexpr kWorkers = 4;
|
||||
InitRabitContext(msg, kWorkers);
|
||||
auto world = rabit::GetWorldSize();
|
||||
InitCommunicatorContext(msg, kWorkers);
|
||||
auto world = collective::GetWorldSize();
|
||||
if (world != 1) {
|
||||
ASSERT_EQ(world, kWorkers);
|
||||
} else {
|
||||
@@ -65,7 +65,7 @@ void TestDistributedQuantile(size_t rows, size_t cols) {
|
||||
|
||||
// Generate cuts for distributed environment.
|
||||
auto sparsity = 0.5f;
|
||||
auto rank = rabit::GetRank();
|
||||
auto rank = collective::GetRank();
|
||||
std::vector<FeatureType> ft(cols);
|
||||
for (size_t i = 0; i < ft.size(); ++i) {
|
||||
ft[i] = (i % 2 == 0) ? FeatureType::kNumerical : FeatureType::kCategorical;
|
||||
@@ -99,8 +99,8 @@ void TestDistributedQuantile(size_t rows, size_t cols) {
|
||||
sketch_distributed.MakeCuts(&distributed_cuts);
|
||||
|
||||
// Generate cuts for single node environment
|
||||
rabit::Finalize();
|
||||
CHECK_EQ(rabit::GetWorldSize(), 1);
|
||||
collective::Finalize();
|
||||
CHECK_EQ(collective::GetWorldSize(), 1);
|
||||
std::for_each(column_size.begin(), column_size.end(), [=](auto& size) { size *= world; });
|
||||
m->Info().num_row_ = world * rows;
|
||||
ContainerType<use_column> sketch_on_single_node(n_bins, m->Info().feature_types.ConstHostSpan(),
|
||||
@@ -184,8 +184,8 @@ TEST(Quantile, SameOnAllWorkers) {
|
||||
#if defined(__unix__)
|
||||
std::string msg{"Skipping Quantile AllreduceBasic test"};
|
||||
int32_t constexpr kWorkers = 4;
|
||||
InitRabitContext(msg, kWorkers);
|
||||
auto world = rabit::GetWorldSize();
|
||||
InitCommunicatorContext(msg, kWorkers);
|
||||
auto world = collective::GetWorldSize();
|
||||
if (world != 1) {
|
||||
CHECK_EQ(world, kWorkers);
|
||||
} else {
|
||||
@@ -196,7 +196,7 @@ TEST(Quantile, SameOnAllWorkers) {
|
||||
constexpr size_t kRows = 1000, kCols = 100;
|
||||
RunWithSeedsAndBins(
|
||||
kRows, [=](int32_t seed, size_t n_bins, MetaInfo const&) {
|
||||
auto rank = rabit::GetRank();
|
||||
auto rank = collective::GetRank();
|
||||
HostDeviceVector<float> storage;
|
||||
std::vector<FeatureType> ft(kCols);
|
||||
for (size_t i = 0; i < ft.size(); ++i) {
|
||||
@@ -217,12 +217,12 @@ TEST(Quantile, SameOnAllWorkers) {
|
||||
std::vector<float> cut_min_values(cuts.MinValues().size() * world, 0);
|
||||
|
||||
size_t value_size = cuts.Values().size();
|
||||
rabit::Allreduce<rabit::op::Max>(&value_size, 1);
|
||||
collective::Allreduce<collective::Operation::kMax>(&value_size, 1);
|
||||
size_t ptr_size = cuts.Ptrs().size();
|
||||
rabit::Allreduce<rabit::op::Max>(&ptr_size, 1);
|
||||
collective::Allreduce<collective::Operation::kMax>(&ptr_size, 1);
|
||||
CHECK_EQ(ptr_size, kCols + 1);
|
||||
size_t min_value_size = cuts.MinValues().size();
|
||||
rabit::Allreduce<rabit::op::Max>(&min_value_size, 1);
|
||||
collective::Allreduce<collective::Operation::kMax>(&min_value_size, 1);
|
||||
CHECK_EQ(min_value_size, kCols);
|
||||
|
||||
size_t value_offset = value_size * rank;
|
||||
@@ -235,9 +235,9 @@ TEST(Quantile, SameOnAllWorkers) {
|
||||
std::copy(cuts.MinValues().cbegin(), cuts.MinValues().cend(),
|
||||
cut_min_values.begin() + min_values_offset);
|
||||
|
||||
rabit::Allreduce<rabit::op::Sum>(cut_values.data(), cut_values.size());
|
||||
rabit::Allreduce<rabit::op::Sum>(cut_ptrs.data(), cut_ptrs.size());
|
||||
rabit::Allreduce<rabit::op::Sum>(cut_min_values.data(), cut_min_values.size());
|
||||
collective::Allreduce<collective::Operation::kSum>(cut_values.data(), cut_values.size());
|
||||
collective::Allreduce<collective::Operation::kSum>(cut_ptrs.data(), cut_ptrs.size());
|
||||
collective::Allreduce<collective::Operation::kSum>(cut_min_values.data(), cut_min_values.size());
|
||||
|
||||
for (int32_t i = 0; i < world; i++) {
|
||||
for (size_t j = 0; j < value_size; ++j) {
|
||||
@@ -256,7 +256,7 @@ TEST(Quantile, SameOnAllWorkers) {
|
||||
}
|
||||
}
|
||||
});
|
||||
rabit::Finalize();
|
||||
collective::Finalize();
|
||||
#endif // defined(__unix__)
|
||||
}
|
||||
} // namespace common
|
||||
|
||||
Reference in New Issue
Block a user