Remove all use of DeviceQuantileDMatrix. (#8665)

This commit is contained in:
Jiaming Yuan 2023-01-17 00:04:10 +08:00 committed by GitHub
parent 0ae8df9a65
commit d6018eb4b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 57 additions and 53 deletions

View File

@ -38,25 +38,23 @@ def using_dask_matrix(client: Client, X, y):
def using_quantile_device_dmatrix(client: Client, X, y): def using_quantile_device_dmatrix(client: Client, X, y):
'''`DaskDeviceQuantileDMatrix` is a data type specialized for `gpu_hist`, tree """`DaskQuantileDMatrix` is a data type specialized for `gpu_hist`, tree
method that reduces memory overhead. When training on GPU pipeline, it's method that reduces memory overhead. When training on GPU pipeline, it's
preferred over `DaskDMatrix`. preferred over `DaskDMatrix`.
.. versionadded:: 1.2.0 .. versionadded:: 1.2.0
''' """
# Input must be on GPU for `DaskDeviceQuantileDMatrix`. # Input must be on GPU for `DaskQuantileDMatrix`.
X = dask_cudf.from_dask_dataframe(dd.from_dask_array(X)) X = dask_cudf.from_dask_dataframe(dd.from_dask_array(X))
y = dask_cudf.from_dask_dataframe(dd.from_dask_array(y)) y = dask_cudf.from_dask_dataframe(dd.from_dask_array(y))
# `DaskDeviceQuantileDMatrix` is used instead of `DaskDMatrix`, be careful # `DaskQuantileDMatrix` is used instead of `DaskDMatrix`, be careful
# that it can not be used for anything else other than training. # that it can not be used for anything else other than training.
dtrain = dxgb.DaskQuantileDMatrix(client, X, y) dtrain = dxgb.DaskQuantileDMatrix(client, X, y)
output = xgb.dask.train(client, output = xgb.dask.train(
{'verbosity': 2, client, {"verbosity": 2, "tree_method": "gpu_hist"}, dtrain, num_boost_round=4
'tree_method': 'gpu_hist'}, )
dtrain,
num_boost_round=4)
prediction = xgb.dask.predict(client, output, X) prediction = xgb.dask.predict(client, output, X)
return prediction return prediction

View File

@ -1,11 +1,11 @@
''' """
Demo for using data iterator with Quantile DMatrix Demo for using data iterator with Quantile DMatrix
================================================== ==================================================
.. versionadded:: 1.2.0 .. versionadded:: 1.2.0
The demo that defines a customized iterator for passing batches of data into The demo that defines a customized iterator for passing batches of data into
`xgboost.DeviceQuantileDMatrix` and use this `DeviceQuantileDMatrix` for :py:class:`xgboost.QuantileDMatrix` and use this ``QuantileDMatrix`` for
training. The feature is used primarily designed to reduce the required GPU training. The feature is used primarily designed to reduce the required GPU
memory for training on distributed environment. memory for training on distributed environment.
@ -15,7 +15,7 @@ using `itertools.tee` might incur significant memory usage according to:
https://docs.python.org/3/library/itertools.html#itertools.tee. https://docs.python.org/3/library/itertools.html#itertools.tee.
''' """
import cupy import cupy
import numpy import numpy
@ -88,26 +88,32 @@ def main():
rounds = 100 rounds = 100
it = IterForDMatrixDemo() it = IterForDMatrixDemo()
# Use iterator, must be `DeviceQuantileDMatrix` for quantile DMatrix. # Use iterator, must be `QuantileDMatrix`.
m_with_it = xgboost.DeviceQuantileDMatrix(it)
# In this demo, the input batches are created using cupy, and the data processing
# (quantile sketching) will be performed on GPU. If data is loaded with CPU based
# data structures like numpy or pandas, then the processing step will be performed
# on CPU instead.
m_with_it = xgboost.QuantileDMatrix(it)
# Use regular DMatrix. # Use regular DMatrix.
m = xgboost.DMatrix(it.as_array(), it.as_array_labels(), m = xgboost.DMatrix(
weight=it.as_array_weights()) it.as_array(), it.as_array_labels(), weight=it.as_array_weights()
)
assert m_with_it.num_col() == m.num_col() assert m_with_it.num_col() == m.num_col()
assert m_with_it.num_row() == m.num_row() assert m_with_it.num_row() == m.num_row()
# Tree meethod must be one of the `hist` or `gpu_hist`. We use `gpu_hist` for GPU
reg_with_it = xgboost.train({'tree_method': 'gpu_hist'}, m_with_it, # input here.
num_boost_round=rounds) reg_with_it = xgboost.train(
{"tree_method": "gpu_hist"}, m_with_it, num_boost_round=rounds
)
predict_with_it = reg_with_it.predict(m_with_it) predict_with_it = reg_with_it.predict(m_with_it)
reg = xgboost.train({'tree_method': 'gpu_hist'}, m, reg = xgboost.train({"tree_method": "gpu_hist"}, m, num_boost_round=rounds)
num_boost_round=rounds)
predict = reg.predict(m) predict = reg.predict(m)
numpy.testing.assert_allclose(predict_with_it, predict, numpy.testing.assert_allclose(predict_with_it, predict, rtol=1e6)
rtol=1e6)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -113,7 +113,7 @@ TrainReturnT = TypedDict(
__all__ = [ __all__ = [
"CommunicatorContext", "CommunicatorContext",
"DaskDMatrix", "DaskDMatrix",
"DaskDeviceQuantileDMatrix", "DaskQuantileDMatrix",
"DaskXGBRegressor", "DaskXGBRegressor",
"DaskXGBClassifier", "DaskXGBClassifier",
"DaskXGBRanker", "DaskXGBRanker",
@ -559,7 +559,7 @@ def _get_worker_parts(list_of_parts: _DataParts) -> Dict[str, List[Any]]:
class DaskPartitionIter(DataIter): # pylint: disable=R0902 class DaskPartitionIter(DataIter): # pylint: disable=R0902
"""A data iterator for `DaskDeviceQuantileDMatrix`.""" """A data iterator for `DaskQuantileDMatrix`."""
def __init__( def __init__(
self, self,

View File

@ -1229,7 +1229,7 @@ def dispatch_proxy_set_data(
cat_codes: Optional[list], cat_codes: Optional[list],
allow_host: bool, allow_host: bool,
) -> None: ) -> None:
"""Dispatch for DeviceQuantileDMatrix.""" """Dispatch for QuantileDMatrix."""
if not _is_cudf_ser(data) and not _is_pandas_series(data): if not _is_cudf_ser(data) and not _is_pandas_series(data):
_check_data_shape(data) _check_data_shape(data)

View File

@ -11,7 +11,7 @@ sys.path.append("tests/python")
import test_quantile_dmatrix as tqd import test_quantile_dmatrix as tqd
class TestDeviceQuantileDMatrix: class TestQuantileDMatrix:
cputest = tqd.TestQuantileDMatrix() cputest = tqd.TestQuantileDMatrix()
@pytest.mark.skipif(**tm.no_cupy()) @pytest.mark.skipif(**tm.no_cupy())
@ -32,7 +32,7 @@ class TestDeviceQuantileDMatrix:
def test_dmatrix_cupy_init(self) -> None: def test_dmatrix_cupy_init(self) -> None:
import cupy as cp import cupy as cp
data = cp.random.randn(5, 5) data = cp.random.randn(5, 5)
xgb.DeviceQuantileDMatrix(data, cp.ones(5, dtype=np.float64)) xgb.QuantileDMatrix(data, cp.ones(5, dtype=np.float64))
@pytest.mark.skipif(**tm.no_cupy()) @pytest.mark.skipif(**tm.no_cupy())
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -85,7 +85,7 @@ class TestDeviceQuantileDMatrix:
fw = rng.randn(rows) fw = rng.randn(rows)
fw -= fw.min() fw -= fw.min()
m = xgb.DeviceQuantileDMatrix(data=data, label=labels, feature_weights=fw) m = xgb.QuantileDMatrix(data=data, label=labels, feature_weights=fw)
got_fw = m.get_float_info("feature_weights") got_fw = m.get_float_info("feature_weights")
got_labels = m.get_label() got_labels = m.get_label()

View File

@ -160,7 +160,7 @@ Arrow specification.'''
@pytest.mark.skipif(**tm.no_cudf()) @pytest.mark.skipif(**tm.no_cudf())
def test_device_dmatrix_from_cudf(self): def test_device_dmatrix_from_cudf(self):
_test_from_cudf(xgb.DeviceQuantileDMatrix) _test_from_cudf(xgb.QuantileDMatrix)
@pytest.mark.skipif(**tm.no_cudf()) @pytest.mark.skipif(**tm.no_cudf())
def test_cudf_training_simple_dmatrix(self): def test_cudf_training_simple_dmatrix(self):
@ -168,7 +168,7 @@ Arrow specification.'''
@pytest.mark.skipif(**tm.no_cudf()) @pytest.mark.skipif(**tm.no_cudf())
def test_cudf_training_device_dmatrix(self): def test_cudf_training_device_dmatrix(self):
_test_cudf_training(xgb.DeviceQuantileDMatrix) _test_cudf_training(xgb.QuantileDMatrix)
@pytest.mark.skipif(**tm.no_cudf()) @pytest.mark.skipif(**tm.no_cudf())
def test_cudf_metainfo_simple_dmatrix(self): def test_cudf_metainfo_simple_dmatrix(self):
@ -176,7 +176,7 @@ Arrow specification.'''
@pytest.mark.skipif(**tm.no_cudf()) @pytest.mark.skipif(**tm.no_cudf())
def test_cudf_metainfo_device_dmatrix(self): def test_cudf_metainfo_device_dmatrix(self):
_test_cudf_metainfo(xgb.DeviceQuantileDMatrix) _test_cudf_metainfo(xgb.QuantileDMatrix)
@pytest.mark.skipif(**tm.no_cudf()) @pytest.mark.skipif(**tm.no_cudf())
def test_cudf_categorical(self) -> None: def test_cudf_categorical(self) -> None:
@ -191,7 +191,7 @@ Arrow specification.'''
assert len(Xy.feature_types) == X.shape[1] assert len(Xy.feature_types) == X.shape[1]
assert all(t == "c" for t in Xy.feature_types) assert all(t == "c" for t in Xy.feature_types)
Xy = xgb.DeviceQuantileDMatrix(X, y, enable_categorical=True) Xy = xgb.QuantileDMatrix(X, y, enable_categorical=True)
assert Xy.feature_types is not None assert Xy.feature_types is not None
assert len(Xy.feature_types) == X.shape[1] assert len(Xy.feature_types) == X.shape[1]
assert all(t == "c" for t in Xy.feature_types) assert all(t == "c" for t in Xy.feature_types)
@ -228,9 +228,9 @@ Arrow specification.'''
assert Xy.num_col() == 1 assert Xy.num_col() == 1
with pytest.raises(ValueError, match="enable_categorical"): with pytest.raises(ValueError, match="enable_categorical"):
xgb.DeviceQuantileDMatrix(X, y) xgb.QuantileDMatrix(X, y)
Xy = xgb.DeviceQuantileDMatrix(X, y, enable_categorical=True) Xy = xgb.QuantileDMatrix(X, y, enable_categorical=True)
assert Xy.num_row() == 3 assert Xy.num_row() == 3
assert Xy.num_col() == 1 assert Xy.num_col() == 1
@ -344,7 +344,7 @@ def test_from_cudf_iter(enable_categorical):
params = {"tree_method": "gpu_hist"} params = {"tree_method": "gpu_hist"}
# Use iterator # Use iterator
m_it = xgb.DeviceQuantileDMatrix(it, enable_categorical=enable_categorical) m_it = xgb.QuantileDMatrix(it, enable_categorical=enable_categorical)
reg_with_it = xgb.train(params, m_it, num_boost_round=rounds) reg_with_it = xgb.train(params, m_it, num_boost_round=rounds)
X = it.as_array() X = it.as_array()

View File

@ -27,8 +27,8 @@ def dmatrix_from_cupy(input_type, DMatrixT, missing=np.NAN):
assert dtrain.num_col() == kCols assert dtrain.num_col() == kCols
assert dtrain.num_row() == kRows assert dtrain.num_row() == kRows
if DMatrixT is xgb.DeviceQuantileDMatrix: if DMatrixT is xgb.QuantileDMatrix:
# Slice is not supported by DeviceQuantileDMatrix # Slice is not supported by QuantileDMatrix
with pytest.raises(xgb.core.XGBoostError): with pytest.raises(xgb.core.XGBoostError):
dtrain.slice(rindex=[0, 1, 2]) dtrain.slice(rindex=[0, 1, 2])
dtrain.slice(rindex=[0, 1, 2]) dtrain.slice(rindex=[0, 1, 2])
@ -153,11 +153,11 @@ Arrow specification.'''
@pytest.mark.skipif(**tm.no_cupy()) @pytest.mark.skipif(**tm.no_cupy())
def test_device_dmat_from_cupy(self): def test_device_dmat_from_cupy(self):
_test_from_cupy(xgb.DeviceQuantileDMatrix) _test_from_cupy(xgb.QuantileDMatrix)
@pytest.mark.skipif(**tm.no_cupy()) @pytest.mark.skipif(**tm.no_cupy())
def test_cupy_training_device_dmat(self): def test_cupy_training_device_dmat(self):
_test_cupy_training(xgb.DeviceQuantileDMatrix) _test_cupy_training(xgb.QuantileDMatrix)
@pytest.mark.skipif(**tm.no_cupy()) @pytest.mark.skipif(**tm.no_cupy())
def test_cupy_training_simple_dmat(self): def test_cupy_training_simple_dmat(self):
@ -169,7 +169,7 @@ Arrow specification.'''
@pytest.mark.skipif(**tm.no_cupy()) @pytest.mark.skipif(**tm.no_cupy())
def test_cupy_metainfo_device_dmat(self): def test_cupy_metainfo_device_dmat(self):
_test_cupy_metainfo(xgb.DeviceQuantileDMatrix) _test_cupy_metainfo(xgb.QuantileDMatrix)
@pytest.mark.skipif(**tm.no_cupy()) @pytest.mark.skipif(**tm.no_cupy())
def test_dlpack_simple_dmat(self): def test_dlpack_simple_dmat(self):
@ -196,7 +196,7 @@ Arrow specification.'''
import cupy as cp import cupy as cp
n = 100 n = 100
X = cp.random.random((n, 2)) X = cp.random.random((n, 2))
m = xgb.DeviceQuantileDMatrix(X.toDlpack()) m = xgb.QuantileDMatrix(X.toDlpack())
with pytest.raises(xgb.core.XGBoostError): with pytest.raises(xgb.core.XGBoostError):
m.slice(rindex=[0, 1, 2]) m.slice(rindex=[0, 1, 2])
@ -222,7 +222,7 @@ Arrow specification.'''
import cupy as cp import cupy as cp
cp.cuda.runtime.setDevice(0) cp.cuda.runtime.setDevice(0)
dtrain = dmatrix_from_cupy( dtrain = dmatrix_from_cupy(
np.float32, xgb.DeviceQuantileDMatrix, np.nan) np.float32, xgb.QuantileDMatrix, np.nan)
with pytest.raises(xgb.core.XGBoostError): with pytest.raises(xgb.core.XGBoostError):
xgb.train( xgb.train(
{'tree_method': 'gpu_hist', 'gpu_id': 1}, dtrain, num_boost_round=10 {'tree_method': 'gpu_hist', 'gpu_id': 1}, dtrain, num_boost_round=10

View File

@ -17,7 +17,7 @@ def test_large_input():
assert (np.log2(m * n) > 31) assert (np.log2(m * n) > 31)
X = cp.ones((m, n), dtype=np.float32) X = cp.ones((m, n), dtype=np.float32)
y = cp.ones(m) y = cp.ones(m)
dmat = xgb.DeviceQuantileDMatrix(X, y) dmat = xgb.QuantileDMatrix(X, y)
booster = xgb.train({"tree_method": "gpu_hist", "max_depth": 1}, dmat, 1) booster = xgb.train({"tree_method": "gpu_hist", "max_depth": 1}, dmat, 1)
del y del y
booster.inplace_predict(X) booster.inplace_predict(X)

View File

@ -173,7 +173,7 @@ class TestTreeMethod:
X, y = cp.array(X), cp.array(y) X, y = cp.array(X), cp.array(y)
with pytest.raises(ValueError): with pytest.raises(ValueError):
Xy = xgb.DeviceQuantileDMatrix(X, y, feature_types=["c"] * 10) Xy = xgb.QuantileDMatrix(X, y, feature_types=["c"] * 10)
def test_invalid_category(self) -> None: def test_invalid_category(self) -> None:
self.run_invalid_category("approx") self.run_invalid_category("approx")

View File

@ -135,7 +135,7 @@ def run_with_dask_array(DMatrixT: Type, client: Client) -> None:
def to_cp(x: Any, DMatrixT: Type) -> Any: def to_cp(x: Any, DMatrixT: Type) -> Any:
import cupy import cupy
if isinstance(x, np.ndarray) and DMatrixT is dxgb.DaskDeviceQuantileDMatrix: if isinstance(x, np.ndarray) and DMatrixT is dxgb.DaskQuantileDMatrix:
X = cupy.array(x) X = cupy.array(x)
else: else:
X = x X = x
@ -169,7 +169,7 @@ def run_gpu_hist(
else: else:
w = None w = None
if DMatrixT is dxgb.DaskDeviceQuantileDMatrix: if DMatrixT is dxgb.DaskQuantileDMatrix:
m = DMatrixT( m = DMatrixT(
client, data=X, label=y, weight=w, max_bin=params.get("max_bin", 256) client, data=X, label=y, weight=w, max_bin=params.get("max_bin", 256)
) )
@ -227,7 +227,7 @@ class TestDistributedGPU:
@pytest.mark.skipif(**tm.no_dask_cudf()) @pytest.mark.skipif(**tm.no_dask_cudf())
def test_dask_dataframe(self, local_cuda_client: Client) -> None: def test_dask_dataframe(self, local_cuda_client: Client) -> None:
run_with_dask_dataframe(dxgb.DaskDMatrix, local_cuda_client) run_with_dask_dataframe(dxgb.DaskDMatrix, local_cuda_client)
run_with_dask_dataframe(dxgb.DaskDeviceQuantileDMatrix, local_cuda_client) run_with_dask_dataframe(dxgb.DaskQuantileDMatrix, local_cuda_client)
@pytest.mark.skipif(**tm.no_dask_cudf()) @pytest.mark.skipif(**tm.no_dask_cudf())
def test_categorical(self, local_cuda_client: Client) -> None: def test_categorical(self, local_cuda_client: Client) -> None:
@ -245,7 +245,7 @@ class TestDistributedGPU:
num_rounds=strategies.integers(1, 20), num_rounds=strategies.integers(1, 20),
dataset=tm.dataset_strategy, dataset=tm.dataset_strategy,
dmatrix_type=strategies.sampled_from( dmatrix_type=strategies.sampled_from(
[dxgb.DaskDMatrix, dxgb.DaskDeviceQuantileDMatrix] [dxgb.DaskDMatrix, dxgb.DaskQuantileDMatrix]
), ),
) )
@settings( @settings(
@ -268,7 +268,7 @@ class TestDistributedGPU:
@pytest.mark.skipif(**tm.no_cupy()) @pytest.mark.skipif(**tm.no_cupy())
def test_dask_array(self, local_cuda_client: Client) -> None: def test_dask_array(self, local_cuda_client: Client) -> None:
run_with_dask_array(dxgb.DaskDMatrix, local_cuda_client) run_with_dask_array(dxgb.DaskDMatrix, local_cuda_client)
run_with_dask_array(dxgb.DaskDeviceQuantileDMatrix, local_cuda_client) run_with_dask_array(dxgb.DaskQuantileDMatrix, local_cuda_client)
@pytest.mark.skipif(**tm.no_cupy()) @pytest.mark.skipif(**tm.no_cupy())
def test_early_stopping(self, local_cuda_client: Client) -> None: def test_early_stopping(self, local_cuda_client: Client) -> None:
@ -357,7 +357,7 @@ class TestDistributedGPU:
) )
X = ddf[ddf.columns.difference(["y"])] X = ddf[ddf.columns.difference(["y"])]
y = ddf[["y"]] y = ddf[["y"]]
dtrain = dxgb.DaskDeviceQuantileDMatrix(local_cuda_client, X, y) dtrain = dxgb.DaskQuantileDMatrix(local_cuda_client, X, y)
bst_empty = xgb.dask.train( bst_empty = xgb.dask.train(
local_cuda_client, parameters, dtrain, evals=[(dtrain, "train")] local_cuda_client, parameters, dtrain, evals=[(dtrain, "train")]
) )
@ -369,7 +369,7 @@ class TestDistributedGPU:
) )
X = ddf[ddf.columns.difference(["y"])] X = ddf[ddf.columns.difference(["y"])]
y = ddf[["y"]] y = ddf[["y"]]
dtrain = dxgb.DaskDeviceQuantileDMatrix(local_cuda_client, X, y) dtrain = dxgb.DaskQuantileDMatrix(local_cuda_client, X, y)
bst = xgb.dask.train( bst = xgb.dask.train(
local_cuda_client, parameters, dtrain, evals=[(dtrain, "train")] local_cuda_client, parameters, dtrain, evals=[(dtrain, "train")]
) )
@ -546,7 +546,7 @@ async def run_from_dask_array_asyncio(scheduler_address: str) -> dxgb.TrainRetur
X = X.map_blocks(cp.array) X = X.map_blocks(cp.array)
y = y.map_blocks(cp.array) y = y.map_blocks(cp.array)
m = await xgb.dask.DaskDeviceQuantileDMatrix(client, X, y) m = await xgb.dask.DaskQuantileDMatrix(client, X, y)
output = await xgb.dask.train(client, {"tree_method": "gpu_hist"}, dtrain=m) output = await xgb.dask.train(client, {"tree_method": "gpu_hist"}, dtrain=m)
with_m = await xgb.dask.predict(client, output, m) with_m = await xgb.dask.predict(client, output, m)