Implement Python data handler. (#5689)
* Define data handlers for DMatrix. * Throw ValueError in scikit learn interface.
This commit is contained in:
@@ -22,6 +22,16 @@ def dmatrix_from_cupy(input_type, DMatrixT, missing=np.NAN):
|
||||
dtrain = DMatrixT(X, missing=missing, label=y)
|
||||
assert dtrain.num_col() == kCols
|
||||
assert dtrain.num_row() == kRows
|
||||
|
||||
if DMatrixT is xgb.DeviceQuantileDMatrix:
|
||||
# Slice is not supported by DeviceQuantileDMatrix
|
||||
with pytest.raises(xgb.core.XGBoostError):
|
||||
dtrain.slice(rindex=[0, 1, 2])
|
||||
dtrain.slice(rindex=[0, 1, 2])
|
||||
else:
|
||||
dtrain.slice(rindex=[0, 1, 2])
|
||||
dtrain.slice(rindex=[0, 1, 2])
|
||||
|
||||
return dtrain
|
||||
|
||||
|
||||
@@ -41,7 +51,7 @@ def _test_from_cupy(DMatrixT):
|
||||
|
||||
with pytest.raises(Exception):
|
||||
X = cp.random.randn(2, 2, dtype="float32")
|
||||
dtrain = DMatrixT(X, label=X)
|
||||
DMatrixT(X, label=X)
|
||||
|
||||
|
||||
def _test_cupy_training(DMatrixT):
|
||||
@@ -88,11 +98,14 @@ def _test_cupy_metainfo(DMatrixT):
|
||||
dmat_cupy.set_interface_info('group', cupy_uints)
|
||||
|
||||
# Test setting info with cupy
|
||||
assert np.array_equal(dmat.get_float_info('weight'), dmat_cupy.get_float_info('weight'))
|
||||
assert np.array_equal(dmat.get_float_info('label'), dmat_cupy.get_float_info('label'))
|
||||
assert np.array_equal(dmat.get_float_info('weight'),
|
||||
dmat_cupy.get_float_info('weight'))
|
||||
assert np.array_equal(dmat.get_float_info('label'),
|
||||
dmat_cupy.get_float_info('label'))
|
||||
assert np.array_equal(dmat.get_float_info('base_margin'),
|
||||
dmat_cupy.get_float_info('base_margin'))
|
||||
assert np.array_equal(dmat.get_uint_info('group_ptr'), dmat_cupy.get_uint_info('group_ptr'))
|
||||
assert np.array_equal(dmat.get_uint_info('group_ptr'),
|
||||
dmat_cupy.get_uint_info('group_ptr'))
|
||||
|
||||
|
||||
class TestFromCupy:
|
||||
@@ -135,7 +148,9 @@ Arrow specification.'''
|
||||
import cupy as cp
|
||||
n = 100
|
||||
X = cp.random.random((n, 2))
|
||||
xgb.DeviceQuantileDMatrix(X.toDlpack())
|
||||
m = xgb.DeviceQuantileDMatrix(X.toDlpack())
|
||||
with pytest.raises(xgb.core.XGBoostError):
|
||||
m.slice(rindex=[0, 1, 2])
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
@pytest.mark.mgpu
|
||||
|
||||
Reference in New Issue
Block a user