Fix index type for bitfield. (#7541)

This commit is contained in:
Jiaming Yuan 2022-01-05 19:23:29 +08:00 committed by GitHub
parent 0df2ae63c7
commit 91c1a1c52f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 23 additions and 14 deletions

View File

@ -58,14 +58,15 @@ __forceinline__ __device__ BitFieldAtomicType AtomicAnd(BitFieldAtomicType* addr
template <typename VT, typename Direction, bool IsConst = false> template <typename VT, typename Direction, bool IsConst = false>
struct BitFieldContainer { struct BitFieldContainer {
using value_type = std::conditional_t<IsConst, VT const, VT>; // NOLINT using value_type = std::conditional_t<IsConst, VT const, VT>; // NOLINT
using index_type = size_t; // NOLINT
using pointer = value_type*; // NOLINT using pointer = value_type*; // NOLINT
static value_type constexpr kValueSize = sizeof(value_type) * 8; static index_type constexpr kValueSize = sizeof(value_type) * 8;
static value_type constexpr kOne = 1; // force correct type. static index_type constexpr kOne = 1; // force correct type.
struct Pos { struct Pos {
std::remove_const_t<value_type> int_pos {0}; index_type int_pos{0};
std::remove_const_t<value_type> bit_pos {0}; index_type bit_pos{0};
}; };
private: private:
@ -73,7 +74,7 @@ struct BitFieldContainer {
static_assert(!std::is_signed<VT>::value, "Must use unsiged type as underlying storage."); static_assert(!std::is_signed<VT>::value, "Must use unsiged type as underlying storage.");
public: public:
XGBOOST_DEVICE static Pos ToBitPos(value_type pos) { XGBOOST_DEVICE static Pos ToBitPos(index_type pos) {
Pos pos_v; Pos pos_v;
if (pos == 0) { if (pos == 0) {
return pos_v; return pos_v;
@ -96,7 +97,7 @@ struct BitFieldContainer {
/*\brief Compute the size of needed memory allocation. The returned value is in terms /*\brief Compute the size of needed memory allocation. The returned value is in terms
* of number of elements with `BitFieldContainer::value_type'. * of number of elements with `BitFieldContainer::value_type'.
*/ */
XGBOOST_DEVICE static size_t ComputeStorageSize(size_t size) { XGBOOST_DEVICE static size_t ComputeStorageSize(index_type size) {
return common::DivRoundUp(size, kValueSize); return common::DivRoundUp(size, kValueSize);
} }
#if defined(__CUDA_ARCH__) #if defined(__CUDA_ARCH__)
@ -138,14 +139,14 @@ struct BitFieldContainer {
#endif // defined(__CUDA_ARCH__) #endif // defined(__CUDA_ARCH__)
#if defined(__CUDA_ARCH__) #if defined(__CUDA_ARCH__)
__device__ auto Set(value_type pos) { __device__ auto Set(index_type pos) {
Pos pos_v = Direction::Shift(ToBitPos(pos)); Pos pos_v = Direction::Shift(ToBitPos(pos));
value_type& value = bits_[pos_v.int_pos]; value_type& value = bits_[pos_v.int_pos];
value_type set_bit = kOne << pos_v.bit_pos; value_type set_bit = kOne << pos_v.bit_pos;
using Type = typename dh::detail::AtomicDispatcher<sizeof(value_type)>::Type; using Type = typename dh::detail::AtomicDispatcher<sizeof(value_type)>::Type;
atomicOr(reinterpret_cast<Type *>(&value), set_bit); atomicOr(reinterpret_cast<Type *>(&value), set_bit);
} }
__device__ void Clear(value_type pos) { __device__ void Clear(index_type pos) {
Pos pos_v = Direction::Shift(ToBitPos(pos)); Pos pos_v = Direction::Shift(ToBitPos(pos));
value_type& value = bits_[pos_v.int_pos]; value_type& value = bits_[pos_v.int_pos];
value_type clear_bit = ~(kOne << pos_v.bit_pos); value_type clear_bit = ~(kOne << pos_v.bit_pos);
@ -153,13 +154,13 @@ struct BitFieldContainer {
atomicAnd(reinterpret_cast<Type *>(&value), clear_bit); atomicAnd(reinterpret_cast<Type *>(&value), clear_bit);
} }
#else #else
void Set(value_type pos) { void Set(index_type pos) {
Pos pos_v = Direction::Shift(ToBitPos(pos)); Pos pos_v = Direction::Shift(ToBitPos(pos));
value_type& value = bits_[pos_v.int_pos]; value_type& value = bits_[pos_v.int_pos];
value_type set_bit = kOne << pos_v.bit_pos; value_type set_bit = kOne << pos_v.bit_pos;
value |= set_bit; value |= set_bit;
} }
void Clear(value_type pos) { void Clear(index_type pos) {
Pos pos_v = Direction::Shift(ToBitPos(pos)); Pos pos_v = Direction::Shift(ToBitPos(pos));
value_type& value = bits_[pos_v.int_pos]; value_type& value = bits_[pos_v.int_pos];
value_type clear_bit = ~(kOne << pos_v.bit_pos); value_type clear_bit = ~(kOne << pos_v.bit_pos);
@ -175,7 +176,7 @@ struct BitFieldContainer {
value_type result = test_bit & value; value_type result = test_bit & value;
return static_cast<bool>(result); return static_cast<bool>(result);
} }
XGBOOST_DEVICE bool Check(value_type pos) const { XGBOOST_DEVICE bool Check(index_type pos) const {
Pos pos_v = ToBitPos(pos); Pos pos_v = ToBitPos(pos);
return Check(pos_v); return Check(pos_v);
} }

View File

@ -38,6 +38,14 @@ TEST(BitField, Check) {
ASSERT_FALSE(bits.Check(i)); ASSERT_FALSE(bits.Check(i));
} }
} }
{
// regression test for correct index type.
std::vector<RBitField8::value_type> storage(33, 0);
storage[32] = static_cast<uint8_t>(1);
auto bits = RBitField8({storage.data(), storage.size()});
ASSERT_TRUE(bits.Check(256));
}
} }
template <typename BitFieldT, typename VT = typename BitFieldT::value_type> template <typename BitFieldT, typename VT = typename BitFieldT::value_type>