Allow unaligned pointer if the array is empty. (#10418)
Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
parent
bbff74d2ff
commit
49e25cfb36
@ -6,28 +6,25 @@
|
||||
#ifndef XGBOOST_DATA_ARRAY_INTERFACE_H_
|
||||
#define XGBOOST_DATA_ARRAY_INTERFACE_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint>
|
||||
#include <limits> // for numeric_limits
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <algorithm> // for all_of, transform, fill
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int32_t, int64_t, ...
|
||||
#include <limits> // for numeric_limits
|
||||
#include <map> // for map
|
||||
#include <string> // for string
|
||||
#include <type_traits> // for alignment_of, remove_pointer_t, invoke_result_t
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../common/bitfield.h" // for RBitField8
|
||||
#include "../common/common.h"
|
||||
#include "../common/bitfield.h" // for RBitField8
|
||||
#include "../common/error_msg.h" // for NoF128
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/json.h"
|
||||
#include "xgboost/linalg.h"
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/span.h"
|
||||
#include "xgboost/json.h" // for Json
|
||||
#include "xgboost/linalg.h" // for CalcStride, TensorView
|
||||
#include "xgboost/logging.h" // for CHECK
|
||||
#include "xgboost/span.h" // for Span
|
||||
#include "xgboost/string_view.h" // for StringView
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
#include "cuda_fp16.h"
|
||||
#include "cuda_fp16.h" // for __half
|
||||
#endif
|
||||
|
||||
namespace xgboost {
|
||||
@ -410,7 +407,7 @@ class ArrayInterface {
|
||||
auto typestr = get<String const>(array.at("typestr"));
|
||||
this->AssignType(StringView{typestr});
|
||||
ArrayInterfaceHandler::ExtractShape(array, shape);
|
||||
size_t itemsize = typestr[2] - '0';
|
||||
std::size_t itemsize = typestr[2] - '0';
|
||||
is_contiguous = ArrayInterfaceHandler::ExtractStride(array, itemsize, shape, strides);
|
||||
n = linalg::detail::CalcSize(shape);
|
||||
|
||||
@ -419,7 +416,9 @@ class ArrayInterface {
|
||||
|
||||
auto alignment = this->ElementAlignment();
|
||||
auto ptr = reinterpret_cast<uintptr_t>(this->data);
|
||||
CHECK_EQ(ptr % alignment, 0) << "Input pointer misalignment.";
|
||||
if (!std::all_of(this->shape, this->shape + D, [](auto v) { return v == 0; })) {
|
||||
CHECK_EQ(ptr % alignment, 0) << "Input pointer misalignment.";
|
||||
}
|
||||
|
||||
if (allow_mask) {
|
||||
common::Span<RBitField8::value_type> s_mask;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user