Allow unaligned pointer if the array is empty. (#10418)

Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Jiaming Yuan 2024-06-15 19:10:21 +08:00 committed by GitHub
parent bbff74d2ff
commit 49e25cfb36
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -6,28 +6,25 @@
#ifndef XGBOOST_DATA_ARRAY_INTERFACE_H_ #ifndef XGBOOST_DATA_ARRAY_INTERFACE_H_
#define XGBOOST_DATA_ARRAY_INTERFACE_H_ #define XGBOOST_DATA_ARRAY_INTERFACE_H_
#include <algorithm> #include <algorithm> // for all_of, transform, fill
#include <cstddef> // for size_t #include <cstddef> // for size_t
#include <cstdint> #include <cstdint> // for int32_t, int64_t, ...
#include <limits> // for numeric_limits #include <limits> // for numeric_limits
#include <map> #include <map> // for map
#include <string> #include <string> // for string
#include <type_traits> // for alignment_of, remove_pointer_t, invoke_result_t #include <type_traits> // for alignment_of, remove_pointer_t, invoke_result_t
#include <utility> #include <vector> // for vector
#include <vector>
#include "../common/bitfield.h" // for RBitField8 #include "../common/bitfield.h" // for RBitField8
#include "../common/common.h"
#include "../common/error_msg.h" // for NoF128 #include "../common/error_msg.h" // for NoF128
#include "xgboost/base.h" #include "xgboost/json.h" // for Json
#include "xgboost/data.h" #include "xgboost/linalg.h" // for CalcStride, TensorView
#include "xgboost/json.h" #include "xgboost/logging.h" // for CHECK
#include "xgboost/linalg.h" #include "xgboost/span.h" // for Span
#include "xgboost/logging.h" #include "xgboost/string_view.h" // for StringView
#include "xgboost/span.h"
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA)
#include "cuda_fp16.h" #include "cuda_fp16.h" // for __half
#endif #endif
namespace xgboost { namespace xgboost {
@ -410,7 +407,7 @@ class ArrayInterface {
auto typestr = get<String const>(array.at("typestr")); auto typestr = get<String const>(array.at("typestr"));
this->AssignType(StringView{typestr}); this->AssignType(StringView{typestr});
ArrayInterfaceHandler::ExtractShape(array, shape); ArrayInterfaceHandler::ExtractShape(array, shape);
size_t itemsize = typestr[2] - '0'; std::size_t itemsize = typestr[2] - '0';
is_contiguous = ArrayInterfaceHandler::ExtractStride(array, itemsize, shape, strides); is_contiguous = ArrayInterfaceHandler::ExtractStride(array, itemsize, shape, strides);
n = linalg::detail::CalcSize(shape); n = linalg::detail::CalcSize(shape);
@ -419,7 +416,9 @@ class ArrayInterface {
auto alignment = this->ElementAlignment(); auto alignment = this->ElementAlignment();
auto ptr = reinterpret_cast<uintptr_t>(this->data); auto ptr = reinterpret_cast<uintptr_t>(this->data);
if (!std::all_of(this->shape, this->shape + D, [](auto v) { return v == 0; })) {
CHECK_EQ(ptr % alignment, 0) << "Input pointer misalignment."; CHECK_EQ(ptr % alignment, 0) << "Input pointer misalignment.";
}
if (allow_mask) { if (allow_mask) {
common::Span<RBitField8::value_type> s_mask; common::Span<RBitField8::value_type> s_mask;