Fix index type for bitfield. (#7541)
This commit is contained in:
parent
0df2ae63c7
commit
91c1a1c52f
@ -58,14 +58,15 @@ __forceinline__ __device__ BitFieldAtomicType AtomicAnd(BitFieldAtomicType* addr
|
|||||||
template <typename VT, typename Direction, bool IsConst = false>
|
template <typename VT, typename Direction, bool IsConst = false>
|
||||||
struct BitFieldContainer {
|
struct BitFieldContainer {
|
||||||
using value_type = std::conditional_t<IsConst, VT const, VT>; // NOLINT
|
using value_type = std::conditional_t<IsConst, VT const, VT>; // NOLINT
|
||||||
using pointer = value_type*; // NOLINT
|
using index_type = size_t; // NOLINT
|
||||||
|
using pointer = value_type*; // NOLINT
|
||||||
|
|
||||||
static value_type constexpr kValueSize = sizeof(value_type) * 8;
|
static index_type constexpr kValueSize = sizeof(value_type) * 8;
|
||||||
static value_type constexpr kOne = 1; // force correct type.
|
static index_type constexpr kOne = 1; // force correct type.
|
||||||
|
|
||||||
struct Pos {
|
struct Pos {
|
||||||
std::remove_const_t<value_type> int_pos {0};
|
index_type int_pos{0};
|
||||||
std::remove_const_t<value_type> bit_pos {0};
|
index_type bit_pos{0};
|
||||||
};
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -73,13 +74,13 @@ struct BitFieldContainer {
|
|||||||
static_assert(!std::is_signed<VT>::value, "Must use unsiged type as underlying storage.");
|
static_assert(!std::is_signed<VT>::value, "Must use unsiged type as underlying storage.");
|
||||||
|
|
||||||
public:
|
public:
|
||||||
XGBOOST_DEVICE static Pos ToBitPos(value_type pos) {
|
XGBOOST_DEVICE static Pos ToBitPos(index_type pos) {
|
||||||
Pos pos_v;
|
Pos pos_v;
|
||||||
if (pos == 0) {
|
if (pos == 0) {
|
||||||
return pos_v;
|
return pos_v;
|
||||||
}
|
}
|
||||||
pos_v.int_pos = pos / kValueSize;
|
pos_v.int_pos = pos / kValueSize;
|
||||||
pos_v.bit_pos = pos % kValueSize;
|
pos_v.bit_pos = pos % kValueSize;
|
||||||
return pos_v;
|
return pos_v;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -96,7 +97,7 @@ struct BitFieldContainer {
|
|||||||
/*\brief Compute the size of needed memory allocation. The returned value is in terms
|
/*\brief Compute the size of needed memory allocation. The returned value is in terms
|
||||||
* of number of elements with `BitFieldContainer::value_type'.
|
* of number of elements with `BitFieldContainer::value_type'.
|
||||||
*/
|
*/
|
||||||
XGBOOST_DEVICE static size_t ComputeStorageSize(size_t size) {
|
XGBOOST_DEVICE static size_t ComputeStorageSize(index_type size) {
|
||||||
return common::DivRoundUp(size, kValueSize);
|
return common::DivRoundUp(size, kValueSize);
|
||||||
}
|
}
|
||||||
#if defined(__CUDA_ARCH__)
|
#if defined(__CUDA_ARCH__)
|
||||||
@ -138,14 +139,14 @@ struct BitFieldContainer {
|
|||||||
#endif // defined(__CUDA_ARCH__)
|
#endif // defined(__CUDA_ARCH__)
|
||||||
|
|
||||||
#if defined(__CUDA_ARCH__)
|
#if defined(__CUDA_ARCH__)
|
||||||
__device__ auto Set(value_type pos) {
|
__device__ auto Set(index_type pos) {
|
||||||
Pos pos_v = Direction::Shift(ToBitPos(pos));
|
Pos pos_v = Direction::Shift(ToBitPos(pos));
|
||||||
value_type& value = bits_[pos_v.int_pos];
|
value_type& value = bits_[pos_v.int_pos];
|
||||||
value_type set_bit = kOne << pos_v.bit_pos;
|
value_type set_bit = kOne << pos_v.bit_pos;
|
||||||
using Type = typename dh::detail::AtomicDispatcher<sizeof(value_type)>::Type;
|
using Type = typename dh::detail::AtomicDispatcher<sizeof(value_type)>::Type;
|
||||||
atomicOr(reinterpret_cast<Type *>(&value), set_bit);
|
atomicOr(reinterpret_cast<Type *>(&value), set_bit);
|
||||||
}
|
}
|
||||||
__device__ void Clear(value_type pos) {
|
__device__ void Clear(index_type pos) {
|
||||||
Pos pos_v = Direction::Shift(ToBitPos(pos));
|
Pos pos_v = Direction::Shift(ToBitPos(pos));
|
||||||
value_type& value = bits_[pos_v.int_pos];
|
value_type& value = bits_[pos_v.int_pos];
|
||||||
value_type clear_bit = ~(kOne << pos_v.bit_pos);
|
value_type clear_bit = ~(kOne << pos_v.bit_pos);
|
||||||
@ -153,13 +154,13 @@ struct BitFieldContainer {
|
|||||||
atomicAnd(reinterpret_cast<Type *>(&value), clear_bit);
|
atomicAnd(reinterpret_cast<Type *>(&value), clear_bit);
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
void Set(value_type pos) {
|
void Set(index_type pos) {
|
||||||
Pos pos_v = Direction::Shift(ToBitPos(pos));
|
Pos pos_v = Direction::Shift(ToBitPos(pos));
|
||||||
value_type& value = bits_[pos_v.int_pos];
|
value_type& value = bits_[pos_v.int_pos];
|
||||||
value_type set_bit = kOne << pos_v.bit_pos;
|
value_type set_bit = kOne << pos_v.bit_pos;
|
||||||
value |= set_bit;
|
value |= set_bit;
|
||||||
}
|
}
|
||||||
void Clear(value_type pos) {
|
void Clear(index_type pos) {
|
||||||
Pos pos_v = Direction::Shift(ToBitPos(pos));
|
Pos pos_v = Direction::Shift(ToBitPos(pos));
|
||||||
value_type& value = bits_[pos_v.int_pos];
|
value_type& value = bits_[pos_v.int_pos];
|
||||||
value_type clear_bit = ~(kOne << pos_v.bit_pos);
|
value_type clear_bit = ~(kOne << pos_v.bit_pos);
|
||||||
@ -175,7 +176,7 @@ struct BitFieldContainer {
|
|||||||
value_type result = test_bit & value;
|
value_type result = test_bit & value;
|
||||||
return static_cast<bool>(result);
|
return static_cast<bool>(result);
|
||||||
}
|
}
|
||||||
XGBOOST_DEVICE bool Check(value_type pos) const {
|
XGBOOST_DEVICE bool Check(index_type pos) const {
|
||||||
Pos pos_v = ToBitPos(pos);
|
Pos pos_v = ToBitPos(pos);
|
||||||
return Check(pos_v);
|
return Check(pos_v);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -38,6 +38,14 @@ TEST(BitField, Check) {
|
|||||||
ASSERT_FALSE(bits.Check(i));
|
ASSERT_FALSE(bits.Check(i));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// regression test for correct index type.
|
||||||
|
std::vector<RBitField8::value_type> storage(33, 0);
|
||||||
|
storage[32] = static_cast<uint8_t>(1);
|
||||||
|
auto bits = RBitField8({storage.data(), storage.size()});
|
||||||
|
ASSERT_TRUE(bits.Check(256));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename BitFieldT, typename VT = typename BitFieldT::value_type>
|
template <typename BitFieldT, typename VT = typename BitFieldT::value_type>
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user