Span: use size_t' for index_type, add front' and `back'. (#4935)
* Use `size_t' for index_type. Add `front' and `back'. * Remove a batch of `static_cast'.
This commit is contained in:
parent
a9053aff83
commit
b61d534472
@ -31,7 +31,8 @@
|
|||||||
|
|
||||||
#include <xgboost/logging.h> // CHECK
|
#include <xgboost/logging.h> // CHECK
|
||||||
|
|
||||||
#include <cinttypes> // int64_t
|
#include <cinttypes> // size_t
|
||||||
|
#include <numeric> // numeric_limits
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
@ -97,18 +98,20 @@ namespace detail {
|
|||||||
* represent ptrdiff_t, which is just int64_t. So we make it determinstic
|
* represent ptrdiff_t, which is just int64_t. So we make it determinstic
|
||||||
* here.
|
* here.
|
||||||
*/
|
*/
|
||||||
using ptrdiff_t = int64_t; // NOLINT
|
using ptrdiff_t = typename std::conditional<std::is_same<std::ptrdiff_t, std::int64_t>::value,
|
||||||
|
std::ptrdiff_t, std::int64_t>::type;
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
#if defined(_MSC_VER) && _MSC_VER < 1910
|
#if defined(_MSC_VER) && _MSC_VER < 1910
|
||||||
constexpr const detail::ptrdiff_t dynamic_extent = -1; // NOLINT
|
constexpr const std::size_t
|
||||||
|
dynamic_extent = std::numeric_limits<std::size_t>::max(); // NOLINT
|
||||||
#else
|
#else
|
||||||
constexpr detail::ptrdiff_t dynamic_extent = -1; // NOLINT
|
constexpr std::size_t dynamic_extent = std::numeric_limits<std::size_t>::max(); // NOLINT
|
||||||
#endif // defined(_MSC_VER) && _MSC_VER < 1910
|
#endif // defined(_MSC_VER) && _MSC_VER < 1910
|
||||||
|
|
||||||
enum class byte : unsigned char {}; // NOLINT
|
enum class byte : unsigned char {}; // NOLINT
|
||||||
|
|
||||||
template <class ElementType, detail::ptrdiff_t Extent>
|
template <class ElementType, std::size_t Extent>
|
||||||
class Span;
|
class Span;
|
||||||
|
|
||||||
namespace detail {
|
namespace detail {
|
||||||
@ -119,8 +122,8 @@ class SpanIterator {
|
|||||||
|
|
||||||
public:
|
public:
|
||||||
using iterator_category = std::random_access_iterator_tag; // NOLINT
|
using iterator_category = std::random_access_iterator_tag; // NOLINT
|
||||||
using value_type = typename std::remove_cv<ElementType>::type; // NOLINT
|
using value_type = typename SpanType::value_type; // NOLINT
|
||||||
using difference_type = typename SpanType::index_type; // NOLINT
|
using difference_type = detail::ptrdiff_t; // NOLINT
|
||||||
|
|
||||||
using reference = typename std::conditional< // NOLINT
|
using reference = typename std::conditional< // NOLINT
|
||||||
IsConst, const ElementType, ElementType>::type&;
|
IsConst, const ElementType, ElementType>::type&;
|
||||||
@ -153,7 +156,7 @@ class SpanIterator {
|
|||||||
}
|
}
|
||||||
|
|
||||||
XGBOOST_DEVICE SpanIterator& operator++() {
|
XGBOOST_DEVICE SpanIterator& operator++() {
|
||||||
SPAN_CHECK(0 <= index_ && index_ != span_->size());
|
SPAN_CHECK(index_ != span_->size());
|
||||||
index_++;
|
index_++;
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
@ -182,7 +185,7 @@ class SpanIterator {
|
|||||||
}
|
}
|
||||||
|
|
||||||
XGBOOST_DEVICE SpanIterator& operator+=(difference_type n) {
|
XGBOOST_DEVICE SpanIterator& operator+=(difference_type n) {
|
||||||
SPAN_CHECK((index_ + n) >= 0 && (index_ + n) <= span_->size());
|
SPAN_CHECK((index_ + n) <= span_->size());
|
||||||
index_ += n;
|
index_ += n;
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
@ -234,7 +237,7 @@ class SpanIterator {
|
|||||||
|
|
||||||
protected:
|
protected:
|
||||||
const SpanType *span_;
|
const SpanType *span_;
|
||||||
detail::ptrdiff_t index_;
|
typename SpanType::index_type index_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
@ -248,24 +251,22 @@ class SpanIterator {
|
|||||||
* - Otherwise, if Extent is not dynamic_extent, Extent - Offset;
|
* - Otherwise, if Extent is not dynamic_extent, Extent - Offset;
|
||||||
* - Otherwise, dynamic_extent.
|
* - Otherwise, dynamic_extent.
|
||||||
*/
|
*/
|
||||||
template <detail::ptrdiff_t Extent,
|
template <std::size_t Extent, std::size_t Offset, std::size_t Count>
|
||||||
detail::ptrdiff_t Offset,
|
|
||||||
detail::ptrdiff_t Count>
|
|
||||||
struct ExtentValue : public std::integral_constant<
|
struct ExtentValue : public std::integral_constant<
|
||||||
detail::ptrdiff_t, Count != dynamic_extent ?
|
std::size_t, Count != dynamic_extent ?
|
||||||
Count : (Extent != dynamic_extent ? Extent - Offset : Extent)> {};
|
Count : (Extent != dynamic_extent ? Extent - Offset : Extent)> {};
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* If N is dynamic_extent, the extent of the returned span E is also
|
* If N is dynamic_extent, the extent of the returned span E is also
|
||||||
* dynamic_extent; otherwise it is detail::ptrdiff_t(sizeof(T)) * N.
|
* dynamic_extent; otherwise it is std::size_t(sizeof(T)) * N.
|
||||||
*/
|
*/
|
||||||
template <typename T, detail::ptrdiff_t Extent>
|
template <typename T, std::size_t Extent>
|
||||||
struct ExtentAsBytesValue : public std::integral_constant<
|
struct ExtentAsBytesValue : public std::integral_constant<
|
||||||
detail::ptrdiff_t,
|
std::size_t,
|
||||||
Extent == dynamic_extent ?
|
Extent == dynamic_extent ?
|
||||||
Extent : static_cast<detail::ptrdiff_t>(sizeof(T) * Extent)> {};
|
Extent : sizeof(T) * Extent> {};
|
||||||
|
|
||||||
template <detail::ptrdiff_t From, detail::ptrdiff_t To>
|
template <std::size_t From, std::size_t To>
|
||||||
struct IsAllowedExtentConversion : public std::integral_constant<
|
struct IsAllowedExtentConversion : public std::integral_constant<
|
||||||
bool, From == To || From == dynamic_extent || To == dynamic_extent> {};
|
bool, From == To || From == dynamic_extent || To == dynamic_extent> {};
|
||||||
|
|
||||||
@ -276,7 +277,7 @@ struct IsAllowedElementTypeConversion : public std::integral_constant<
|
|||||||
template <class T>
|
template <class T>
|
||||||
struct IsSpanOracle : std::false_type {};
|
struct IsSpanOracle : std::false_type {};
|
||||||
|
|
||||||
template <class T, detail::ptrdiff_t Extent>
|
template <class T, std::size_t Extent>
|
||||||
struct IsSpanOracle<Span<T, Extent>> : std::true_type {};
|
struct IsSpanOracle<Span<T, Extent>> : std::true_type {};
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
@ -385,12 +386,12 @@ XGBOOST_DEVICE bool LexicographicalCompare(InputIt1 first1, InputIt1 last1,
|
|||||||
* passing iterator.
|
* passing iterator.
|
||||||
*/
|
*/
|
||||||
template <typename T,
|
template <typename T,
|
||||||
detail::ptrdiff_t Extent = dynamic_extent>
|
std::size_t Extent = dynamic_extent>
|
||||||
class Span {
|
class Span {
|
||||||
public:
|
public:
|
||||||
using element_type = T; // NOLINT
|
using element_type = T; // NOLINT
|
||||||
using value_type = typename std::remove_cv<T>::type; // NOLINT
|
using value_type = typename std::remove_cv<T>::type; // NOLINT
|
||||||
using index_type = detail::ptrdiff_t; // NOLINT
|
using index_type = std::size_t; // NOLINT
|
||||||
using difference_type = detail::ptrdiff_t; // NOLINT
|
using difference_type = detail::ptrdiff_t; // NOLINT
|
||||||
using pointer = T*; // NOLINT
|
using pointer = T*; // NOLINT
|
||||||
using reference = T&; // NOLINT
|
using reference = T&; // NOLINT
|
||||||
@ -406,13 +407,12 @@ class Span {
|
|||||||
|
|
||||||
XGBOOST_DEVICE Span(pointer _ptr, index_type _count) :
|
XGBOOST_DEVICE Span(pointer _ptr, index_type _count) :
|
||||||
size_(_count), data_(_ptr) {
|
size_(_count), data_(_ptr) {
|
||||||
SPAN_CHECK(_count >= 0);
|
SPAN_CHECK(!(Extent != dynamic_extent && _count != Extent));
|
||||||
SPAN_CHECK(_ptr || _count == 0);
|
SPAN_CHECK(_ptr || _count == 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
XGBOOST_DEVICE Span(pointer _first, pointer _last) :
|
XGBOOST_DEVICE Span(pointer _first, pointer _last) :
|
||||||
size_(_last - _first), data_(_first) {
|
size_(_last - _first), data_(_first) {
|
||||||
SPAN_CHECK(size_ >= 0);
|
|
||||||
SPAN_CHECK(data_ || size_ == 0);
|
SPAN_CHECK(data_ || size_ == 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -441,7 +441,7 @@ class Span {
|
|||||||
XGBOOST_DEVICE Span(const Container& _cont) : size_(_cont.size()), // NOLINT
|
XGBOOST_DEVICE Span(const Container& _cont) : size_(_cont.size()), // NOLINT
|
||||||
data_(_cont.data()) {}
|
data_(_cont.data()) {}
|
||||||
|
|
||||||
template <class U, detail::ptrdiff_t OtherExtent,
|
template <class U, std::size_t OtherExtent,
|
||||||
class = typename std::enable_if<
|
class = typename std::enable_if<
|
||||||
detail::IsAllowedElementTypeConversion<U, T>::value &&
|
detail::IsAllowedElementTypeConversion<U, T>::value &&
|
||||||
detail::IsAllowedExtentConversion<OtherExtent, Extent>::value>>
|
detail::IsAllowedExtentConversion<OtherExtent, Extent>::value>>
|
||||||
@ -491,8 +491,18 @@ class Span {
|
|||||||
return const_reverse_iterator{cbegin()};
|
return const_reverse_iterator{cbegin()};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// element access
|
||||||
|
|
||||||
|
XGBOOST_DEVICE reference front() const {
|
||||||
|
return (*this)[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
XGBOOST_DEVICE reference back() const {
|
||||||
|
return (*this)[size() - 1];
|
||||||
|
}
|
||||||
|
|
||||||
XGBOOST_DEVICE reference operator[](index_type _idx) const {
|
XGBOOST_DEVICE reference operator[](index_type _idx) const {
|
||||||
SPAN_CHECK(_idx >= 0 && _idx < size());
|
SPAN_CHECK(_idx < size());
|
||||||
return data()[_idx];
|
return data()[_idx];
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -517,27 +527,27 @@ class Span {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Subviews
|
// Subviews
|
||||||
template <detail::ptrdiff_t Count >
|
template <std::size_t Count>
|
||||||
XGBOOST_DEVICE Span<element_type, Count> first() const { // NOLINT
|
XGBOOST_DEVICE Span<element_type, Count> first() const { // NOLINT
|
||||||
SPAN_CHECK(Count >= 0 && Count <= size());
|
SPAN_CHECK(Count <= size());
|
||||||
return {data(), Count};
|
return {data(), Count};
|
||||||
}
|
}
|
||||||
|
|
||||||
XGBOOST_DEVICE Span<element_type, dynamic_extent> first( // NOLINT
|
XGBOOST_DEVICE Span<element_type, dynamic_extent> first( // NOLINT
|
||||||
detail::ptrdiff_t _count) const {
|
std::size_t _count) const {
|
||||||
SPAN_CHECK(_count >= 0 && _count <= size());
|
SPAN_CHECK(_count <= size());
|
||||||
return {data(), _count};
|
return {data(), _count};
|
||||||
}
|
}
|
||||||
|
|
||||||
template <detail::ptrdiff_t Count >
|
template <std::size_t Count>
|
||||||
XGBOOST_DEVICE Span<element_type, Count> last() const { // NOLINT
|
XGBOOST_DEVICE Span<element_type, Count> last() const { // NOLINT
|
||||||
SPAN_CHECK(Count >=0 && size() - Count >= 0);
|
SPAN_CHECK(Count <= size());
|
||||||
return {data() + size() - Count, Count};
|
return {data() + size() - Count, Count};
|
||||||
}
|
}
|
||||||
|
|
||||||
XGBOOST_DEVICE Span<element_type, dynamic_extent> last( // NOLINT
|
XGBOOST_DEVICE Span<element_type, dynamic_extent> last( // NOLINT
|
||||||
detail::ptrdiff_t _count) const {
|
std::size_t _count) const {
|
||||||
SPAN_CHECK(_count >= 0 && _count <= size());
|
SPAN_CHECK(_count <= size());
|
||||||
return subspan(size() - _count, _count);
|
return subspan(size() - _count, _count);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -545,24 +555,22 @@ class Span {
|
|||||||
* If Count is std::dynamic_extent, r.size() == this->size() - Offset;
|
* If Count is std::dynamic_extent, r.size() == this->size() - Offset;
|
||||||
* Otherwise r.size() == Count.
|
* Otherwise r.size() == Count.
|
||||||
*/
|
*/
|
||||||
template <detail::ptrdiff_t Offset,
|
template <std::size_t Offset,
|
||||||
detail::ptrdiff_t Count = dynamic_extent>
|
std::size_t Count = dynamic_extent>
|
||||||
XGBOOST_DEVICE auto subspan() const -> // NOLINT
|
XGBOOST_DEVICE auto subspan() const -> // NOLINT
|
||||||
Span<element_type,
|
Span<element_type,
|
||||||
detail::ExtentValue<Extent, Offset, Count>::value> {
|
detail::ExtentValue<Extent, Offset, Count>::value> {
|
||||||
SPAN_CHECK(Offset >= 0 && (Offset < size() || size() == 0));
|
SPAN_CHECK(Offset < size() || size() == 0);
|
||||||
SPAN_CHECK(Count == dynamic_extent ||
|
SPAN_CHECK(Count == dynamic_extent || (Offset + Count <= size()));
|
||||||
(Count >= 0 && Offset + Count <= size()));
|
|
||||||
|
|
||||||
return {data() + Offset, Count == dynamic_extent ? size() - Offset : Count};
|
return {data() + Offset, Count == dynamic_extent ? size() - Offset : Count};
|
||||||
}
|
}
|
||||||
|
|
||||||
XGBOOST_DEVICE Span<element_type, dynamic_extent> subspan( // NOLINT
|
XGBOOST_DEVICE Span<element_type, dynamic_extent> subspan( // NOLINT
|
||||||
detail::ptrdiff_t _offset,
|
index_type _offset,
|
||||||
detail::ptrdiff_t _count = dynamic_extent) const {
|
index_type _count = dynamic_extent) const {
|
||||||
SPAN_CHECK(_offset >= 0 && (_offset < size() || size() == 0));
|
SPAN_CHECK(_offset < size() || size() == 0);
|
||||||
SPAN_CHECK((_count == dynamic_extent) ||
|
SPAN_CHECK((_count == dynamic_extent) || (_offset + _count <= size()));
|
||||||
(_count >= 0 && _offset + _count <= size()));
|
|
||||||
|
|
||||||
return {data() + _offset, _count ==
|
return {data() + _offset, _count ==
|
||||||
dynamic_extent ? size() - _offset : _count};
|
dynamic_extent ? size() - _offset : _count};
|
||||||
@ -573,7 +581,7 @@ class Span {
|
|||||||
pointer data_;
|
pointer data_;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <class T, detail::ptrdiff_t X, class U, detail::ptrdiff_t Y>
|
template <class T, std::size_t X, class U, std::size_t Y>
|
||||||
XGBOOST_DEVICE bool operator==(Span<T, X> l, Span<U, Y> r) {
|
XGBOOST_DEVICE bool operator==(Span<T, X> l, Span<U, Y> r) {
|
||||||
if (l.size() != r.size()) {
|
if (l.size() != r.size()) {
|
||||||
return false;
|
return false;
|
||||||
@ -587,23 +595,23 @@ XGBOOST_DEVICE bool operator==(Span<T, X> l, Span<U, Y> r) {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class T, detail::ptrdiff_t X, class U, detail::ptrdiff_t Y>
|
template <class T, std::size_t X, class U, std::size_t Y>
|
||||||
XGBOOST_DEVICE constexpr bool operator!=(Span<T, X> l, Span<U, Y> r) {
|
XGBOOST_DEVICE constexpr bool operator!=(Span<T, X> l, Span<U, Y> r) {
|
||||||
return !(l == r);
|
return !(l == r);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class T, detail::ptrdiff_t X, class U, detail::ptrdiff_t Y>
|
template <class T, std::size_t X, class U, std::size_t Y>
|
||||||
XGBOOST_DEVICE constexpr bool operator<(Span<T, X> l, Span<U, Y> r) {
|
XGBOOST_DEVICE constexpr bool operator<(Span<T, X> l, Span<U, Y> r) {
|
||||||
return detail::LexicographicalCompare(l.begin(), l.end(),
|
return detail::LexicographicalCompare(l.begin(), l.end(),
|
||||||
r.begin(), r.end());
|
r.begin(), r.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class T, detail::ptrdiff_t X, class U, detail::ptrdiff_t Y>
|
template <class T, std::size_t X, class U, std::size_t Y>
|
||||||
XGBOOST_DEVICE constexpr bool operator<=(Span<T, X> l, Span<U, Y> r) {
|
XGBOOST_DEVICE constexpr bool operator<=(Span<T, X> l, Span<U, Y> r) {
|
||||||
return !(l > r);
|
return !(l > r);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class T, detail::ptrdiff_t X, class U, detail::ptrdiff_t Y>
|
template <class T, std::size_t X, class U, std::size_t Y>
|
||||||
XGBOOST_DEVICE constexpr bool operator>(Span<T, X> l, Span<U, Y> r) {
|
XGBOOST_DEVICE constexpr bool operator>(Span<T, X> l, Span<U, Y> r) {
|
||||||
return detail::LexicographicalCompare<
|
return detail::LexicographicalCompare<
|
||||||
typename Span<T, X>::iterator, typename Span<U, Y>::iterator,
|
typename Span<T, X>::iterator, typename Span<U, Y>::iterator,
|
||||||
@ -611,18 +619,18 @@ XGBOOST_DEVICE constexpr bool operator>(Span<T, X> l, Span<U, Y> r) {
|
|||||||
r.begin(), r.end());
|
r.begin(), r.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class T, detail::ptrdiff_t X, class U, detail::ptrdiff_t Y>
|
template <class T, std::size_t X, class U, std::size_t Y>
|
||||||
XGBOOST_DEVICE constexpr bool operator>=(Span<T, X> l, Span<U, Y> r) {
|
XGBOOST_DEVICE constexpr bool operator>=(Span<T, X> l, Span<U, Y> r) {
|
||||||
return !(l < r);
|
return !(l < r);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class T, detail::ptrdiff_t E>
|
template <class T, std::size_t E>
|
||||||
XGBOOST_DEVICE auto as_bytes(Span<T, E> s) __span_noexcept -> // NOLINT
|
XGBOOST_DEVICE auto as_bytes(Span<T, E> s) __span_noexcept -> // NOLINT
|
||||||
Span<const byte, detail::ExtentAsBytesValue<T, E>::value> {
|
Span<const byte, detail::ExtentAsBytesValue<T, E>::value> {
|
||||||
return {reinterpret_cast<const byte*>(s.data()), s.size_bytes()};
|
return {reinterpret_cast<const byte*>(s.data()), s.size_bytes()};
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class T, detail::ptrdiff_t E>
|
template <class T, std::size_t E>
|
||||||
XGBOOST_DEVICE auto as_writable_bytes(Span<T, E> s) __span_noexcept -> // NOLINT
|
XGBOOST_DEVICE auto as_writable_bytes(Span<T, E> s) __span_noexcept -> // NOLINT
|
||||||
Span<byte, detail::ExtentAsBytesValue<T, E>::value> {
|
Span<byte, detail::ExtentAsBytesValue<T, E>::value> {
|
||||||
return {reinterpret_cast<byte*>(s.data()), s.size_bytes()};
|
return {reinterpret_cast<byte*>(s.data()), s.size_bytes()};
|
||||||
|
|||||||
@ -380,9 +380,7 @@ class DoubleBuffer {
|
|||||||
|
|
||||||
T *Current() { return buff.Current(); }
|
T *Current() { return buff.Current(); }
|
||||||
xgboost::common::Span<T> CurrentSpan() {
|
xgboost::common::Span<T> CurrentSpan() {
|
||||||
return xgboost::common::Span<T>{
|
return xgboost::common::Span<T>{buff.Current(), Size()};
|
||||||
buff.Current(),
|
|
||||||
static_cast<typename xgboost::common::Span<T>::index_type>(Size())};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
T *other() { return buff.Alternate(); }
|
T *other() { return buff.Alternate(); }
|
||||||
@ -1120,17 +1118,16 @@ template <typename T,
|
|||||||
xgboost::common::Span<T> ToSpan(
|
xgboost::common::Span<T> ToSpan(
|
||||||
device_vector<T>& vec,
|
device_vector<T>& vec,
|
||||||
IndexT offset = 0,
|
IndexT offset = 0,
|
||||||
IndexT size = -1) {
|
IndexT size = std::numeric_limits<size_t>::max()) {
|
||||||
size = size == -1 ? vec.size() : size;
|
size = size == std::numeric_limits<size_t>::max() ? vec.size() : size;
|
||||||
CHECK_LE(offset + size, vec.size());
|
CHECK_LE(offset + size, vec.size());
|
||||||
return {vec.data().get() + offset, static_cast<IndexT>(size)};
|
return {vec.data().get() + offset, size};
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
xgboost::common::Span<T> ToSpan(thrust::device_vector<T>& vec,
|
xgboost::common::Span<T> ToSpan(thrust::device_vector<T>& vec,
|
||||||
size_t offset, size_t size) {
|
size_t offset, size_t size) {
|
||||||
using IndexT = typename xgboost::common::Span<T>::index_type;
|
return ToSpan(vec, offset, size);
|
||||||
return ToSpan(vec, static_cast<IndexT>(offset), static_cast<IndexT>(size));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// thrust begin, similiar to std::begin
|
// thrust begin, similiar to std::begin
|
||||||
|
|||||||
@ -343,7 +343,7 @@ struct GHistIndexBlock {
|
|||||||
|
|
||||||
// get i-th row
|
// get i-th row
|
||||||
inline GHistIndexRow operator[](size_t i) const {
|
inline GHistIndexRow operator[](size_t i) const {
|
||||||
return {&index[0] + row_ptr[i], detail::ptrdiff_t(row_ptr[i + 1] - row_ptr[i])};
|
return {&index[0] + row_ptr[i], row_ptr[i + 1] - row_ptr[i]};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -69,13 +69,12 @@ class HostDeviceVectorImpl {
|
|||||||
|
|
||||||
common::Span<T> DeviceSpan() {
|
common::Span<T> DeviceSpan() {
|
||||||
LazySyncDevice(GPUAccess::kWrite);
|
LazySyncDevice(GPUAccess::kWrite);
|
||||||
return {data_d_.data().get(), static_cast<typename common::Span<T>::index_type>(Size())};
|
return {data_d_.data().get(), Size()};
|
||||||
}
|
}
|
||||||
|
|
||||||
common::Span<const T> ConstDeviceSpan() {
|
common::Span<const T> ConstDeviceSpan() {
|
||||||
LazySyncDevice(GPUAccess::kRead);
|
LazySyncDevice(GPUAccess::kRead);
|
||||||
using SpanInd = typename common::Span<const T>::index_type;
|
return {data_d_.data().get(), Size()};
|
||||||
return {data_d_.data().get(), static_cast<SpanInd>(Size())};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Fill(T v) { // NOLINT
|
void Fill(T v) { // NOLINT
|
||||||
|
|||||||
@ -165,6 +165,14 @@ class ArrayInterfaceHandler {
|
|||||||
auto typestr = get<String const>(j_mask.at("typestr"));
|
auto typestr = get<String const>(j_mask.at("typestr"));
|
||||||
// For now this is just 1, we can support different size of interger in mask.
|
// For now this is just 1, we can support different size of interger in mask.
|
||||||
int64_t const type_length = typestr.at(2) - 48;
|
int64_t const type_length = typestr.at(2) - 48;
|
||||||
|
|
||||||
|
if (typestr.at(1) == 't') {
|
||||||
|
CHECK_EQ(type_length, 1) << "mask with bitfield type should be of 1 byte per bitfield.";
|
||||||
|
} else if (typestr.at(1) == 'i') {
|
||||||
|
CHECK_EQ(type_length, 1) << "mask with integer type should be of 1 byte per integer.";
|
||||||
|
} else {
|
||||||
|
LOG(FATAL) << "mask must be of integer type or bit field type.";
|
||||||
|
}
|
||||||
/*
|
/*
|
||||||
* shape represents how many bits is in the mask. (This is a grey area, don't be
|
* shape represents how many bits is in the mask. (This is a grey area, don't be
|
||||||
* suprised if it suddently represents something else when supporting a new
|
* suprised if it suddently represents something else when supporting a new
|
||||||
@ -175,10 +183,10 @@ class ArrayInterfaceHandler {
|
|||||||
*
|
*
|
||||||
* And that's the only requirement.
|
* And that's the only requirement.
|
||||||
*/
|
*/
|
||||||
int64_t const n_bits = get<Integer>(j_shape.at(0));
|
size_t const n_bits = static_cast<size_t>(get<Integer>(j_shape.at(0)));
|
||||||
// The size of span required to cover all bits. Here with 8 bits bitfield, we
|
// The size of span required to cover all bits. Here with 8 bits bitfield, we
|
||||||
// assume 1 byte alignment.
|
// assume 1 byte alignment.
|
||||||
int64_t const span_size = RBitField8::ComputeStorageSize(n_bits);
|
size_t const span_size = RBitField8::ComputeStorageSize(n_bits);
|
||||||
|
|
||||||
if (j_mask.find("strides") != j_mask.cend()) {
|
if (j_mask.find("strides") != j_mask.cend()) {
|
||||||
auto strides = get<Array const>(column.at("strides"));
|
auto strides = get<Array const>(column.at("strides"));
|
||||||
@ -186,14 +194,6 @@ class ArrayInterfaceHandler {
|
|||||||
CHECK_EQ(get<Integer>(strides.at(0)), type_length) << ColumnarErrors::Contigious();
|
CHECK_EQ(get<Integer>(strides.at(0)), type_length) << ColumnarErrors::Contigious();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (typestr.at(1) == 't') {
|
|
||||||
CHECK_EQ(typestr.at(2), '1') << "mask with bitfield type should be of 1 byte per bitfield.";
|
|
||||||
} else if (typestr.at(1) == 'i') {
|
|
||||||
CHECK_EQ(typestr.at(2), '1') << "mask with integer type should be of 1 byte per integer.";
|
|
||||||
} else {
|
|
||||||
LOG(FATAL) << "mask must be of integer type or bit field type.";
|
|
||||||
}
|
|
||||||
|
|
||||||
s_mask = {p_mask, span_size};
|
s_mask = {p_mask, span_size};
|
||||||
return n_bits;
|
return n_bits;
|
||||||
}
|
}
|
||||||
@ -219,7 +219,7 @@ class ArrayInterfaceHandler {
|
|||||||
CHECK_EQ(get<Integer>(strides.at(0)), sizeof(T)) << ColumnarErrors::Contigious();
|
CHECK_EQ(get<Integer>(strides.at(0)), sizeof(T)) << ColumnarErrors::Contigious();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto length = get<Integer const>(j_shape.at(0));
|
auto length = static_cast<size_t>(get<Integer const>(j_shape.at(0)));
|
||||||
|
|
||||||
T* p_data = ArrayInterfaceHandler::GetPtrFromArrayData<T*>(column);
|
T* p_data = ArrayInterfaceHandler::GetPtrFromArrayData<T*>(column);
|
||||||
return common::Span<T>{p_data, length};
|
return common::Span<T>{p_data, length};
|
||||||
|
|||||||
@ -98,7 +98,8 @@ TEST(Span, FromPtrLen) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
EXPECT_ANY_THROW(Span<float> tmp (arr, -1););
|
auto lazy = [=]() {Span<float const, 16> tmp (arr, 5);};
|
||||||
|
EXPECT_ANY_THROW(lazy());
|
||||||
}
|
}
|
||||||
|
|
||||||
// dynamic extent
|
// dynamic extent
|
||||||
@ -298,6 +299,32 @@ TEST(Span, Obversers) {
|
|||||||
ASSERT_EQ(status, 1);
|
ASSERT_EQ(status, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(Span, FrontBack) {
|
||||||
|
{
|
||||||
|
float arr[4] {0, 1, 2, 3};
|
||||||
|
Span<float, 4> s(arr);
|
||||||
|
ASSERT_EQ(s.front(), 0);
|
||||||
|
ASSERT_EQ(s.back(), 3);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
std::vector<double> arr {0, 1, 2, 3};
|
||||||
|
Span<double> s(arr);
|
||||||
|
ASSERT_EQ(s.front(), 0);
|
||||||
|
ASSERT_EQ(s.back(), 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
Span<float, 0> s;
|
||||||
|
EXPECT_ANY_THROW(s.front());
|
||||||
|
EXPECT_ANY_THROW(s.back());
|
||||||
|
}
|
||||||
|
{
|
||||||
|
Span<float> s;
|
||||||
|
EXPECT_ANY_THROW(s.front());
|
||||||
|
EXPECT_ANY_THROW(s.back());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST(Span, FirstLast) {
|
TEST(Span, FirstLast) {
|
||||||
// static extent
|
// static extent
|
||||||
{
|
{
|
||||||
@ -310,11 +337,11 @@ TEST(Span, FirstLast) {
|
|||||||
ASSERT_EQ(first.size(), 4);
|
ASSERT_EQ(first.size(), 4);
|
||||||
ASSERT_EQ(first.data(), arr);
|
ASSERT_EQ(first.data(), arr);
|
||||||
|
|
||||||
for (int64_t i = 0; i < first.size(); ++i) {
|
for (size_t i = 0; i < first.size(); ++i) {
|
||||||
ASSERT_EQ(first[i], arr[i]);
|
ASSERT_EQ(first[i], arr[i]);
|
||||||
}
|
}
|
||||||
|
auto constexpr kOne = static_cast<Span<float, 4>::index_type>(-1);
|
||||||
EXPECT_ANY_THROW(s.first<-1>());
|
EXPECT_ANY_THROW(s.first<kOne>());
|
||||||
EXPECT_ANY_THROW(s.first<17>());
|
EXPECT_ANY_THROW(s.first<17>());
|
||||||
EXPECT_ANY_THROW(s.first<32>());
|
EXPECT_ANY_THROW(s.first<32>());
|
||||||
}
|
}
|
||||||
@ -329,11 +356,11 @@ TEST(Span, FirstLast) {
|
|||||||
ASSERT_EQ(last.size(), 4);
|
ASSERT_EQ(last.size(), 4);
|
||||||
ASSERT_EQ(last.data(), arr + 12);
|
ASSERT_EQ(last.data(), arr + 12);
|
||||||
|
|
||||||
for (int64_t i = 0; i < last.size(); ++i) {
|
for (size_t i = 0; i < last.size(); ++i) {
|
||||||
ASSERT_EQ(last[i], arr[i+12]);
|
ASSERT_EQ(last[i], arr[i+12]);
|
||||||
}
|
}
|
||||||
|
auto constexpr kOne = static_cast<Span<float, 4>::index_type>(-1);
|
||||||
EXPECT_ANY_THROW(s.last<-1>());
|
EXPECT_ANY_THROW(s.last<kOne>());
|
||||||
EXPECT_ANY_THROW(s.last<17>());
|
EXPECT_ANY_THROW(s.last<17>());
|
||||||
EXPECT_ANY_THROW(s.last<32>());
|
EXPECT_ANY_THROW(s.last<32>());
|
||||||
}
|
}
|
||||||
@ -348,7 +375,7 @@ TEST(Span, FirstLast) {
|
|||||||
ASSERT_EQ(first.size(), 4);
|
ASSERT_EQ(first.size(), 4);
|
||||||
ASSERT_EQ(first.data(), s.data());
|
ASSERT_EQ(first.data(), s.data());
|
||||||
|
|
||||||
for (int64_t i = 0; i < first.size(); ++i) {
|
for (size_t i = 0; i < first.size(); ++i) {
|
||||||
ASSERT_EQ(first[i], s[i]);
|
ASSERT_EQ(first[i], s[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -368,7 +395,7 @@ TEST(Span, FirstLast) {
|
|||||||
ASSERT_EQ(last.size(), 4);
|
ASSERT_EQ(last.size(), 4);
|
||||||
ASSERT_EQ(last.data(), s.data() + 12);
|
ASSERT_EQ(last.data(), s.data() + 12);
|
||||||
|
|
||||||
for (int64_t i = 0; i < last.size(); ++i) {
|
for (size_t i = 0; i < last.size(); ++i) {
|
||||||
ASSERT_EQ(s[12 + i], last[i]);
|
ASSERT_EQ(s[12 + i], last[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -397,7 +424,8 @@ TEST(Span, Subspan) {
|
|||||||
EXPECT_ANY_THROW(s1.subspan(-1, 0));
|
EXPECT_ANY_THROW(s1.subspan(-1, 0));
|
||||||
EXPECT_ANY_THROW(s1.subspan(16, 0));
|
EXPECT_ANY_THROW(s1.subspan(16, 0));
|
||||||
|
|
||||||
EXPECT_ANY_THROW(s1.subspan<-1>());
|
auto constexpr kOne = static_cast<Span<int, 4>::index_type>(-1);
|
||||||
|
EXPECT_ANY_THROW(s1.subspan<kOne>());
|
||||||
EXPECT_ANY_THROW(s1.subspan<16>());
|
EXPECT_ANY_THROW(s1.subspan<16>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -240,16 +240,16 @@ TEST(GPUSpan, ElementAccess) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
__global__ void TestFirstDynamicKernel(Span<float> _span) {
|
__global__ void TestFirstDynamicKernel(Span<float> _span) {
|
||||||
_span.first<-1>();
|
_span.first<static_cast<Span<float>::index_type>(-1)>();
|
||||||
}
|
}
|
||||||
__global__ void TestFirstStaticKernel(Span<float> _span) {
|
__global__ void TestFirstStaticKernel(Span<float> _span) {
|
||||||
_span.first(-1);
|
_span.first(static_cast<Span<float>::index_type>(-1));
|
||||||
}
|
}
|
||||||
__global__ void TestLastDynamicKernel(Span<float> _span) {
|
__global__ void TestLastDynamicKernel(Span<float> _span) {
|
||||||
_span.last<-1>();
|
_span.last<static_cast<Span<float>::index_type>(-1)>();
|
||||||
}
|
}
|
||||||
__global__ void TestLastStaticKernel(Span<float> _span) {
|
__global__ void TestLastStaticKernel(Span<float> _span) {
|
||||||
_span.last(-1);
|
_span.last(static_cast<Span<float>::index_type>(-1));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(GPUSpan, FirstLast) {
|
TEST(GPUSpan, FirstLast) {
|
||||||
@ -312,6 +312,41 @@ TEST(GPUSpan, FirstLast) {
|
|||||||
output = testing::internal::GetCapturedStdout();
|
output = testing::internal::GetCapturedStdout();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__global__ void TestFrontKernel(Span<float> _span) {
|
||||||
|
_span.front();
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void TestBackKernel(Span<float> _span) {
|
||||||
|
_span.back();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(GPUSpan, FrontBack) {
|
||||||
|
dh::safe_cuda(cudaSetDevice(0));
|
||||||
|
|
||||||
|
Span<float> s;
|
||||||
|
auto lambda_test_front = [=]() {
|
||||||
|
// make sure the termination happens inside this test.
|
||||||
|
try {
|
||||||
|
TestFrontKernel<<<1, 1>>>(s);
|
||||||
|
dh::safe_cuda(cudaDeviceSynchronize());
|
||||||
|
dh::safe_cuda(cudaGetLastError());
|
||||||
|
} catch (dmlc::Error const& e) {
|
||||||
|
std::terminate();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
EXPECT_DEATH(lambda_test_front(), "");
|
||||||
|
|
||||||
|
auto lambda_test_back = [=]() {
|
||||||
|
try {
|
||||||
|
TestBackKernel<<<1, 1>>>(s);
|
||||||
|
dh::safe_cuda(cudaDeviceSynchronize());
|
||||||
|
dh::safe_cuda(cudaGetLastError());
|
||||||
|
} catch (dmlc::Error const& e) {
|
||||||
|
std::terminate();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
EXPECT_DEATH(lambda_test_back(), "");
|
||||||
|
}
|
||||||
|
|
||||||
__global__ void TestSubspanDynamicKernel(Span<float> _span) {
|
__global__ void TestSubspanDynamicKernel(Span<float> _span) {
|
||||||
_span.subspan(16, 0);
|
_span.subspan(16, 0);
|
||||||
|
|||||||
@ -50,7 +50,7 @@ TEST(SparsePage, PushCSC) {
|
|||||||
inst = page[1];
|
inst = page[1];
|
||||||
ASSERT_EQ(inst.size(), 6);
|
ASSERT_EQ(inst.size(), 6);
|
||||||
std::vector<size_t> indices_sol {1, 2, 3};
|
std::vector<size_t> indices_sol {1, 2, 3};
|
||||||
for (int64_t i = 0; i < inst.size(); ++i) {
|
for (size_t i = 0; i < inst.size(); ++i) {
|
||||||
ASSERT_EQ(inst[i].index, indices_sol[i % 3]);
|
ASSERT_EQ(inst[i].index, indices_sol[i % 3]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -76,7 +76,7 @@ TEST(SparsePage, PushCSCAfterTranspose) {
|
|||||||
// how the dmatrix has been created
|
// how the dmatrix has been created
|
||||||
for (size_t i = 0; i < page.Size(); ++i) {
|
for (size_t i = 0; i < page.Size(); ++i) {
|
||||||
auto inst = page[i];
|
auto inst = page[i];
|
||||||
for (int j = 1; j < inst.size(); ++j) {
|
for (size_t j = 1; j < inst.size(); ++j) {
|
||||||
ASSERT_EQ(inst[0].fvalue, inst[j].fvalue);
|
ASSERT_EQ(inst[0].fvalue, inst[j].fvalue);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -73,7 +73,7 @@ class QuantileHistMock : public QuantileHistMaker {
|
|||||||
ASSERT_LT(gmat_row_offset, gmat.index.size());
|
ASSERT_LT(gmat_row_offset, gmat.index.size());
|
||||||
SparsePage::Inst inst = batch[i];
|
SparsePage::Inst inst = batch[i];
|
||||||
ASSERT_EQ(gmat.row_ptr[rid] + inst.size(), gmat.row_ptr[rid + 1]);
|
ASSERT_EQ(gmat.row_ptr[rid] + inst.size(), gmat.row_ptr[rid + 1]);
|
||||||
for (int64_t j = 0; j < inst.size(); ++j) {
|
for (size_t j = 0; j < inst.size(); ++j) {
|
||||||
// Each entry of GHistIndexMatrix represents a bin ID
|
// Each entry of GHistIndexMatrix represents a bin ID
|
||||||
const size_t bin_id = gmat.index[gmat_row_offset + j];
|
const size_t bin_id = gmat.index[gmat_row_offset + j];
|
||||||
const size_t fid = inst[j].index;
|
const size_t fid = inst[j].index;
|
||||||
@ -129,7 +129,7 @@ class QuantileHistMock : public QuantileHistMaker {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Now validate the computed histogram returned by BuildHist
|
// Now validate the computed histogram returned by BuildHist
|
||||||
for (int64_t i = 0; i < hist_[nid].size(); ++i) {
|
for (size_t i = 0; i < hist_[nid].size(); ++i) {
|
||||||
GradientPairPrecise sol = histogram_expected[i];
|
GradientPairPrecise sol = histogram_expected[i];
|
||||||
ASSERT_NEAR(sol.GetGrad(), hist_[nid][i].GetGrad(), kEps);
|
ASSERT_NEAR(sol.GetGrad(), hist_[nid][i].GetGrad(), kEps);
|
||||||
ASSERT_NEAR(sol.GetHess(), hist_[nid][i].GetHess(), kEps);
|
ASSERT_NEAR(sol.GetHess(), hist_[nid][i].GetHess(), kEps);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user