[coll] Broadcast. (#9659)

This commit is contained in:
Jiaming Yuan
2023-10-14 09:34:37 +08:00
committed by GitHub
parent 81a059864a
commit 53049b16b8
7 changed files with 220 additions and 15 deletions

View File

@@ -5,22 +5,21 @@
#ifndef XGBOOST_COMMON_BITFIELD_H_
#define XGBOOST_COMMON_BITFIELD_H_
#include <algorithm>
#include <bitset>
#include <cinttypes>
#include <iostream>
#include <sstream>
#include <string>
#include <vector>
#include <algorithm> // for min
#include <bitset> // for bitset
#include <cstdint> // for uint32_t, uint64_t, uint8_t
#include <ostream> // for ostream
#include <type_traits> // for conditional_t, is_signed_v
#if defined(__CUDACC__)
#include <thrust/copy.h>
#include <thrust/device_ptr.h>
#include "device_helpers.cuh"
#endif // defined(__CUDACC__)
#include "xgboost/span.h"
#include "common.h"
#include "xgboost/span.h" // for Span
namespace xgboost {
@@ -75,7 +74,7 @@ struct BitFieldContainer {
private:
value_type* bits_{nullptr};
size_type n_values_{0};
static_assert(!std::is_signed<VT>::value, "Must use an unsiged type as the underlying storage.");
static_assert(!std::is_signed_v<VT>, "Must use an unsiged type as the underlying storage.");
public:
XGBOOST_DEVICE static Pos ToBitPos(index_type pos) {
@@ -240,11 +239,39 @@ struct RBitsPolicy : public BitFieldContainer<VT, RBitsPolicy<VT>> {
// Format: <Const><Direction>BitField<size of underlying type in bits>, underlying type
// must be unsigned.
using LBitField64 = BitFieldContainer<uint64_t, LBitsPolicy<uint64_t>>;
using RBitField8 = BitFieldContainer<uint8_t, RBitsPolicy<unsigned char>>;
using LBitField64 = BitFieldContainer<std::uint64_t, LBitsPolicy<std::uint64_t>>;
using RBitField8 = BitFieldContainer<std::uint8_t, RBitsPolicy<unsigned char>>;
using LBitField32 = BitFieldContainer<uint32_t, LBitsPolicy<uint32_t>>;
using CLBitField32 = BitFieldContainer<uint32_t, LBitsPolicy<uint32_t, true>, true>;
using LBitField32 = BitFieldContainer<std::uint32_t, LBitsPolicy<std::uint32_t>>;
using CLBitField32 = BitFieldContainer<std::uint32_t, LBitsPolicy<std::uint32_t, true>, true>;
using RBitField32 = BitFieldContainer<std::uint32_t, RBitsPolicy<std::uint32_t>>;
namespace detail {
inline std::uint32_t TrailingZeroBitsImpl(std::uint32_t value) {
auto n = sizeof(value) * 8;
std::uint32_t cnt{0};
for (decltype(n) i = 0; i < n; i++) {
if ((value >> i) & 1) {
break;
}
cnt++;
}
return cnt;
}
} // namespace detail
inline std::uint32_t TrailingZeroBits(std::uint32_t value) {
if (value == 0) {
return sizeof(value) * 8;
}
#if defined(__GNUC__)
return __builtin_ctz(value);
#elif defined(_MSC_VER)
return _tzcnt_u32(value);
#else
return detail::TrailingZeroBitsImpl(value);
#endif // __GNUC__
}
} // namespace xgboost
#endif // XGBOOST_COMMON_BITFIELD_H_

View File

@@ -6,7 +6,6 @@
#ifndef XGBOOST_COMMON_COMMON_H_
#define XGBOOST_COMMON_COMMON_H_
#include <algorithm> // for max
#include <array> // for array
#include <cmath> // for ceil
#include <cstddef> // for size_t
@@ -181,7 +180,7 @@ inline void SetDevice(std::int32_t device) {
#endif
/**
* Last index of a group in a CSR style of index pointer.
* @brief Last index of a group in a CSR style of index pointer.
*/
template <typename Indexable>
XGBOOST_DEVICE size_t LastOf(size_t group, Indexable const &indptr) {