[coll] allgatherv. (#9688)

This commit is contained in:
Jiaming Yuan 2023-10-19 03:13:50 +08:00 committed by GitHub
parent ea9f09716b
commit 5d1bcde719
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 157 additions and 35 deletions

View File

@ -3,13 +3,16 @@
*/ */
#include "allgather.h" #include "allgather.h"
#include <algorithm> // for min #include <algorithm> // for min, copy_n
#include <cstddef> // for size_t #include <cstddef> // for size_t
#include <cstdint> // for int8_t #include <cstdint> // for int8_t, int32_t, int64_t
#include <memory> // for shared_ptr #include <memory> // for shared_ptr
#include <numeric> // for partial_sum
#include <vector> // for vector
#include "comm.h" // for Comm, Channel #include "comm.h" // for Comm, Channel
#include "xgboost/span.h" // for Span #include "xgboost/collective/result.h" // for Result
#include "xgboost/span.h" // for Span
namespace xgboost::collective::cpu_impl { namespace xgboost::collective::cpu_impl {
Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size_t segment_size, Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size_t segment_size,
@ -39,4 +42,47 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
return Success(); return Success();
} }
[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
common::Span<std::int8_t const> data,
common::Span<std::int8_t> erased_result) {
auto world = comm.World();
auto rank = comm.Rank();
auto prev = BootstrapPrev(rank, comm.World());
auto next = BootstrapNext(rank, comm.World());
auto prev_ch = comm.Chan(prev);
auto next_ch = comm.Chan(next);
// get worker offset
std::vector<std::int64_t> offset(world + 1, 0);
std::partial_sum(sizes.cbegin(), sizes.cend(), offset.begin() + 1);
CHECK_EQ(*offset.cbegin(), 0);
// copy data
auto current = erased_result.subspan(offset[rank], data.size_bytes());
auto erased_data = EraseType(data);
std::copy_n(erased_data.data(), erased_data.size(), current.data());
for (std::int32_t r = 0; r < world; ++r) {
auto send_rank = (rank + world - r) % world;
auto send_off = offset[send_rank];
auto send_size = sizes[send_rank];
auto send_seg = erased_result.subspan(send_off, send_size);
next_ch->SendAll(send_seg);
auto recv_rank = (rank + world - r - 1) % world;
auto recv_off = offset[recv_rank];
auto recv_size = sizes[recv_rank];
auto recv_seg = erased_result.subspan(recv_off, recv_size);
prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes());
auto rc = prev_ch->Block();
if (!rc.OK()) {
return rc;
}
}
return comm.Block();
}
} // namespace xgboost::collective::cpu_impl } // namespace xgboost::collective::cpu_impl

View File

