Workaround CUDA warning. (#8696)

This commit is contained in:
Jiaming Yuan 2023-01-19 09:16:08 +08:00 committed by GitHub
parent 6933240837
commit 7a068af1a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -7,10 +7,11 @@
#define XGBOOST_DATA_ARRAY_INTERFACE_H_ #define XGBOOST_DATA_ARRAY_INTERFACE_H_
#include <algorithm> #include <algorithm>
#include <cstddef> // std::size_t
#include <cstdint> #include <cstdint>
#include <map> #include <map>
#include <string> #include <string>
#include <type_traits> // std::alignment_of #include <type_traits> // std::alignment_of,std::remove_pointer_t
#include <utility> #include <utility>
#include <vector> #include <vector>
@ -402,11 +403,9 @@ class ArrayInterface {
data = ArrayInterfaceHandler::ExtractData(array, n); data = ArrayInterfaceHandler::ExtractData(array, n);
static_assert(allow_mask ? D == 1 : D >= 1, "Masked ndarray is not supported."); static_assert(allow_mask ? D == 1 : D >= 1, "Masked ndarray is not supported.");
this->DispatchCall([&](auto const *data_typed_ptr) { auto alignment = this->ElementAlignment();
auto ptr = reinterpret_cast<uintptr_t>(data); auto ptr = reinterpret_cast<uintptr_t>(this->data);
auto alignment = std::alignment_of<std::remove_pointer_t<decltype(data_typed_ptr)>>::value; CHECK_EQ(ptr % alignment, 0) << "Input pointer misalignment.";
CHECK_EQ(ptr % alignment, 0) << "Input pointer misalignment.";
});
if (allow_mask) { if (allow_mask) {
common::Span<RBitField8::value_type> s_mask; common::Span<RBitField8::value_type> s_mask;
@ -540,9 +539,15 @@ class ArrayInterface {
return func(reinterpret_cast<uint64_t const *>(data)); return func(reinterpret_cast<uint64_t const *>(data));
} }
XGBOOST_DEVICE size_t ElementSize() { XGBOOST_DEVICE std::size_t ElementSize() const {
return this->DispatchCall( return this->DispatchCall([](auto *typed_data_ptr) {
[](auto *p_values) { return sizeof(std::remove_pointer_t<decltype(p_values)>); }); return sizeof(std::remove_pointer_t<decltype(typed_data_ptr)>);
});
}
XGBOOST_DEVICE std::size_t ElementAlignment() const {
return this->DispatchCall([](auto *typed_data_ptr) {
return std::alignment_of<std::remove_pointer_t<decltype(typed_data_ptr)>>::value;
});
} }
template <typename T = float, typename... Index> template <typename T = float, typename... Index>