Support numpy array interface (#6998)

This commit is contained in:
Jiaming Yuan
2021-05-27 16:08:22 +08:00
committed by GitHub
parent ab6fd304c4
commit 4cf95a6041
6 changed files with 59 additions and 38 deletions

View File

@@ -116,11 +116,6 @@ def _maybe_np_slice(data, dtype):
'''
try:
if not data.flags.c_contiguous:
warnings.warn(
"Use of np.ndarray subsets (sliced data) is not recommended " +
"because it will generate extra copies and increase " +
"memory consumption. Consider using np.ascontiguousarray to " +
"make the array contiguous.")
data = np.array(data, copy=True, dtype=dtype)
else:
data = np.array(data, copy=False, dtype=dtype)
@@ -130,44 +125,28 @@ def _maybe_np_slice(data, dtype):
return data
def _transform_np_array(data: np.ndarray) -> np.ndarray:
if not isinstance(data, np.ndarray) and hasattr(data, '__array__'):
data = np.array(data, copy=False)
if len(data.shape) != 2:
raise ValueError('Expecting 2 dimensional numpy.ndarray, got: ',
data.shape)
# flatten the array by rows and ensure it is float32. we try to avoid
# data copies if possible (reshape returns a view when possible and we
# explicitly tell np.array to try and avoid copying)
flatten = np.array(data.reshape(data.size), copy=False,
dtype=np.float32)
flatten = _maybe_np_slice(flatten, np.float32)
_check_complex(data)
return flatten
def _from_numpy_array(data, missing, nthread, feature_names, feature_types):
"""Initialize data from a 2-D numpy matrix.
If ``mat`` does not have ``order='C'`` (aka row-major) or is
not contiguous, a temporary copy will be made.
If ``mat`` does not have ``dtype=numpy.float32``, a temporary copy will
be made.
So there could be as many as two temporary data copies; be mindful of
input layout and type if memory use is a concern.
"""
flatten: np.ndarray = _transform_np_array(data)
if len(data.shape) != 2:
raise ValueError(
"Expecting 2 dimensional numpy.ndarray, got: ", data.shape
)
data, _ = _ensure_np_dtype(data, data.dtype)
handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixCreateFromMat_omp(
flatten.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
c_bst_ulong(data.shape[0]),
c_bst_ulong(data.shape[1]),
ctypes.c_float(missing),
ctypes.byref(handle),
ctypes.c_int(nthread)))
args = {
"missing": float(missing),
"nthread": int(nthread),
}
config = bytes(json.dumps(args), "utf-8")
_check_call(
_LIB.XGDMatrixCreateFromArray(
_array_interface(data),
config,
ctypes.byref(handle),
)
)
return handle, feature_names, feature_types