Use __array_interface__ for creating DMatrix from CSR. (#6675)
* Use __array_interface__ for creating DMatrix from CSR. * Add configuration.
This commit is contained in:
@@ -9,7 +9,7 @@ from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .core import c_array, _LIB, _check_call, c_str
|
||||
from .core import c_array, _LIB, _check_call, c_str, _array_interface
|
||||
from .core import DataIter, _ProxyDMatrix, DMatrix
|
||||
from .compat import lazy_isinstance
|
||||
|
||||
@@ -40,21 +40,28 @@ def _is_scipy_csr(data):
|
||||
return isinstance(data, scipy.sparse.csr_matrix)
|
||||
|
||||
|
||||
def _from_scipy_csr(data, missing, feature_names, feature_types):
|
||||
'''Initialize data from a CSR matrix.'''
|
||||
def _from_scipy_csr(data, missing, nthread, feature_names, feature_types):
|
||||
"""Initialize data from a CSR matrix."""
|
||||
if len(data.indices) != len(data.data):
|
||||
raise ValueError('length mismatch: {} vs {}'.format(
|
||||
len(data.indices), len(data.data)))
|
||||
_warn_unused_missing(data, missing)
|
||||
raise ValueError(
|
||||
"length mismatch: {} vs {}".format(len(data.indices), len(data.data))
|
||||
)
|
||||
handle = ctypes.c_void_p()
|
||||
_check_call(_LIB.XGDMatrixCreateFromCSREx(
|
||||
c_array(ctypes.c_size_t, data.indptr),
|
||||
c_array(ctypes.c_uint, data.indices),
|
||||
c_array(ctypes.c_float, data.data),
|
||||
ctypes.c_size_t(len(data.indptr)),
|
||||
ctypes.c_size_t(len(data.data)),
|
||||
ctypes.c_size_t(data.shape[1]),
|
||||
ctypes.byref(handle)))
|
||||
args = {
|
||||
"missing": float(missing),
|
||||
"nthread": int(nthread),
|
||||
}
|
||||
config = bytes(json.dumps(args), "utf-8")
|
||||
_check_call(
|
||||
_LIB.XGDMatrixCreateFromCSR(
|
||||
_array_interface(data.indptr),
|
||||
_array_interface(data.indices),
|
||||
_array_interface(data.data),
|
||||
ctypes.c_size_t(data.shape[1]),
|
||||
config,
|
||||
ctypes.byref(handle),
|
||||
)
|
||||
)
|
||||
return handle, feature_names, feature_types
|
||||
|
||||
|
||||
@@ -541,11 +548,11 @@ def dispatch_data_backend(data, missing, threads,
|
||||
enable_categorical=False):
|
||||
'''Dispatch data for DMatrix.'''
|
||||
if _is_scipy_csr(data):
|
||||
return _from_scipy_csr(data, missing, feature_names, feature_types)
|
||||
return _from_scipy_csr(data, missing, threads, feature_names, feature_types)
|
||||
if _is_scipy_csc(data):
|
||||
return _from_scipy_csc(data, missing, feature_names, feature_types)
|
||||
if _is_scipy_coo(data):
|
||||
return _from_scipy_csr(data.tocsr(), missing, feature_names, feature_types)
|
||||
return _from_scipy_csr(data.tocsr(), missing, threads, feature_names, feature_types)
|
||||
if _is_numpy_array(data):
|
||||
return _from_numpy_array(data, missing, threads, feature_names,
|
||||
feature_types)
|
||||
@@ -592,7 +599,7 @@ def dispatch_data_backend(data, missing, threads,
|
||||
|
||||
converted = _convert_unknown_data(data)
|
||||
if converted:
|
||||
return _from_scipy_csr(data, missing, feature_names, feature_types)
|
||||
return _from_scipy_csr(data, missing, threads, feature_names, feature_types)
|
||||
|
||||
raise TypeError('Not supported type for data.' + str(type(data)))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user