@ -2,12 +2,16 @@
* Copyright 2023, XGBoost Contributors * Copyright 2023, XGBoost Contributors
*/ */
#pragma once #pragma once
#include <cstddef> // for size_t #include <cstddef> // for size_t
#include <cstdint> // for int32_t #include <cstdint> // for int32_t
#include <memory> // for shared_ptr #include <memory> // for shared_ptr
#include <numeric> // for accumulate
#include <type_traits> // for remove_cv_t
#include <vector> // for vector
#include "comm.h" // for Comm, Channel #include "comm.h" // for Comm, Channel, EraseType
#include "xgboost/span.h" // for Span #include "xgboost/collective/result.h" // for Result
#include "xgboost/span.h" // for Span
namespace xgboost::collective { namespace xgboost::collective {
namespace cpu_impl { namespace cpu_impl {
@ -19,14 +23,16 @@ namespace cpu_impl {
std::size_t segment_size, std::int32_t worker_off, std::size_t segment_size, std::int32_t worker_off,
std::shared_ptr<Channel> prev_ch, std::shared_ptr<Channel> prev_ch,
std::shared_ptr<Channel> next_ch); std::shared_ptr<Channel> next_ch);
[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
common::Span<std::int8_t const> data,
common::Span<std::int8_t> erased_result);
} // namespace cpu_impl } // namespace cpu_impl
template <typename T> template <typename T>
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<T> data, std::size_t size) { [[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<T> data, std::size_t size) {
auto n_total_bytes = data.size_bytes();
auto n_bytes = sizeof(T) * size; auto n_bytes = sizeof(T) * size;
auto erased = auto erased = EraseType(data);
common::Span<std::int8_t>{reinterpret_cast<std::int8_t*>(data.data()), n_total_bytes};
auto rank = comm.Rank(); auto rank = comm.Rank();
auto prev = BootstrapPrev(rank, comm.World()); auto prev = BootstrapPrev(rank, comm.World());
@ -40,4 +46,27 @@ template <typename T>
} }
return comm.Block(); return comm.Block();
} }
template <typename T>
[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span<T> data,
std::vector<std::remove_cv_t<T>>* p_out) {
auto world = comm.World();
auto rank = comm.Rank();
std::vector<std::int64_t> sizes(world, 0);
sizes[rank] = data.size_bytes();
auto rc = RingAllgather(comm, common::Span{sizes.data(), sizes.size()}, 1);
if (!rc.OK()) {
return rc;
}
std::vector<T>& result = *p_out;
auto n_total_bytes = std::accumulate(sizes.cbegin(), sizes.cend(), 0);
result.resize(n_total_bytes / sizeof(T));
auto h_result = common::Span{result.data(), result.size()};
auto erased_result = EraseType(h_result);
auto erased_data = EraseType(data);
return cpu_impl::RingAllgatherV(comm, sizes, erased_data, erased_result);
}
} // namespace xgboost::collective } // namespace xgboost::collective

View File

@ -3,6 +3,7 @@
*/ */
#include "broadcast.h" #include "broadcast.h"
#include <cmath> // for ceil, log2
#include <cstdint> // for int32_t, int8_t #include <cstdint> // for int32_t, int8_t
#include <utility> // for move #include <utility> // for move

View File

