Serialize expand entry for allgather. (#9702)

This commit is contained in:
Jiaming Yuan
2023-10-24 14:33:28 +08:00
committed by GitHub
parent ee8b29c843
commit 7a02facc9d
14 changed files with 336 additions and 76 deletions

View File

@@ -9,7 +9,7 @@
#include <bitset> // for bitset
#include <cstdint> // for uint32_t, uint64_t, uint8_t
#include <ostream> // for ostream
#include <type_traits> // for conditional_t, is_signed_v
#include <type_traits> // for conditional_t, is_signed_v, add_const_t
#if defined(__CUDACC__)
#include <thrust/copy.h>

24
src/common/type.h Normal file
View File

@@ -0,0 +1,24 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#pragma once
#include <cstdint> // for int8_t
#include <type_traits> // for is_const_v, add_const_t, conditional_t, add_pointer_t
#include "xgboost/span.h" // for Span
namespace xgboost::common {
template <typename T, typename U = std::conditional_t<std::is_const_v<T>,
std::add_const_t<std::int8_t>, std::int8_t>>
common::Span<U> EraseType(common::Span<T> data) {
auto n_total_bytes = data.size_bytes();
auto erased = common::Span{reinterpret_cast<std::add_pointer_t<U>>(data.data()), n_total_bytes};
return erased;
}
template <typename T, typename U>
common::Span<T> RestoreType(common::Span<U> data) {
auto n_total_bytes = data.size_bytes();
auto restored = common::Span{reinterpret_cast<T*>(data.data()), n_total_bytes / sizeof(T)};
return restored;
}
} // namespace xgboost::common