diff --git a/src/data/array_interface.h b/src/data/array_interface.h index 8a4661712..e9045899b 100644 --- a/src/data/array_interface.h +++ b/src/data/array_interface.h @@ -7,10 +7,11 @@ #define XGBOOST_DATA_ARRAY_INTERFACE_H_ #include +#include // std::size_t #include #include #include -#include // std::alignment_of +#include // std::alignment_of,std::remove_pointer_t #include #include @@ -402,11 +403,9 @@ class ArrayInterface { data = ArrayInterfaceHandler::ExtractData(array, n); static_assert(allow_mask ? D == 1 : D >= 1, "Masked ndarray is not supported."); - this->DispatchCall([&](auto const *data_typed_ptr) { - auto ptr = reinterpret_cast(data); - auto alignment = std::alignment_of>::value; - CHECK_EQ(ptr % alignment, 0) << "Input pointer misalignment."; - }); + auto alignment = this->ElementAlignment(); + auto ptr = reinterpret_cast(this->data); + CHECK_EQ(ptr % alignment, 0) << "Input pointer misalignment."; if (allow_mask) { common::Span s_mask; @@ -540,9 +539,15 @@ class ArrayInterface { return func(reinterpret_cast(data)); } - XGBOOST_DEVICE size_t ElementSize() { - return this->DispatchCall( - [](auto *p_values) { return sizeof(std::remove_pointer_t); }); + XGBOOST_DEVICE std::size_t ElementSize() const { + return this->DispatchCall([](auto *typed_data_ptr) { + return sizeof(std::remove_pointer_t); + }); + } + XGBOOST_DEVICE std::size_t ElementAlignment() const { + return this->DispatchCall([](auto *typed_data_ptr) { + return std::alignment_of>::value; + }); } template