[coll] Broadcast. (#9659)
This commit is contained in:
@@ -5,22 +5,21 @@
|
||||
#ifndef XGBOOST_COMMON_BITFIELD_H_
|
||||
#define XGBOOST_COMMON_BITFIELD_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <bitset>
|
||||
#include <cinttypes>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <algorithm> // for min
|
||||
#include <bitset> // for bitset
|
||||
#include <cstdint> // for uint32_t, uint64_t, uint8_t
|
||||
#include <ostream> // for ostream
|
||||
#include <type_traits> // for conditional_t, is_signed_v
|
||||
|
||||
#if defined(__CUDACC__)
|
||||
#include <thrust/copy.h>
|
||||
#include <thrust/device_ptr.h>
|
||||
|
||||
#include "device_helpers.cuh"
|
||||
#endif // defined(__CUDACC__)
|
||||
|
||||
#include "xgboost/span.h"
|
||||
#include "common.h"
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
@@ -75,7 +74,7 @@ struct BitFieldContainer {
|
||||
private:
|
||||
value_type* bits_{nullptr};
|
||||
size_type n_values_{0};
|
||||
static_assert(!std::is_signed<VT>::value, "Must use an unsiged type as the underlying storage.");
|
||||
static_assert(!std::is_signed_v<VT>, "Must use an unsiged type as the underlying storage.");
|
||||
|
||||
public:
|
||||
XGBOOST_DEVICE static Pos ToBitPos(index_type pos) {
|
||||
@@ -240,11 +239,39 @@ struct RBitsPolicy : public BitFieldContainer<VT, RBitsPolicy<VT>> {
|
||||
|
||||
// Format: <Const><Direction>BitField<size of underlying type in bits>, underlying type
|
||||
// must be unsigned.
|
||||
using LBitField64 = BitFieldContainer<uint64_t, LBitsPolicy<uint64_t>>;
|
||||
using RBitField8 = BitFieldContainer<uint8_t, RBitsPolicy<unsigned char>>;
|
||||
using LBitField64 = BitFieldContainer<std::uint64_t, LBitsPolicy<std::uint64_t>>;
|
||||
using RBitField8 = BitFieldContainer<std::uint8_t, RBitsPolicy<unsigned char>>;
|
||||
|
||||
using LBitField32 = BitFieldContainer<uint32_t, LBitsPolicy<uint32_t>>;
|
||||
using CLBitField32 = BitFieldContainer<uint32_t, LBitsPolicy<uint32_t, true>, true>;
|
||||
using LBitField32 = BitFieldContainer<std::uint32_t, LBitsPolicy<std::uint32_t>>;
|
||||
using CLBitField32 = BitFieldContainer<std::uint32_t, LBitsPolicy<std::uint32_t, true>, true>;
|
||||
using RBitField32 = BitFieldContainer<std::uint32_t, RBitsPolicy<std::uint32_t>>;
|
||||
|
||||
namespace detail {
|
||||
inline std::uint32_t TrailingZeroBitsImpl(std::uint32_t value) {
|
||||
auto n = sizeof(value) * 8;
|
||||
std::uint32_t cnt{0};
|
||||
for (decltype(n) i = 0; i < n; i++) {
|
||||
if ((value >> i) & 1) {
|
||||
break;
|
||||
}
|
||||
cnt++;
|
||||
}
|
||||
return cnt;
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
inline std::uint32_t TrailingZeroBits(std::uint32_t value) {
|
||||
if (value == 0) {
|
||||
return sizeof(value) * 8;
|
||||
}
|
||||
#if defined(__GNUC__)
|
||||
return __builtin_ctz(value);
|
||||
#elif defined(_MSC_VER)
|
||||
return _tzcnt_u32(value);
|
||||
#else
|
||||
return detail::TrailingZeroBitsImpl(value);
|
||||
#endif // __GNUC__
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
#endif // XGBOOST_COMMON_BITFIELD_H_
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
#ifndef XGBOOST_COMMON_COMMON_H_
|
||||
#define XGBOOST_COMMON_COMMON_H_
|
||||
|
||||
#include <algorithm> // for max
|
||||
#include <array> // for array
|
||||
#include <cmath> // for ceil
|
||||
#include <cstddef> // for size_t
|
||||
@@ -181,7 +180,7 @@ inline void SetDevice(std::int32_t device) {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* Last index of a group in a CSR style of index pointer.
|
||||
* @brief Last index of a group in a CSR style of index pointer.
|
||||
*/
|
||||
template <typename Indexable>
|
||||
XGBOOST_DEVICE size_t LastOf(size_t group, Indexable const &indptr) {
|
||||
|
||||
Reference in New Issue
Block a user