Workaround CUDA warning. (#8696)

This commit is contained in:
Jiaming Yuan 2023-01-19 09:16:08 +08:00 committed by GitHub
parent 6933240837
commit 7a068af1a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -7,10 +7,11 @@
#define XGBOOST_DATA_ARRAY_INTERFACE_H_
#include <algorithm>
#include <cstddef> // std::size_t
#include <cstdint>
#include <map>
#include <string>
#include <type_traits> // std::alignment_of
#include <type_traits> // std::alignment_of,std::remove_pointer_t
#include <utility>
#include <vector>
@ -402,11 +403,9 @@ class ArrayInterface {
data = ArrayInterfaceHandler::ExtractData(array, n);
static_assert(allow_mask ? D == 1 : D >= 1, "Masked ndarray is not supported.");
this->DispatchCall([&](auto const *data_typed_ptr) {
auto ptr = reinterpret_cast<uintptr_t>(data);
auto alignment = std::alignment_of<std::remove_pointer_t<decltype(data_typed_ptr)>>::value;
CHECK_EQ(ptr % alignment, 0) << "Input pointer misalignment.";
});
auto alignment = this->ElementAlignment();
auto ptr = reinterpret_cast<uintptr_t>(this->data);
CHECK_EQ(ptr % alignment, 0) << "Input pointer misalignment.";
if (allow_mask) {
common::Span<RBitField8::value_type> s_mask;
@ -540,9 +539,15 @@ class ArrayInterface {
return func(reinterpret_cast<uint64_t const *>(data));
}
XGBOOST_DEVICE size_t ElementSize() {
return this->DispatchCall(
[](auto *p_values) { return sizeof(std::remove_pointer_t<decltype(p_values)>); });
XGBOOST_DEVICE std::size_t ElementSize() const {
return this->DispatchCall([](auto *typed_data_ptr) {
return sizeof(std::remove_pointer_t<decltype(typed_data_ptr)>);
});
}
XGBOOST_DEVICE std::size_t ElementAlignment() const {
return this->DispatchCall([](auto *typed_data_ptr) {
return std::alignment_of<std::remove_pointer_t<decltype(typed_data_ptr)>>::value;
});
}
template <typename T = float, typename... Index>