Add support for dlpack, expose python docs for DeviceQuantileDMatrix (#5465)

This commit is contained in:
Rory Mitchell
2020-04-01 23:34:32 +13:00
committed by GitHub
parent 6601a641d7
commit 15f40e51e9
3 changed files with 40 additions and 9 deletions

View File

@@ -95,7 +95,7 @@ def _test_cupy_metainfo(DMatrixT):
assert np.array_equal(dmat.get_uint_info('group_ptr'), dmat_cupy.get_uint_info('group_ptr'))
class TestFromArrayInterface:
class TestFromCupy:
'''Tests for constructing DMatrix from data structure conforming Apache
Arrow specification.'''
@@ -122,3 +122,17 @@ Arrow specification.'''
@pytest.mark.skipif(**tm.no_cupy())
def test_cupy_metainfo_device_dmat(self):
_test_cupy_metainfo(xgb.DeviceQuantileDMatrix)
@pytest.mark.skipif(**tm.no_cupy())
def test_dlpack_simple_dmat(self):
import cupy as cp
n = 100
X = cp.random.random((n, 2))
xgb.DMatrix(X.toDlpack())
@pytest.mark.skipif(**tm.no_cupy())
def test_dlpack_device_dmat(self):
import cupy as cp
n = 100
X = cp.random.random((n, 2))
xgb.DeviceQuantileDMatrix(X.toDlpack())