Add support for dlpack, expose python docs for DeviceQuantileDMatrix (#5465)
This commit is contained in:
parent
6601a641d7
commit
15f40e51e9
@ -14,6 +14,9 @@ Core Data Structure
|
||||
:members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: xgboost.DeviceQuantileDMatrix
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: xgboost.Booster
|
||||
:members:
|
||||
:show-inheritance:
|
||||
|
||||
@ -381,6 +381,17 @@ def _maybe_dt_data(data, feature_names, feature_types,
|
||||
|
||||
return data, feature_names, feature_types
|
||||
|
||||
def _is_dlpack(x):
|
||||
return 'PyCapsule' in str(type(x)) and "dltensor" in str(x)
|
||||
|
||||
# Just convert dlpack into cupy (zero copy)
|
||||
def _maybe_dlpack_data(data, feature_names, feature_types):
|
||||
if not _is_dlpack(data):
|
||||
return data, feature_names, feature_types
|
||||
from cupy import fromDlpack # pylint: disable=E0401
|
||||
data = fromDlpack(data)
|
||||
return data, feature_names, feature_types
|
||||
|
||||
|
||||
def _convert_dataframes(data, feature_names, feature_types,
|
||||
meta=None, meta_type=None):
|
||||
@ -399,6 +410,9 @@ def _convert_dataframes(data, feature_names, feature_types,
|
||||
data, feature_names, feature_types = _maybe_cudf_dataframe(
|
||||
data, feature_names, feature_types)
|
||||
|
||||
data, feature_names, feature_types = _maybe_dlpack_data(
|
||||
data, feature_names, feature_types)
|
||||
|
||||
return data, feature_names, feature_types
|
||||
|
||||
|
||||
@ -439,7 +453,7 @@ class DMatrix(object):
|
||||
"""Parameters
|
||||
----------
|
||||
data : os.PathLike/string/numpy.array/scipy.sparse/pd.DataFrame/
|
||||
dt.Frame/cudf.DataFrame/cupy.array
|
||||
dt.Frame/cudf.DataFrame/cupy.array/dlpack
|
||||
Data source of DMatrix.
|
||||
When data is string or os.PathLike type, it represents the path
|
||||
libsvm format txt file, csv file (by specifying uri parameter
|
||||
@ -1028,12 +1042,12 @@ class DMatrix(object):
|
||||
class DeviceQuantileDMatrix(DMatrix):
|
||||
"""Device memory Data Matrix used in XGBoost for training with tree_method='gpu_hist'. Do not
|
||||
use this for test/validation tasks as some information may be lost in quantisation. This
|
||||
DMatrix is primarily designed to save memory in training and avoids intermediate steps,
|
||||
directly creating a compressed representation for training without allocating additional
|
||||
memory. Implementation does not currently consider weights in quantisation process(unlike
|
||||
DMatrix).
|
||||
DMatrix is primarily designed to save memory in training from device memory inputs by
|
||||
avoiding intermediate storage. Implementation does not currently consider weights in
|
||||
quantisation process(unlike DMatrix). Set max_bin to control the number of bins during
|
||||
quantisation.
|
||||
|
||||
You can construct DeviceDMatrix from cupy/cudf
|
||||
You can construct DeviceQuantileDMatrix from cupy/cudf/dlpack.
|
||||
"""
|
||||
|
||||
def __init__(self, data, label=None, weight=None, base_margin=None,
|
||||
@ -1044,8 +1058,8 @@ class DeviceQuantileDMatrix(DMatrix):
|
||||
nthread=None, max_bin=256):
|
||||
self.max_bin = max_bin
|
||||
if not (hasattr(data, "__cuda_array_interface__") or (
|
||||
CUDF_INSTALLED and isinstance(data, CUDF_DataFrame))):
|
||||
raise ValueError('Only cupy/cudf currently supported for DeviceDMatrix')
|
||||
CUDF_INSTALLED and isinstance(data, CUDF_DataFrame)) or _is_dlpack(data)):
|
||||
raise ValueError('Only cupy/cudf/dlpack currently supported for DeviceQuantileDMatrix')
|
||||
|
||||
super().__init__(data, label=label, weight=weight, base_margin=base_margin,
|
||||
missing=missing,
|
||||
|
||||
@ -95,7 +95,7 @@ def _test_cupy_metainfo(DMatrixT):
|
||||
assert np.array_equal(dmat.get_uint_info('group_ptr'), dmat_cupy.get_uint_info('group_ptr'))
|
||||
|
||||
|
||||
class TestFromArrayInterface:
|
||||
class TestFromCupy:
|
||||
'''Tests for constructing DMatrix from data structure conforming Apache
|
||||
Arrow specification.'''
|
||||
|
||||
@ -122,3 +122,17 @@ Arrow specification.'''
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
def test_cupy_metainfo_device_dmat(self):
|
||||
_test_cupy_metainfo(xgb.DeviceQuantileDMatrix)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
def test_dlpack_simple_dmat(self):
|
||||
import cupy as cp
|
||||
n = 100
|
||||
X = cp.random.random((n, 2))
|
||||
xgb.DMatrix(X.toDlpack())
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
def test_dlpack_device_dmat(self):
|
||||
import cupy as cp
|
||||
n = 100
|
||||
X = cp.random.random((n, 2))
|
||||
xgb.DeviceQuantileDMatrix(X.toDlpack())
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user