Obtain CSR matrix from DMatrix. (#8269)
This commit is contained in:
@@ -609,7 +609,7 @@ def _deprecate_positional_args(f: Callable[..., _T]) -> Callable[..., _T]:
|
||||
return inner_f
|
||||
|
||||
|
||||
class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-methods
|
||||
"""Data Matrix used in XGBoost.
|
||||
|
||||
DMatrix is an internal data structure that is used by XGBoost,
|
||||
@@ -1015,29 +1015,49 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
group_ptr = self.get_uint_info("group_ptr")
|
||||
return np.diff(group_ptr)
|
||||
|
||||
def num_row(self) -> int:
|
||||
"""Get the number of rows in the DMatrix.
|
||||
def get_data(self) -> scipy.sparse.csr_matrix:
|
||||
"""Get the predictors from DMatrix as a CSR matrix. This getter is mostly for
|
||||
testing purposes. If this is a quantized DMatrix then quantized values are
|
||||
returned instead of input values.
|
||||
|
||||
.. versionadded:: 2.0.0
|
||||
|
||||
Returns
|
||||
-------
|
||||
number of rows : int
|
||||
"""
|
||||
indptr = np.empty(self.num_row() + 1, dtype=np.uint64)
|
||||
indices = np.empty(self.num_nonmissing(), dtype=np.uint32)
|
||||
data = np.empty(self.num_nonmissing(), dtype=np.float32)
|
||||
|
||||
c_indptr = indptr.ctypes.data_as(ctypes.POINTER(c_bst_ulong))
|
||||
c_indices = indices.ctypes.data_as(ctypes.POINTER(ctypes.c_uint32))
|
||||
c_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
|
||||
config = from_pystr_to_cstr(json.dumps({}))
|
||||
|
||||
_check_call(
|
||||
_LIB.XGDMatrixGetDataAsCSR(self.handle, config, c_indptr, c_indices, c_data)
|
||||
)
|
||||
ret = scipy.sparse.csr_matrix(
|
||||
(data, indices, indptr), shape=(self.num_row(), self.num_col())
|
||||
)
|
||||
return ret
|
||||
|
||||
def num_row(self) -> int:
|
||||
"""Get the number of rows in the DMatrix."""
|
||||
ret = c_bst_ulong()
|
||||
_check_call(_LIB.XGDMatrixNumRow(self.handle,
|
||||
ctypes.byref(ret)))
|
||||
_check_call(_LIB.XGDMatrixNumRow(self.handle, ctypes.byref(ret)))
|
||||
return ret.value
|
||||
|
||||
def num_col(self) -> int:
|
||||
"""Get the number of columns (features) in the DMatrix.
|
||||
|
||||
Returns
|
||||
-------
|
||||
number of columns
|
||||
"""
|
||||
"""Get the number of columns (features) in the DMatrix."""
|
||||
ret = c_bst_ulong()
|
||||
_check_call(_LIB.XGDMatrixNumCol(self.handle, ctypes.byref(ret)))
|
||||
return ret.value
|
||||
|
||||
def num_nonmissing(self) -> int:
|
||||
"""Get the number of non-missing values in the DMatrix."""
|
||||
ret = c_bst_ulong()
|
||||
_check_call(_LIB.XGDMatrixNumNonMissing(self.handle, ctypes.byref(ret)))
|
||||
return ret.value
|
||||
|
||||
def slice(
|
||||
self, rindex: Union[List[int], np.ndarray], allow_groups: bool = False
|
||||
) -> "DMatrix":
|
||||
|
||||
Reference in New Issue
Block a user