@ -11,8 +11,10 @@
#include "allgather.h" #include "allgather.h"
#include "protocol.h" // for kMagic #include "protocol.h" // for kMagic
#include "xgboost/base.h" // for XGBOOST_STRICT_R_MODE
#include "xgboost/collective/socket.h" // for TCPSocket #include "xgboost/collective/socket.h" // for TCPSocket
#include "xgboost/json.h" // for Json, Object #include "xgboost/json.h" // for Json, Object
#include "xgboost/string_view.h" // for StringView
namespace xgboost::collective { namespace xgboost::collective {
Comm::Comm(std::string const& host, std::int32_t port, std::chrono::seconds timeout, Comm::Comm(std::string const& host, std::int32_t port, std::chrono::seconds timeout,

View File

@ -2,20 +2,16 @@
* Copyright 2023, XGBoost Contributors * Copyright 2023, XGBoost Contributors
*/ */
#pragma once #pragma once
#include <chrono> // for seconds #include <chrono> // for seconds
#include <condition_variable> // for condition_variable #include <cstddef> // for size_t
#include <cstddef> // for size_t #include <cstdint> // for int32_t
#include <cstdint> // for int32_t #include <memory> // for shared_ptr
#include <memory> // for shared_ptr #include <string> // for string
#include <mutex> // for mutex #include <thread> // for thread
#include <queue> // for queue #include <type_traits> // for remove_const_t
#include <string> // for string #include <utility> // for move
#include <thread> // for thread #include <vector> // for vector
#include <type_traits> // for remove_const_t
#include <utility> // for move
#include <vector> // for vector
#include "../common/timer.h"
#include "loop.h" // for Loop #include "loop.h" // for Loop
#include "protocol.h" // for PeerInfo #include "protocol.h" // for PeerInfo
#include "xgboost/collective/result.h" // for Result #include "xgboost/collective/result.h" // for Result

View File

@ -175,6 +175,7 @@ template class HostDeviceVector<GradientPair>;
template class HostDeviceVector<GradientPairPrecise>; template class HostDeviceVector<GradientPairPrecise>;
template class HostDeviceVector<int32_t>; // bst_node_t template class HostDeviceVector<int32_t>; // bst_node_t
template class HostDeviceVector<uint8_t>; template class HostDeviceVector<uint8_t>;
template class HostDeviceVector<int8_t>;
template class HostDeviceVector<FeatureType>; template class HostDeviceVector<FeatureType>;
template class HostDeviceVector<Entry>; template class HostDeviceVector<Entry>;
template class HostDeviceVector<uint64_t>; // bst_row_t template class HostDeviceVector<uint64_t>; // bst_row_t

View File

@ -409,6 +409,7 @@ template class HostDeviceVector<GradientPair>;
template class HostDeviceVector<GradientPairPrecise>; template class HostDeviceVector<GradientPairPrecise>;
template class HostDeviceVector<int32_t>; // bst_node_t template class HostDeviceVector<int32_t>; // bst_node_t
template class HostDeviceVector<uint8_t>; template class HostDeviceVector<uint8_t>;
template class HostDeviceVector<int8_t>;
template class HostDeviceVector<FeatureType>; template class HostDeviceVector<FeatureType>;
template class HostDeviceVector<Entry>; template class HostDeviceVector<Entry>;
template class HostDeviceVector<uint64_t>; // bst_row_t template class HostDeviceVector<uint64_t>; // bst_row_t

View File

@ -1,18 +1,23 @@
/** /**
* Copyright 2023, XGBoost Contributors * Copyright 2023, XGBoost Contributors
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h> // for ASSERT_EQ
#include <xgboost/span.h> // for Span #include <xgboost/span.h> // for Span, oper...
#include <cstdint> // for int32_t #include <algorithm> // for min
#include <numeric> // for iota #include <chrono> // for seconds
#include <string> // for string #include <cstddef> // for size_t
#include <thread> // for thread #include <cstdint> // for int32_t
#include <vector> // for vector #include <numeric> // for iota
#include <string> // for string
#include <thread> // for thread
#include <vector> // for vector
#include "../../../src/collective/allgather.h" #include "../../../src/collective/allgather.h" // for RingAllgather
#include "../../../src/collective/tracker.h" // for GetHostAddress, Tracker #include "../../../src/collective/comm.h" // for RabitComm
#include "test_worker.h" // for TestDistributed== #include "gtest/gtest.h" // for AssertionR...
#include "test_worker.h" // for TestDistri...
#include "xgboost/collective/result.h" // for Result
namespace xgboost::collective { namespace xgboost::collective {
namespace { namespace {
@ -57,6 +62,38 @@ class Worker : public WorkerForTest {
} }
} }
} }
void TestV() {
{
// basic test
std::int32_t n{comm_.Rank()};
std::vector<std::int32_t> result;
auto rc = RingAllgatherV(comm_, common::Span{&n, 1}, &result);
ASSERT_TRUE(rc.OK()) << rc.Report();
for (std::int32_t i = 0; i < comm_.World(); ++i) {
ASSERT_EQ(result[i], i);
}
}
{
// V test
std::vector<std::int32_t> data(comm_.Rank() + 1, comm_.Rank());
std::vector<std::int32_t> result;
auto rc = RingAllgatherV(comm_, common::Span{data.data(), data.size()}, &result);
ASSERT_TRUE(rc.OK()) << rc.Report();
ASSERT_EQ(result.size(), (1 + comm_.World()) * comm_.World() / 2);
std::int32_t k{0};
for (std::int32_t r = 0; r < comm_.World(); ++r) {
auto seg = common::Span{result.data(), result.size()}.subspan(k, (r + 1));
if (comm_.Rank() == 0) {
for (auto v : seg) {
ASSERT_EQ(v, r);
}
k += seg.size();
}
}
}
}
}; };
} // namespace } // namespace
@ -68,4 +105,13 @@ TEST_F(AllgatherTest, Basic) {
worker.Run(); worker.Run();
}); });
} }
TEST_F(AllgatherTest, V) {
std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t r) {
Worker worker{host, port, timeout, n_workers, r};
worker.TestV();
});
}
} // namespace xgboost::collective } // namespace xgboost::collective