Enforce input data is not object. (#6927)
* Check for object data type. * Allow strided arrays with greater underlying buffer size.
This commit is contained in:
@@ -233,6 +233,9 @@ def _numpy2ctypes_type(dtype):
|
||||
|
||||
|
||||
def _array_interface(data: np.ndarray) -> bytes:
|
||||
assert (
|
||||
data.dtype.hasobject is False
|
||||
), "Input data contains `object` dtype. Expecting numeric data."
|
||||
interface = data.__array_interface__
|
||||
if "mask" in interface:
|
||||
interface["mask"] = interface["mask"].__array_interface__
|
||||
@@ -1908,8 +1911,8 @@ class Booster(object):
|
||||
)
|
||||
|
||||
if isinstance(data, np.ndarray):
|
||||
from .data import _maybe_np_slice
|
||||
data = _maybe_np_slice(data, data.dtype)
|
||||
from .data import _ensure_np_dtype
|
||||
data, _ = _ensure_np_dtype(data, data.dtype)
|
||||
_check_call(
|
||||
_LIB.XGBoosterPredictFromDense(
|
||||
self.handle,
|
||||
|
||||
@@ -104,6 +104,13 @@ def _is_numpy_array(data):
|
||||
return isinstance(data, (np.ndarray, np.matrix))
|
||||
|
||||
|
||||
def _ensure_np_dtype(data, dtype):
|
||||
if data.dtype.hasobject:
|
||||
data = data.astype(np.float32, copy=False)
|
||||
dtype = np.float32
|
||||
return data, dtype
|
||||
|
||||
|
||||
def _maybe_np_slice(data, dtype):
|
||||
'''Handle numpy slice. This can be removed if we use __array_interface__.
|
||||
'''
|
||||
@@ -118,6 +125,7 @@ def _maybe_np_slice(data, dtype):
|
||||
data = np.array(data, copy=False, dtype=dtype)
|
||||
except AttributeError:
|
||||
data = np.array(data, copy=False, dtype=dtype)
|
||||
data, dtype = _ensure_np_dtype(data, dtype)
|
||||
return data
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user