Implement Python data handler. (#5689)

* Define data handlers for DMatrix.
* Throw ValueError in scikit learn interface.
This commit is contained in:
Jiaming Yuan
2020-05-22 11:53:55 +08:00
committed by GitHub
parent 646def51e0
commit 5af8161a1a
7 changed files with 746 additions and 405 deletions

View File

@@ -22,6 +22,16 @@ def dmatrix_from_cupy(input_type, DMatrixT, missing=np.NAN):
dtrain = DMatrixT(X, missing=missing, label=y)
assert dtrain.num_col() == kCols
assert dtrain.num_row() == kRows
if DMatrixT is xgb.DeviceQuantileDMatrix:
# Slice is not supported by DeviceQuantileDMatrix
with pytest.raises(xgb.core.XGBoostError):
dtrain.slice(rindex=[0, 1, 2])
dtrain.slice(rindex=[0, 1, 2])
else:
dtrain.slice(rindex=[0, 1, 2])
dtrain.slice(rindex=[0, 1, 2])
return dtrain
@@ -41,7 +51,7 @@ def _test_from_cupy(DMatrixT):
with pytest.raises(Exception):
X = cp.random.randn(2, 2, dtype="float32")
dtrain = DMatrixT(X, label=X)
DMatrixT(X, label=X)
def _test_cupy_training(DMatrixT):
@@ -88,11 +98,14 @@ def _test_cupy_metainfo(DMatrixT):
dmat_cupy.set_interface_info('group', cupy_uints)
# Test setting info with cupy
assert np.array_equal(dmat.get_float_info('weight'), dmat_cupy.get_float_info('weight'))
assert np.array_equal(dmat.get_float_info('label'), dmat_cupy.get_float_info('label'))
assert np.array_equal(dmat.get_float_info('weight'),
dmat_cupy.get_float_info('weight'))
assert np.array_equal(dmat.get_float_info('label'),
dmat_cupy.get_float_info('label'))
assert np.array_equal(dmat.get_float_info('base_margin'),
dmat_cupy.get_float_info('base_margin'))
assert np.array_equal(dmat.get_uint_info('group_ptr'), dmat_cupy.get_uint_info('group_ptr'))
assert np.array_equal(dmat.get_uint_info('group_ptr'),
dmat_cupy.get_uint_info('group_ptr'))
class TestFromCupy:
@@ -135,7 +148,9 @@ Arrow specification.'''
import cupy as cp
n = 100
X = cp.random.random((n, 2))
xgb.DeviceQuantileDMatrix(X.toDlpack())
m = xgb.DeviceQuantileDMatrix(X.toDlpack())
with pytest.raises(xgb.core.XGBoostError):
m.slice(rindex=[0, 1, 2])
@pytest.mark.skipif(**tm.no_cupy())
@pytest.mark.mgpu