Update collective implementation. (#10152)
* Update collective implementation. - Cleanup resource during `Finalize` to avoid handling threads in destructor. - Calculate the size for allgather automatically. - Use simple allgather for small (smaller than the number of worker) allreduce.
This commit is contained in:
@@ -34,7 +34,7 @@ class Worker : public WorkerForTest {
|
||||
std::vector<std::int32_t> data(comm_.World(), 0);
|
||||
data[comm_.Rank()] = comm_.Rank();
|
||||
|
||||
auto rc = RingAllgather(this->comm_, common::Span{data.data(), data.size()}, 1);
|
||||
auto rc = RingAllgather(this->comm_, common::Span{data.data(), data.size()});
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
|
||||
for (std::int32_t r = 0; r < comm_.World(); ++r) {
|
||||
@@ -51,7 +51,7 @@ class Worker : public WorkerForTest {
|
||||
auto seg = s_data.subspan(comm_.Rank() * n, n);
|
||||
std::iota(seg.begin(), seg.end(), comm_.Rank());
|
||||
|
||||
auto rc = RingAllgather(comm_, common::Span{data.data(), data.size()}, n);
|
||||
auto rc = RingAllgather(comm_, common::Span{data.data(), data.size()});
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
|
||||
for (std::int32_t r = 0; r < comm_.World(); ++r) {
|
||||
@@ -104,7 +104,7 @@ class Worker : public WorkerForTest {
|
||||
|
||||
std::vector<std::int64_t> sizes(comm_.World(), 0);
|
||||
sizes[comm_.Rank()] = s_data.size_bytes();
|
||||
auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}, 1);
|
||||
auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()});
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
std::shared_ptr<Coll> pcoll{new Coll{}};
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
* Copyright 2023-2024, XGBoost Contributors
|
||||
*/
|
||||
#if defined(XGBOOST_USE_NCCL)
|
||||
#include <gtest/gtest.h>
|
||||
@@ -33,7 +33,7 @@ class Worker : public NCCLWorkerForTest {
|
||||
// get size
|
||||
std::vector<std::int64_t> sizes(comm_.World(), -1);
|
||||
sizes[comm_.Rank()] = s_data.size_bytes();
|
||||
auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}, 1);
|
||||
auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()});
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
// create result
|
||||
dh::device_vector<std::int32_t> result(comm_.World(), -1);
|
||||
@@ -57,7 +57,7 @@ class Worker : public NCCLWorkerForTest {
|
||||
// get size
|
||||
std::vector<std::int64_t> sizes(nccl_comm_->World(), 0);
|
||||
sizes[comm_.Rank()] = dh::ToSpan(data).size_bytes();
|
||||
auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}, 1);
|
||||
auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()});
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
auto n_bytes = std::accumulate(sizes.cbegin(), sizes.cend(), 0);
|
||||
// create result
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
* Copyright 2023-2024, XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
|
||||
#include "../../../src/collective/allreduce.h"
|
||||
#include "../../../src/collective/coll.h" // for Coll
|
||||
#include "../../../src/collective/tracker.h"
|
||||
#include "../../../src/common/type.h" // for EraseType
|
||||
#include "test_worker.h" // for WorkerForTest, TestDistributed
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <thrust/host_vector.h> // for host_vector
|
||||
|
||||
#include "../../../src/common/common.h"
|
||||
#include "../../../src/common/common.h" // for AllVisibleGPUs
|
||||
#include "../../../src/common/device_helpers.cuh" // for ToSpan, device_vector
|
||||
#include "../../../src/common/type.h" // for EraseType
|
||||
#include "test_worker.cuh" // for NCCLWorkerForTest
|
||||
|
||||
Reference in New Issue
Block a user