Update collective implementation. (#10152)

* Update collective implementation.

- Cleanup resource during `Finalize` to avoid handling threads in destructor.
- Calculate the size for allgather automatically.
- Use simple allgather for small (smaller than the number of worker) allreduce.
This commit is contained in:
Jiaming Yuan
2024-03-30 18:57:31 +08:00
committed by GitHub
parent 230010d9a0
commit 8bad677c2f
31 changed files with 233 additions and 127 deletions

View File

@@ -34,7 +34,7 @@ class Worker : public WorkerForTest {
std::vector<std::int32_t> data(comm_.World(), 0);
data[comm_.Rank()] = comm_.Rank();
auto rc = RingAllgather(this->comm_, common::Span{data.data(), data.size()}, 1);
auto rc = RingAllgather(this->comm_, common::Span{data.data(), data.size()});
ASSERT_TRUE(rc.OK()) << rc.Report();
for (std::int32_t r = 0; r < comm_.World(); ++r) {
@@ -51,7 +51,7 @@ class Worker : public WorkerForTest {
auto seg = s_data.subspan(comm_.Rank() * n, n);
std::iota(seg.begin(), seg.end(), comm_.Rank());
auto rc = RingAllgather(comm_, common::Span{data.data(), data.size()}, n);
auto rc = RingAllgather(comm_, common::Span{data.data(), data.size()});
ASSERT_TRUE(rc.OK()) << rc.Report();
for (std::int32_t r = 0; r < comm_.World(); ++r) {
@@ -104,7 +104,7 @@ class Worker : public WorkerForTest {
std::vector<std::int64_t> sizes(comm_.World(), 0);
sizes[comm_.Rank()] = s_data.size_bytes();
auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}, 1);
auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()});
ASSERT_TRUE(rc.OK()) << rc.Report();
std::shared_ptr<Coll> pcoll{new Coll{}};

View File

@@ -1,5 +1,5 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#if defined(XGBOOST_USE_NCCL)
#include <gtest/gtest.h>
@@ -33,7 +33,7 @@ class Worker : public NCCLWorkerForTest {
// get size
std::vector<std::int64_t> sizes(comm_.World(), -1);
sizes[comm_.Rank()] = s_data.size_bytes();
auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}, 1);
auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()});
ASSERT_TRUE(rc.OK()) << rc.Report();
// create result
dh::device_vector<std::int32_t> result(comm_.World(), -1);
@@ -57,7 +57,7 @@ class Worker : public NCCLWorkerForTest {
// get size
std::vector<std::int64_t> sizes(nccl_comm_->World(), 0);
sizes[comm_.Rank()] = dh::ToSpan(data).size_bytes();
auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}, 1);
auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()});
ASSERT_TRUE(rc.OK()) << rc.Report();
auto n_bytes = std::accumulate(sizes.cbegin(), sizes.cend(), 0);
// create result

View File

@@ -1,5 +1,5 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#include <gtest/gtest.h>
@@ -7,7 +7,6 @@
#include "../../../src/collective/allreduce.h"
#include "../../../src/collective/coll.h" // for Coll
#include "../../../src/collective/tracker.h"
#include "../../../src/common/type.h" // for EraseType
#include "test_worker.h" // for WorkerForTest, TestDistributed

View File

@@ -5,7 +5,7 @@
#include <gtest/gtest.h>
#include <thrust/host_vector.h> // for host_vector
#include "../../../src/common/common.h"
#include "../../../src/common/common.h" // for AllVisibleGPUs
#include "../../../src/common/device_helpers.cuh" // for ToSpan, device_vector
#include "../../../src/common/type.h" // for EraseType
#include "test_worker.cuh" // for NCCLWorkerForTest