Make sure input numpy array is aligned. (#8690)
- use `np.require` to specify that the alignment is required. - scipy csr as well. - validate input pointer in `ArrayInterface`.
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2019-2021 by Contributors
|
||||
/**
|
||||
* Copyright 2019-2023 by XGBoost Contributors
|
||||
* \file array_interface.h
|
||||
* \brief View of __array_interface__
|
||||
*/
|
||||
@@ -7,9 +7,10 @@
|
||||
#define XGBOOST_DATA_ARRAY_INTERFACE_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <cinttypes>
|
||||
#include <cstdint>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <type_traits> // std::alignment_of
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
@@ -400,6 +401,13 @@ class ArrayInterface {
|
||||
|
||||
data = ArrayInterfaceHandler::ExtractData(array, n);
|
||||
static_assert(allow_mask ? D == 1 : D >= 1, "Masked ndarray is not supported.");
|
||||
|
||||
this->DispatchCall([&](auto const *data_typed_ptr) {
|
||||
auto ptr = reinterpret_cast<uintptr_t>(data);
|
||||
auto alignment = std::alignment_of<std::remove_pointer_t<decltype(data_typed_ptr)>>::value;
|
||||
CHECK_EQ(ptr % alignment, 0) << "Input pointer misalignment.";
|
||||
});
|
||||
|
||||
if (allow_mask) {
|
||||
common::Span<RBitField8::value_type> s_mask;
|
||||
size_t n_bits = ArrayInterfaceHandler::ExtractMask(array, &s_mask);
|
||||
|
||||
Reference in New Issue
Block a user