[backport] Enforce input data is not object. (#6927) (#6938)

* Check for object data type.
* Allow strided arrays with greater underlying buffer size.
This commit is contained in:
Jiaming Yuan
2021-05-04 16:10:16 +08:00
committed by GitHub
parent b78ad1e623
commit 6609211517
4 changed files with 31 additions and 4 deletions

View File

@@ -229,6 +229,9 @@ def _numpy2ctypes_type(dtype):
def _array_interface(data: np.ndarray) -> bytes:
assert (
data.dtype.hasobject is False
), "Input data contains `object` dtype. Expecting numeric data."
interface = data.__array_interface__
if "mask" in interface:
interface["mask"] = interface["mask"].__array_interface__
@@ -1841,8 +1844,8 @@ class Booster(object):
)
if isinstance(data, np.ndarray):
from .data import _maybe_np_slice
data = _maybe_np_slice(data, data.dtype)
from .data import _ensure_np_dtype
data, _ = _ensure_np_dtype(data, data.dtype)
_check_call(
_LIB.XGBoosterPredictFromDense(
self.handle,

View File

@@ -104,6 +104,13 @@ def _is_numpy_array(data):
return isinstance(data, (np.ndarray, np.matrix))
def _ensure_np_dtype(data, dtype):
if data.dtype.hasobject:
data = data.astype(np.float32, copy=False)
dtype = np.float32
return data, dtype
def _maybe_np_slice(data, dtype):
'''Handle numpy slice. This can be removed if we use __array_interface__.
'''
@@ -118,6 +125,7 @@ def _maybe_np_slice(data, dtype):
data = np.array(data, copy=False, dtype=dtype)
except AttributeError:
data = np.array(data, copy=False, dtype=dtype)
data, dtype = _ensure_np_dtype(data, dtype)
return data