Obtain CSR matrix from DMatrix. (#8269)

This commit is contained in:
Jiaming Yuan
2022-09-29 20:41:43 +08:00
committed by GitHub
parent b14c44ee5e
commit 55cf24cc32
22 changed files with 400 additions and 74 deletions

View File

@@ -609,7 +609,7 @@ def _deprecate_positional_args(f: Callable[..., _T]) -> Callable[..., _T]:
return inner_f
class DMatrix: # pylint: disable=too-many-instance-attributes
class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-methods
"""Data Matrix used in XGBoost.
DMatrix is an internal data structure that is used by XGBoost,
@@ -1015,29 +1015,49 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
group_ptr = self.get_uint_info("group_ptr")
return np.diff(group_ptr)
def num_row(self) -> int:
"""Get the number of rows in the DMatrix.
def get_data(self) -> scipy.sparse.csr_matrix:
"""Get the predictors from DMatrix as a CSR matrix. This getter is mostly for
testing purposes. If this is a quantized DMatrix then quantized values are
returned instead of input values.
.. versionadded:: 2.0.0
Returns
-------
number of rows : int
"""
indptr = np.empty(self.num_row() + 1, dtype=np.uint64)
indices = np.empty(self.num_nonmissing(), dtype=np.uint32)
data = np.empty(self.num_nonmissing(), dtype=np.float32)
c_indptr = indptr.ctypes.data_as(ctypes.POINTER(c_bst_ulong))
c_indices = indices.ctypes.data_as(ctypes.POINTER(ctypes.c_uint32))
c_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
config = from_pystr_to_cstr(json.dumps({}))
_check_call(
_LIB.XGDMatrixGetDataAsCSR(self.handle, config, c_indptr, c_indices, c_data)
)
ret = scipy.sparse.csr_matrix(
(data, indices, indptr), shape=(self.num_row(), self.num_col())
)
return ret
def num_row(self) -> int:
"""Get the number of rows in the DMatrix."""
ret = c_bst_ulong()
_check_call(_LIB.XGDMatrixNumRow(self.handle,
ctypes.byref(ret)))
_check_call(_LIB.XGDMatrixNumRow(self.handle, ctypes.byref(ret)))
return ret.value
def num_col(self) -> int:
"""Get the number of columns (features) in the DMatrix.
Returns
-------
number of columns
"""
"""Get the number of columns (features) in the DMatrix."""
ret = c_bst_ulong()
_check_call(_LIB.XGDMatrixNumCol(self.handle, ctypes.byref(ret)))
return ret.value
def num_nonmissing(self) -> int:
"""Get the number of non-missing values in the DMatrix."""
ret = c_bst_ulong()
_check_call(_LIB.XGDMatrixNumNonMissing(self.handle, ctypes.byref(ret)))
return ret.value
def slice(
self, rindex: Union[List[int], np.ndarray], allow_groups: bool = False
) -> "DMatrix":