Add support for dlpack, expose python docs for DeviceQuantileDMatrix (#5465)
This commit is contained in:
@@ -95,7 +95,7 @@ def _test_cupy_metainfo(DMatrixT):
|
||||
assert np.array_equal(dmat.get_uint_info('group_ptr'), dmat_cupy.get_uint_info('group_ptr'))
|
||||
|
||||
|
||||
class TestFromArrayInterface:
|
||||
class TestFromCupy:
|
||||
'''Tests for constructing DMatrix from data structure conforming Apache
|
||||
Arrow specification.'''
|
||||
|
||||
@@ -122,3 +122,17 @@ Arrow specification.'''
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
def test_cupy_metainfo_device_dmat(self):
|
||||
_test_cupy_metainfo(xgb.DeviceQuantileDMatrix)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
def test_dlpack_simple_dmat(self):
|
||||
import cupy as cp
|
||||
n = 100
|
||||
X = cp.random.random((n, 2))
|
||||
xgb.DMatrix(X.toDlpack())
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
def test_dlpack_device_dmat(self):
|
||||
import cupy as cp
|
||||
n = 100
|
||||
X = cp.random.random((n, 2))
|
||||
xgb.DeviceQuantileDMatrix(X.toDlpack())
|
||||
|
||||
Reference in New Issue
Block a user