Serialize expand entry for allgather. (#9702)
This commit is contained in:
@@ -9,7 +9,8 @@
|
||||
#include <type_traits> // for remove_cv_t
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "comm.h" // for Comm, Channel, EraseType
|
||||
#include "../common/type.h" // for EraseType
|
||||
#include "comm.h" // for Comm, Channel
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
@@ -33,7 +34,7 @@ namespace cpu_impl {
|
||||
template <typename T>
|
||||
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<T> data, std::size_t size) {
|
||||
auto n_bytes = sizeof(T) * size;
|
||||
auto erased = EraseType(data);
|
||||
auto erased = common::EraseType(data);
|
||||
|
||||
auto rank = comm.Rank();
|
||||
auto prev = BootstrapPrev(rank, comm.World());
|
||||
@@ -65,8 +66,8 @@ template <typename T>
|
||||
auto n_total_bytes = std::accumulate(sizes.cbegin(), sizes.cend(), 0);
|
||||
result.resize(n_total_bytes / sizeof(T));
|
||||
auto h_result = common::Span{result.data(), result.size()};
|
||||
auto erased_result = EraseType(h_result);
|
||||
auto erased_data = EraseType(data);
|
||||
auto erased_result = common::EraseType(h_result);
|
||||
auto erased_data = common::EraseType(data);
|
||||
std::vector<std::int64_t> offset(world + 1);
|
||||
|
||||
return cpu_impl::RingAllgatherV(comm, sizes, erased_data,
|
||||
|
||||
@@ -4,8 +4,9 @@
|
||||
#pragma once
|
||||
#include <cstdint> // for int8_t
|
||||
#include <functional> // for function
|
||||
#include <type_traits> // for is_invocable_v
|
||||
#include <type_traits> // for is_invocable_v, enable_if_t
|
||||
|
||||
#include "../common/type.h" // for EraseType, RestoreType
|
||||
#include "../data/array_interface.h" // for ArrayInterfaceHandler
|
||||
#include "comm.h" // for Comm, RestoreType
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
@@ -23,14 +24,14 @@ Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func cons
|
||||
template <typename T, typename Fn>
|
||||
std::enable_if_t<std::is_invocable_v<Fn, common::Span<T const>, common::Span<T>>, Result> Allreduce(
|
||||
Comm const& comm, common::Span<T> data, Fn redop) {
|
||||
auto erased = EraseType(data);
|
||||
auto erased = common::EraseType(data);
|
||||
auto type = ToDType<T>::kType;
|
||||
|
||||
auto erased_fn = [type, redop](common::Span<std::int8_t const> lhs,
|
||||
common::Span<std::int8_t> out) {
|
||||
CHECK_EQ(lhs.size(), out.size()) << "Invalid input for reduction.";
|
||||
auto lhs_t = RestoreType<T const>(lhs);
|
||||
auto rhs_t = RestoreType<T>(out);
|
||||
auto lhs_t = common::RestoreType<T const>(lhs);
|
||||
auto rhs_t = common::RestoreType<T>(out);
|
||||
redop(lhs_t, rhs_t);
|
||||
};
|
||||
|
||||
|
||||
@@ -137,20 +137,4 @@ class Channel {
|
||||
};
|
||||
|
||||
enum class Op { kMax = 0, kMin = 1, kSum = 2, kBitwiseAND = 3, kBitwiseOR = 4, kBitwiseXOR = 5 };
|
||||
|
||||
template <typename T, typename U = std::conditional_t<std::is_const_v<T>,
|
||||
std::add_const_t<std::int8_t>, std::int8_t>>
|
||||
common::Span<U> EraseType(common::Span<T> data) {
|
||||
auto n_total_bytes = data.size_bytes();
|
||||
auto erased = common::Span{reinterpret_cast<std::add_pointer_t<U>>(data.data()), n_total_bytes};
|
||||
return erased;
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
common::Span<T> RestoreType(common::Span<U> data) {
|
||||
static_assert(std::is_same_v<std::remove_const_t<U>, std::int8_t>);
|
||||
auto n_total_bytes = data.size_bytes();
|
||||
auto restored = common::Span{reinterpret_cast<T*>(data.data()), n_total_bytes / sizeof(T)};
|
||||
return restored;
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@@ -327,6 +327,5 @@ inline SpecialAllgatherVResult<T> SpecialAllgatherV(std::vector<T> const &inputs
|
||||
|
||||
return {offsets, all_sizes, all_inputs};
|
||||
}
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user