Span: use size_t' for index_type, add front' and `back'. (#4935)
* Use `size_t' for index_type. Add `front' and `back'. * Remove a batch of `static_cast'.
This commit is contained in:
@@ -380,9 +380,7 @@ class DoubleBuffer {
|
||||
|
||||
T *Current() { return buff.Current(); }
|
||||
xgboost::common::Span<T> CurrentSpan() {
|
||||
return xgboost::common::Span<T>{
|
||||
buff.Current(),
|
||||
static_cast<typename xgboost::common::Span<T>::index_type>(Size())};
|
||||
return xgboost::common::Span<T>{buff.Current(), Size()};
|
||||
}
|
||||
|
||||
T *other() { return buff.Alternate(); }
|
||||
@@ -1120,17 +1118,16 @@ template <typename T,
|
||||
xgboost::common::Span<T> ToSpan(
|
||||
device_vector<T>& vec,
|
||||
IndexT offset = 0,
|
||||
IndexT size = -1) {
|
||||
size = size == -1 ? vec.size() : size;
|
||||
IndexT size = std::numeric_limits<size_t>::max()) {
|
||||
size = size == std::numeric_limits<size_t>::max() ? vec.size() : size;
|
||||
CHECK_LE(offset + size, vec.size());
|
||||
return {vec.data().get() + offset, static_cast<IndexT>(size)};
|
||||
return {vec.data().get() + offset, size};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
xgboost::common::Span<T> ToSpan(thrust::device_vector<T>& vec,
|
||||
size_t offset, size_t size) {
|
||||
using IndexT = typename xgboost::common::Span<T>::index_type;
|
||||
return ToSpan(vec, static_cast<IndexT>(offset), static_cast<IndexT>(size));
|
||||
return ToSpan(vec, offset, size);
|
||||
}
|
||||
|
||||
// thrust begin, similiar to std::begin
|
||||
|
||||
@@ -343,7 +343,7 @@ struct GHistIndexBlock {
|
||||
|
||||
// get i-th row
|
||||
inline GHistIndexRow operator[](size_t i) const {
|
||||
return {&index[0] + row_ptr[i], detail::ptrdiff_t(row_ptr[i + 1] - row_ptr[i])};
|
||||
return {&index[0] + row_ptr[i], row_ptr[i + 1] - row_ptr[i]};
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -69,13 +69,12 @@ class HostDeviceVectorImpl {
|
||||
|
||||
common::Span<T> DeviceSpan() {
|
||||
LazySyncDevice(GPUAccess::kWrite);
|
||||
return {data_d_.data().get(), static_cast<typename common::Span<T>::index_type>(Size())};
|
||||
return {data_d_.data().get(), Size()};
|
||||
}
|
||||
|
||||
common::Span<const T> ConstDeviceSpan() {
|
||||
LazySyncDevice(GPUAccess::kRead);
|
||||
using SpanInd = typename common::Span<const T>::index_type;
|
||||
return {data_d_.data().get(), static_cast<SpanInd>(Size())};
|
||||
return {data_d_.data().get(), Size()};
|
||||
}
|
||||
|
||||
void Fill(T v) { // NOLINT
|
||||
|
||||
@@ -165,6 +165,14 @@ class ArrayInterfaceHandler {
|
||||
auto typestr = get<String const>(j_mask.at("typestr"));
|
||||
// For now this is just 1, we can support different size of interger in mask.
|
||||
int64_t const type_length = typestr.at(2) - 48;
|
||||
|
||||
if (typestr.at(1) == 't') {
|
||||
CHECK_EQ(type_length, 1) << "mask with bitfield type should be of 1 byte per bitfield.";
|
||||
} else if (typestr.at(1) == 'i') {
|
||||
CHECK_EQ(type_length, 1) << "mask with integer type should be of 1 byte per integer.";
|
||||
} else {
|
||||
LOG(FATAL) << "mask must be of integer type or bit field type.";
|
||||
}
|
||||
/*
|
||||
* shape represents how many bits is in the mask. (This is a grey area, don't be
|
||||
* suprised if it suddently represents something else when supporting a new
|
||||
@@ -175,10 +183,10 @@ class ArrayInterfaceHandler {
|
||||
*
|
||||
* And that's the only requirement.
|
||||
*/
|
||||
int64_t const n_bits = get<Integer>(j_shape.at(0));
|
||||
size_t const n_bits = static_cast<size_t>(get<Integer>(j_shape.at(0)));
|
||||
// The size of span required to cover all bits. Here with 8 bits bitfield, we
|
||||
// assume 1 byte alignment.
|
||||
int64_t const span_size = RBitField8::ComputeStorageSize(n_bits);
|
||||
size_t const span_size = RBitField8::ComputeStorageSize(n_bits);
|
||||
|
||||
if (j_mask.find("strides") != j_mask.cend()) {
|
||||
auto strides = get<Array const>(column.at("strides"));
|
||||
@@ -186,14 +194,6 @@ class ArrayInterfaceHandler {
|
||||
CHECK_EQ(get<Integer>(strides.at(0)), type_length) << ColumnarErrors::Contigious();
|
||||
}
|
||||
|
||||
if (typestr.at(1) == 't') {
|
||||
CHECK_EQ(typestr.at(2), '1') << "mask with bitfield type should be of 1 byte per bitfield.";
|
||||
} else if (typestr.at(1) == 'i') {
|
||||
CHECK_EQ(typestr.at(2), '1') << "mask with integer type should be of 1 byte per integer.";
|
||||
} else {
|
||||
LOG(FATAL) << "mask must be of integer type or bit field type.";
|
||||
}
|
||||
|
||||
s_mask = {p_mask, span_size};
|
||||
return n_bits;
|
||||
}
|
||||
@@ -219,7 +219,7 @@ class ArrayInterfaceHandler {
|
||||
CHECK_EQ(get<Integer>(strides.at(0)), sizeof(T)) << ColumnarErrors::Contigious();
|
||||
}
|
||||
|
||||
auto length = get<Integer const>(j_shape.at(0));
|
||||
auto length = static_cast<size_t>(get<Integer const>(j_shape.at(0)));
|
||||
|
||||
T* p_data = ArrayInterfaceHandler::GetPtrFromArrayData<T*>(column);
|
||||
return common::Span<T>{p_data, length};
|
||||
|
||||
Reference in New Issue
Block a user