Use __array_interface__ for creating DMatrix from CSR. (#6675)

* Use __array_interface__ for creating DMatrix from CSR.
* Add configuration.
This commit is contained in:
Jiaming Yuan
2021-02-05 21:09:47 +08:00
committed by GitHub
parent 1e949110da
commit dbb5208a0a
7 changed files with 97 additions and 18 deletions

View File

@@ -9,7 +9,7 @@ from typing import Any
import numpy as np
from .core import c_array, _LIB, _check_call, c_str
from .core import c_array, _LIB, _check_call, c_str, _array_interface
from .core import DataIter, _ProxyDMatrix, DMatrix
from .compat import lazy_isinstance
@@ -40,21 +40,28 @@ def _is_scipy_csr(data):
return isinstance(data, scipy.sparse.csr_matrix)
def _from_scipy_csr(data, missing, feature_names, feature_types):
'''Initialize data from a CSR matrix.'''
def _from_scipy_csr(data, missing, nthread, feature_names, feature_types):
"""Initialize data from a CSR matrix."""
if len(data.indices) != len(data.data):
raise ValueError('length mismatch: {} vs {}'.format(
len(data.indices), len(data.data)))
_warn_unused_missing(data, missing)
raise ValueError(
"length mismatch: {} vs {}".format(len(data.indices), len(data.data))
)
handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixCreateFromCSREx(
c_array(ctypes.c_size_t, data.indptr),
c_array(ctypes.c_uint, data.indices),
c_array(ctypes.c_float, data.data),
ctypes.c_size_t(len(data.indptr)),
ctypes.c_size_t(len(data.data)),
ctypes.c_size_t(data.shape[1]),
ctypes.byref(handle)))
args = {
"missing": float(missing),
"nthread": int(nthread),
}
config = bytes(json.dumps(args), "utf-8")
_check_call(
_LIB.XGDMatrixCreateFromCSR(
_array_interface(data.indptr),
_array_interface(data.indices),
_array_interface(data.data),
ctypes.c_size_t(data.shape[1]),
config,
ctypes.byref(handle),
)
)
return handle, feature_names, feature_types
@@ -541,11 +548,11 @@ def dispatch_data_backend(data, missing, threads,
enable_categorical=False):
'''Dispatch data for DMatrix.'''
if _is_scipy_csr(data):
return _from_scipy_csr(data, missing, feature_names, feature_types)
return _from_scipy_csr(data, missing, threads, feature_names, feature_types)
if _is_scipy_csc(data):
return _from_scipy_csc(data, missing, feature_names, feature_types)
if _is_scipy_coo(data):
return _from_scipy_csr(data.tocsr(), missing, feature_names, feature_types)
return _from_scipy_csr(data.tocsr(), missing, threads, feature_names, feature_types)
if _is_numpy_array(data):
return _from_numpy_array(data, missing, threads, feature_names,
feature_types)
@@ -592,7 +599,7 @@ def dispatch_data_backend(data, missing, threads,
converted = _convert_unknown_data(data)
if converted:
return _from_scipy_csr(data, missing, feature_names, feature_types)
return _from_scipy_csr(data, missing, threads, feature_names, feature_types)
raise TypeError('Not supported type for data.' + str(type(data)))