* [backport] Make sure input numpy array is aligned. (#8690) - use `np.require` to specify that the alignment is required. - scipy csr as well. - validate input pointer in `ArrayInterface`. * Workaround CUDA warning. (#8696) * backport from half type support for alignment. * fix import.
This commit is contained in:
parent
68d86336d7
commit
2f22f8d49b
@ -2172,6 +2172,7 @@ class Booster:
|
|||||||
)
|
)
|
||||||
return _prediction_output(shape, dims, preds, False)
|
return _prediction_output(shape, dims, preds, False)
|
||||||
|
|
||||||
|
# pylint: disable=too-many-statements
|
||||||
def inplace_predict(
|
def inplace_predict(
|
||||||
self,
|
self,
|
||||||
data: DataType,
|
data: DataType,
|
||||||
@ -2192,10 +2193,10 @@ class Booster:
|
|||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
booster.set_param({'predictor': 'gpu_predictor'})
|
booster.set_param({"predictor": "gpu_predictor"})
|
||||||
booster.inplace_predict(cupy_array)
|
booster.inplace_predict(cupy_array)
|
||||||
|
|
||||||
booster.set_param({'predictor': 'cpu_predictor})
|
booster.set_param({"predictor": "cpu_predictor"})
|
||||||
booster.inplace_predict(numpy_array)
|
booster.inplace_predict(numpy_array)
|
||||||
|
|
||||||
.. versionadded:: 1.1.0
|
.. versionadded:: 1.1.0
|
||||||
@ -2301,14 +2302,16 @@ class Booster:
|
|||||||
)
|
)
|
||||||
return _prediction_output(shape, dims, preds, False)
|
return _prediction_output(shape, dims, preds, False)
|
||||||
if isinstance(data, scipy.sparse.csr_matrix):
|
if isinstance(data, scipy.sparse.csr_matrix):
|
||||||
csr = data
|
from .data import _transform_scipy_csr
|
||||||
|
|
||||||
|
data = _transform_scipy_csr(data)
|
||||||
_check_call(
|
_check_call(
|
||||||
_LIB.XGBoosterPredictFromCSR(
|
_LIB.XGBoosterPredictFromCSR(
|
||||||
self.handle,
|
self.handle,
|
||||||
_array_interface(csr.indptr),
|
_array_interface(data.indptr),
|
||||||
_array_interface(csr.indices),
|
_array_interface(data.indices),
|
||||||
_array_interface(csr.data),
|
_array_interface(data.data),
|
||||||
c_bst_ulong(csr.shape[1]),
|
c_bst_ulong(data.shape[1]),
|
||||||
from_pystr_to_cstr(json.dumps(args)),
|
from_pystr_to_cstr(json.dumps(args)),
|
||||||
p_handle,
|
p_handle,
|
||||||
ctypes.byref(shape),
|
ctypes.byref(shape),
|
||||||
|
|||||||
@ -30,6 +30,7 @@ from .core import (
|
|||||||
c_array,
|
c_array,
|
||||||
c_str,
|
c_str,
|
||||||
from_pystr_to_cstr,
|
from_pystr_to_cstr,
|
||||||
|
make_jcargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
DispatchedDataBackendReturnType = Tuple[
|
DispatchedDataBackendReturnType = Tuple[
|
||||||
@ -80,6 +81,21 @@ def _array_interface(data: np.ndarray) -> bytes:
|
|||||||
return interface_str
|
return interface_str
|
||||||
|
|
||||||
|
|
||||||
|
def _transform_scipy_csr(data: DataType) -> DataType:
|
||||||
|
from scipy.sparse import csr_matrix
|
||||||
|
|
||||||
|
indptr, _ = _ensure_np_dtype(data.indptr, data.indptr.dtype)
|
||||||
|
indices, _ = _ensure_np_dtype(data.indices, data.indices.dtype)
|
||||||
|
values, _ = _ensure_np_dtype(data.data, data.data.dtype)
|
||||||
|
if (
|
||||||
|
indptr is not data.indptr
|
||||||
|
or indices is not data.indices
|
||||||
|
or values is not data.data
|
||||||
|
):
|
||||||
|
data = csr_matrix((values, indices, indptr), shape=data.shape)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
def _from_scipy_csr(
|
def _from_scipy_csr(
|
||||||
data: DataType,
|
data: DataType,
|
||||||
missing: FloatCompatible,
|
missing: FloatCompatible,
|
||||||
@ -93,18 +109,14 @@ def _from_scipy_csr(
|
|||||||
f"length mismatch: {len(data.indices)} vs {len(data.data)}"
|
f"length mismatch: {len(data.indices)} vs {len(data.data)}"
|
||||||
)
|
)
|
||||||
handle = ctypes.c_void_p()
|
handle = ctypes.c_void_p()
|
||||||
args = {
|
data = _transform_scipy_csr(data)
|
||||||
"missing": float(missing),
|
|
||||||
"nthread": int(nthread),
|
|
||||||
}
|
|
||||||
config = bytes(json.dumps(args), "utf-8")
|
|
||||||
_check_call(
|
_check_call(
|
||||||
_LIB.XGDMatrixCreateFromCSR(
|
_LIB.XGDMatrixCreateFromCSR(
|
||||||
_array_interface(data.indptr),
|
_array_interface(data.indptr),
|
||||||
_array_interface(data.indices),
|
_array_interface(data.indices),
|
||||||
_array_interface(data.data),
|
_array_interface(data.data),
|
||||||
c_bst_ulong(data.shape[1]),
|
c_bst_ulong(data.shape[1]),
|
||||||
config,
|
make_jcargs(missing=float(missing), nthread=int(nthread)),
|
||||||
ctypes.byref(handle),
|
ctypes.byref(handle),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -153,12 +165,13 @@ def _is_numpy_array(data: DataType) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def _ensure_np_dtype(
|
def _ensure_np_dtype(
|
||||||
data: DataType,
|
data: DataType, dtype: Optional[NumpyDType]
|
||||||
dtype: Optional[NumpyDType]
|
|
||||||
) -> Tuple[np.ndarray, Optional[NumpyDType]]:
|
) -> Tuple[np.ndarray, Optional[NumpyDType]]:
|
||||||
if data.dtype.hasobject or data.dtype in [np.float16, np.bool_]:
|
if data.dtype.hasobject or data.dtype in [np.float16, np.bool_]:
|
||||||
data = data.astype(np.float32, copy=False)
|
|
||||||
dtype = np.float32
|
dtype = np.float32
|
||||||
|
data = data.astype(dtype, copy=False)
|
||||||
|
if not data.flags.aligned:
|
||||||
|
data = np.require(data, requirements="A")
|
||||||
return data, dtype
|
return data, dtype
|
||||||
|
|
||||||
|
|
||||||
@ -1197,11 +1210,13 @@ def _proxy_transform(
|
|||||||
data, _ = _ensure_np_dtype(data, data.dtype)
|
data, _ = _ensure_np_dtype(data, data.dtype)
|
||||||
return data, None, feature_names, feature_types
|
return data, None, feature_names, feature_types
|
||||||
if _is_scipy_csr(data):
|
if _is_scipy_csr(data):
|
||||||
|
data = _transform_scipy_csr(data)
|
||||||
return data, None, feature_names, feature_types
|
return data, None, feature_names, feature_types
|
||||||
if _is_pandas_df(data):
|
if _is_pandas_df(data):
|
||||||
arr, feature_names, feature_types = _transform_pandas_df(
|
arr, feature_names, feature_types = _transform_pandas_df(
|
||||||
data, enable_categorical, feature_names, feature_types
|
data, enable_categorical, feature_names, feature_types
|
||||||
)
|
)
|
||||||
|
arr, _ = _ensure_np_dtype(arr, arr.dtype)
|
||||||
return arr, None, feature_names, feature_types
|
return arr, None, feature_names, feature_types
|
||||||
raise TypeError("Value type is not supported for data iterator:" + str(type(data)))
|
raise TypeError("Value type is not supported for data iterator:" + str(type(data)))
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2019-2021 by Contributors
|
* Copyright 2019-2023 by XGBoost Contributors
|
||||||
* \file array_interface.h
|
* \file array_interface.h
|
||||||
* \brief View of __array_interface__
|
* \brief View of __array_interface__
|
||||||
*/
|
*/
|
||||||
@ -7,9 +7,11 @@
|
|||||||
#define XGBOOST_DATA_ARRAY_INTERFACE_H_
|
#define XGBOOST_DATA_ARRAY_INTERFACE_H_
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cinttypes>
|
#include <cstddef> // std::size_t
|
||||||
|
#include <cstdint>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <type_traits> // std::alignment_of,std::remove_pointer_t
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@ -394,6 +396,11 @@ class ArrayInterface {
|
|||||||
|
|
||||||
data = ArrayInterfaceHandler::ExtractData(array, n);
|
data = ArrayInterfaceHandler::ExtractData(array, n);
|
||||||
static_assert(allow_mask ? D == 1 : D >= 1, "Masked ndarray is not supported.");
|
static_assert(allow_mask ? D == 1 : D >= 1, "Masked ndarray is not supported.");
|
||||||
|
|
||||||
|
auto alignment = this->ElementAlignment();
|
||||||
|
auto ptr = reinterpret_cast<uintptr_t>(this->data);
|
||||||
|
CHECK_EQ(ptr % alignment, 0) << "Input pointer misalignment.";
|
||||||
|
|
||||||
if (allow_mask) {
|
if (allow_mask) {
|
||||||
common::Span<RBitField8::value_type> s_mask;
|
common::Span<RBitField8::value_type> s_mask;
|
||||||
size_t n_bits = ArrayInterfaceHandler::ExtractMask(array, &s_mask);
|
size_t n_bits = ArrayInterfaceHandler::ExtractMask(array, &s_mask);
|
||||||
@ -512,9 +519,15 @@ class ArrayInterface {
|
|||||||
return func(reinterpret_cast<uint64_t const *>(data));
|
return func(reinterpret_cast<uint64_t const *>(data));
|
||||||
}
|
}
|
||||||
|
|
||||||
XGBOOST_DEVICE size_t ElementSize() {
|
XGBOOST_DEVICE std::size_t ElementSize() const {
|
||||||
return this->DispatchCall(
|
return this->DispatchCall([](auto *typed_data_ptr) {
|
||||||
[](auto *p_values) { return sizeof(std::remove_pointer_t<decltype(p_values)>); });
|
return sizeof(std::remove_pointer_t<decltype(typed_data_ptr)>);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
XGBOOST_DEVICE std::size_t ElementAlignment() const {
|
||||||
|
return this->DispatchCall([](auto *typed_data_ptr) {
|
||||||
|
return std::alignment_of<std::remove_pointer_t<decltype(typed_data_ptr)>>::value;
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T = float, typename... Index>
|
template <typename T = float, typename... Index>
|
||||||
|
|||||||
@ -1,10 +1,12 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2020-2021 by XGBoost Contributors
|
* Copyright 2020-2023 by XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <xgboost/host_device_vector.h>
|
#include <xgboost/host_device_vector.h>
|
||||||
#include "../helpers.h"
|
#include "../helpers.h"
|
||||||
#include "../../../src/data/array_interface.h"
|
#include "../../../src/data/array_interface.h"
|
||||||
|
#include "dmlc/logging.h"
|
||||||
|
#include "xgboost/json.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
TEST(ArrayInterface, Initialize) {
|
TEST(ArrayInterface, Initialize) {
|
||||||
@ -71,6 +73,14 @@ TEST(ArrayInterface, Error) {
|
|||||||
column["mask"]["data"] = Null{};
|
column["mask"]["data"] = Null{};
|
||||||
common::Span<RBitField8::value_type> s_mask;
|
common::Span<RBitField8::value_type> s_mask;
|
||||||
EXPECT_THROW(ArrayInterfaceHandler::ExtractMask(column_obj, &s_mask), dmlc::Error);
|
EXPECT_THROW(ArrayInterfaceHandler::ExtractMask(column_obj, &s_mask), dmlc::Error);
|
||||||
|
|
||||||
|
get<Object>(column).erase("mask");
|
||||||
|
// misaligned.
|
||||||
|
j_data = {Json(Integer(reinterpret_cast<Integer::Int>(
|
||||||
|
reinterpret_cast<char const*>(storage.ConstHostPointer()) + 1))),
|
||||||
|
Json(Boolean(false))};
|
||||||
|
column["data"] = j_data;
|
||||||
|
EXPECT_THROW({ ArrayInterface<1> arr{column}; }, dmlc::Error);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ArrayInterface, GetElement) {
|
TEST(ArrayInterface, GetElement) {
|
||||||
|
|||||||
@ -326,7 +326,7 @@ class TestDMatrix:
|
|||||||
nrow = 100
|
nrow = 100
|
||||||
ncol = 1000
|
ncol = 1000
|
||||||
x = rand(nrow, ncol, density=0.0005, format='csr', random_state=rng)
|
x = rand(nrow, ncol, density=0.0005, format='csr', random_state=rng)
|
||||||
assert x.indices.max() < ncol - 1
|
assert x.indices.max() < ncol
|
||||||
x.data[:] = 1
|
x.data[:] = 1
|
||||||
dtrain = xgb.DMatrix(x, label=rng.binomial(1, 0.3, nrow))
|
dtrain = xgb.DMatrix(x, label=rng.binomial(1, 0.3, nrow))
|
||||||
assert (dtrain.num_row(), dtrain.num_col()) == (nrow, ncol)
|
assert (dtrain.num_row(), dtrain.num_col()) == (nrow, ncol)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user