Make sure input numpy array is aligned. (#8690)

- use `np.require` to specify that the alignment is required.
- scipy csr as well.
- validate input pointer in `ArrayInterface`.
This commit is contained in:
Jiaming Yuan
2023-01-18 08:12:13 +08:00
committed by GitHub
parent 175986b739
commit 31b9cbab3d
5 changed files with 56 additions and 22 deletions

View File

@@ -1,5 +1,5 @@
/*!
* Copyright 2019-2021 by Contributors
/**
* Copyright 2019-2023 by XGBoost Contributors
* \file array_interface.h
* \brief View of __array_interface__
*/
@@ -7,9 +7,10 @@
#define XGBOOST_DATA_ARRAY_INTERFACE_H_
#include <algorithm>
#include <cinttypes>
#include <cstdint>
#include <map>
#include <string>
#include <type_traits> // std::alignment_of
#include <utility>
#include <vector>
@@ -400,6 +401,13 @@ class ArrayInterface {
data = ArrayInterfaceHandler::ExtractData(array, n);
static_assert(allow_mask ? D == 1 : D >= 1, "Masked ndarray is not supported.");
this->DispatchCall([&](auto const *data_typed_ptr) {
auto ptr = reinterpret_cast<uintptr_t>(data);
auto alignment = std::alignment_of<std::remove_pointer_t<decltype(data_typed_ptr)>>::value;
CHECK_EQ(ptr % alignment, 0) << "Input pointer misalignment.";
});
if (allow_mask) {
common::Span<RBitField8::value_type> s_mask;
size_t n_bits = ArrayInterfaceHandler::ExtractMask(array, &s_mask);