Remove all use of DeviceQuantileDMatrix. (#8665)
This commit is contained in:
parent
0ae8df9a65
commit
d6018eb4b9
@ -38,25 +38,23 @@ def using_dask_matrix(client: Client, X, y):
|
|||||||
|
|
||||||
|
|
||||||
def using_quantile_device_dmatrix(client: Client, X, y):
|
def using_quantile_device_dmatrix(client: Client, X, y):
|
||||||
'''`DaskDeviceQuantileDMatrix` is a data type specialized for `gpu_hist`, tree
|
"""`DaskQuantileDMatrix` is a data type specialized for `gpu_hist`, tree
|
||||||
method that reduces memory overhead. When training on GPU pipeline, it's
|
method that reduces memory overhead. When training on GPU pipeline, it's
|
||||||
preferred over `DaskDMatrix`.
|
preferred over `DaskDMatrix`.
|
||||||
|
|
||||||
.. versionadded:: 1.2.0
|
.. versionadded:: 1.2.0
|
||||||
|
|
||||||
'''
|
"""
|
||||||
# Input must be on GPU for `DaskDeviceQuantileDMatrix`.
|
# Input must be on GPU for `DaskQuantileDMatrix`.
|
||||||
X = dask_cudf.from_dask_dataframe(dd.from_dask_array(X))
|
X = dask_cudf.from_dask_dataframe(dd.from_dask_array(X))
|
||||||
y = dask_cudf.from_dask_dataframe(dd.from_dask_array(y))
|
y = dask_cudf.from_dask_dataframe(dd.from_dask_array(y))
|
||||||
|
|
||||||
# `DaskDeviceQuantileDMatrix` is used instead of `DaskDMatrix`, be careful
|
# `DaskQuantileDMatrix` is used instead of `DaskDMatrix`, be careful
|
||||||
# that it can not be used for anything else other than training.
|
# that it can not be used for anything else other than training.
|
||||||
dtrain = dxgb.DaskQuantileDMatrix(client, X, y)
|
dtrain = dxgb.DaskQuantileDMatrix(client, X, y)
|
||||||
output = xgb.dask.train(client,
|
output = xgb.dask.train(
|
||||||
{'verbosity': 2,
|
client, {"verbosity": 2, "tree_method": "gpu_hist"}, dtrain, num_boost_round=4
|
||||||
'tree_method': 'gpu_hist'},
|
)
|
||||||
dtrain,
|
|
||||||
num_boost_round=4)
|
|
||||||
|
|
||||||
prediction = xgb.dask.predict(client, output, X)
|
prediction = xgb.dask.predict(client, output, X)
|
||||||
return prediction
|
return prediction
|
||||||
|
|||||||
@ -1,11 +1,11 @@
|
|||||||
'''
|
"""
|
||||||
Demo for using data iterator with Quantile DMatrix
|
Demo for using data iterator with Quantile DMatrix
|
||||||
==================================================
|
==================================================
|
||||||
|
|
||||||
.. versionadded:: 1.2.0
|
.. versionadded:: 1.2.0
|
||||||
|
|
||||||
The demo that defines a customized iterator for passing batches of data into
|
The demo that defines a customized iterator for passing batches of data into
|
||||||
`xgboost.DeviceQuantileDMatrix` and use this `DeviceQuantileDMatrix` for
|
:py:class:`xgboost.QuantileDMatrix` and use this ``QuantileDMatrix`` for
|
||||||
training. The feature is used primarily designed to reduce the required GPU
|
training. The feature is used primarily designed to reduce the required GPU
|
||||||
memory for training on distributed environment.
|
memory for training on distributed environment.
|
||||||
|
|
||||||
@ -15,7 +15,7 @@ using `itertools.tee` might incur significant memory usage according to:
|
|||||||
|
|
||||||
https://docs.python.org/3/library/itertools.html#itertools.tee.
|
https://docs.python.org/3/library/itertools.html#itertools.tee.
|
||||||
|
|
||||||
'''
|
"""
|
||||||
|
|
||||||
import cupy
|
import cupy
|
||||||
import numpy
|
import numpy
|
||||||
@ -88,26 +88,32 @@ def main():
|
|||||||
rounds = 100
|
rounds = 100
|
||||||
it = IterForDMatrixDemo()
|
it = IterForDMatrixDemo()
|
||||||
|
|
||||||
# Use iterator, must be `DeviceQuantileDMatrix` for quantile DMatrix.
|
# Use iterator, must be `QuantileDMatrix`.
|
||||||
m_with_it = xgboost.DeviceQuantileDMatrix(it)
|
|
||||||
|
# In this demo, the input batches are created using cupy, and the data processing
|
||||||
|
# (quantile sketching) will be performed on GPU. If data is loaded with CPU based
|
||||||
|
# data structures like numpy or pandas, then the processing step will be performed
|
||||||
|
# on CPU instead.
|
||||||
|
m_with_it = xgboost.QuantileDMatrix(it)
|
||||||
|
|
||||||
# Use regular DMatrix.
|
# Use regular DMatrix.
|
||||||
m = xgboost.DMatrix(it.as_array(), it.as_array_labels(),
|
m = xgboost.DMatrix(
|
||||||
weight=it.as_array_weights())
|
it.as_array(), it.as_array_labels(), weight=it.as_array_weights()
|
||||||
|
)
|
||||||
|
|
||||||
assert m_with_it.num_col() == m.num_col()
|
assert m_with_it.num_col() == m.num_col()
|
||||||
assert m_with_it.num_row() == m.num_row()
|
assert m_with_it.num_row() == m.num_row()
|
||||||
|
# Tree meethod must be one of the `hist` or `gpu_hist`. We use `gpu_hist` for GPU
|
||||||
reg_with_it = xgboost.train({'tree_method': 'gpu_hist'}, m_with_it,
|
# input here.
|
||||||
num_boost_round=rounds)
|
reg_with_it = xgboost.train(
|
||||||
|
{"tree_method": "gpu_hist"}, m_with_it, num_boost_round=rounds
|
||||||
|
)
|
||||||
predict_with_it = reg_with_it.predict(m_with_it)
|
predict_with_it = reg_with_it.predict(m_with_it)
|
||||||
|
|
||||||
reg = xgboost.train({'tree_method': 'gpu_hist'}, m,
|
reg = xgboost.train({"tree_method": "gpu_hist"}, m, num_boost_round=rounds)
|
||||||
num_boost_round=rounds)
|
|
||||||
predict = reg.predict(m)
|
predict = reg.predict(m)
|
||||||
|
|
||||||
numpy.testing.assert_allclose(predict_with_it, predict,
|
numpy.testing.assert_allclose(predict_with_it, predict, rtol=1e6)
|
||||||
rtol=1e6)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@ -113,7 +113,7 @@ TrainReturnT = TypedDict(
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"CommunicatorContext",
|
"CommunicatorContext",
|
||||||
"DaskDMatrix",
|
"DaskDMatrix",
|
||||||
"DaskDeviceQuantileDMatrix",
|
"DaskQuantileDMatrix",
|
||||||
"DaskXGBRegressor",
|
"DaskXGBRegressor",
|
||||||
"DaskXGBClassifier",
|
"DaskXGBClassifier",
|
||||||
"DaskXGBRanker",
|
"DaskXGBRanker",
|
||||||
@ -559,7 +559,7 @@ def _get_worker_parts(list_of_parts: _DataParts) -> Dict[str, List[Any]]:
|
|||||||
|
|
||||||
|
|
||||||
class DaskPartitionIter(DataIter): # pylint: disable=R0902
|
class DaskPartitionIter(DataIter): # pylint: disable=R0902
|
||||||
"""A data iterator for `DaskDeviceQuantileDMatrix`."""
|
"""A data iterator for `DaskQuantileDMatrix`."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -1229,7 +1229,7 @@ def dispatch_proxy_set_data(
|
|||||||
cat_codes: Optional[list],
|
cat_codes: Optional[list],
|
||||||
allow_host: bool,
|
allow_host: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Dispatch for DeviceQuantileDMatrix."""
|
"""Dispatch for QuantileDMatrix."""
|
||||||
if not _is_cudf_ser(data) and not _is_pandas_series(data):
|
if not _is_cudf_ser(data) and not _is_pandas_series(data):
|
||||||
_check_data_shape(data)
|
_check_data_shape(data)
|
||||||
|
|
||||||
|
|||||||
@ -11,7 +11,7 @@ sys.path.append("tests/python")
|
|||||||
import test_quantile_dmatrix as tqd
|
import test_quantile_dmatrix as tqd
|
||||||
|
|
||||||
|
|
||||||
class TestDeviceQuantileDMatrix:
|
class TestQuantileDMatrix:
|
||||||
cputest = tqd.TestQuantileDMatrix()
|
cputest = tqd.TestQuantileDMatrix()
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_cupy())
|
@pytest.mark.skipif(**tm.no_cupy())
|
||||||
@ -32,7 +32,7 @@ class TestDeviceQuantileDMatrix:
|
|||||||
def test_dmatrix_cupy_init(self) -> None:
|
def test_dmatrix_cupy_init(self) -> None:
|
||||||
import cupy as cp
|
import cupy as cp
|
||||||
data = cp.random.randn(5, 5)
|
data = cp.random.randn(5, 5)
|
||||||
xgb.DeviceQuantileDMatrix(data, cp.ones(5, dtype=np.float64))
|
xgb.QuantileDMatrix(data, cp.ones(5, dtype=np.float64))
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_cupy())
|
@pytest.mark.skipif(**tm.no_cupy())
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -85,7 +85,7 @@ class TestDeviceQuantileDMatrix:
|
|||||||
fw = rng.randn(rows)
|
fw = rng.randn(rows)
|
||||||
fw -= fw.min()
|
fw -= fw.min()
|
||||||
|
|
||||||
m = xgb.DeviceQuantileDMatrix(data=data, label=labels, feature_weights=fw)
|
m = xgb.QuantileDMatrix(data=data, label=labels, feature_weights=fw)
|
||||||
|
|
||||||
got_fw = m.get_float_info("feature_weights")
|
got_fw = m.get_float_info("feature_weights")
|
||||||
got_labels = m.get_label()
|
got_labels = m.get_label()
|
||||||
|
|||||||
@ -160,7 +160,7 @@ Arrow specification.'''
|
|||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_cudf())
|
@pytest.mark.skipif(**tm.no_cudf())
|
||||||
def test_device_dmatrix_from_cudf(self):
|
def test_device_dmatrix_from_cudf(self):
|
||||||
_test_from_cudf(xgb.DeviceQuantileDMatrix)
|
_test_from_cudf(xgb.QuantileDMatrix)
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_cudf())
|
@pytest.mark.skipif(**tm.no_cudf())
|
||||||
def test_cudf_training_simple_dmatrix(self):
|
def test_cudf_training_simple_dmatrix(self):
|
||||||
@ -168,7 +168,7 @@ Arrow specification.'''
|
|||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_cudf())
|
@pytest.mark.skipif(**tm.no_cudf())
|
||||||
def test_cudf_training_device_dmatrix(self):
|
def test_cudf_training_device_dmatrix(self):
|
||||||
_test_cudf_training(xgb.DeviceQuantileDMatrix)
|
_test_cudf_training(xgb.QuantileDMatrix)
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_cudf())
|
@pytest.mark.skipif(**tm.no_cudf())
|
||||||
def test_cudf_metainfo_simple_dmatrix(self):
|
def test_cudf_metainfo_simple_dmatrix(self):
|
||||||
@ -176,7 +176,7 @@ Arrow specification.'''
|
|||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_cudf())
|
@pytest.mark.skipif(**tm.no_cudf())
|
||||||
def test_cudf_metainfo_device_dmatrix(self):
|
def test_cudf_metainfo_device_dmatrix(self):
|
||||||
_test_cudf_metainfo(xgb.DeviceQuantileDMatrix)
|
_test_cudf_metainfo(xgb.QuantileDMatrix)
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_cudf())
|
@pytest.mark.skipif(**tm.no_cudf())
|
||||||
def test_cudf_categorical(self) -> None:
|
def test_cudf_categorical(self) -> None:
|
||||||
@ -191,7 +191,7 @@ Arrow specification.'''
|
|||||||
assert len(Xy.feature_types) == X.shape[1]
|
assert len(Xy.feature_types) == X.shape[1]
|
||||||
assert all(t == "c" for t in Xy.feature_types)
|
assert all(t == "c" for t in Xy.feature_types)
|
||||||
|
|
||||||
Xy = xgb.DeviceQuantileDMatrix(X, y, enable_categorical=True)
|
Xy = xgb.QuantileDMatrix(X, y, enable_categorical=True)
|
||||||
assert Xy.feature_types is not None
|
assert Xy.feature_types is not None
|
||||||
assert len(Xy.feature_types) == X.shape[1]
|
assert len(Xy.feature_types) == X.shape[1]
|
||||||
assert all(t == "c" for t in Xy.feature_types)
|
assert all(t == "c" for t in Xy.feature_types)
|
||||||
@ -228,9 +228,9 @@ Arrow specification.'''
|
|||||||
assert Xy.num_col() == 1
|
assert Xy.num_col() == 1
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="enable_categorical"):
|
with pytest.raises(ValueError, match="enable_categorical"):
|
||||||
xgb.DeviceQuantileDMatrix(X, y)
|
xgb.QuantileDMatrix(X, y)
|
||||||
|
|
||||||
Xy = xgb.DeviceQuantileDMatrix(X, y, enable_categorical=True)
|
Xy = xgb.QuantileDMatrix(X, y, enable_categorical=True)
|
||||||
assert Xy.num_row() == 3
|
assert Xy.num_row() == 3
|
||||||
assert Xy.num_col() == 1
|
assert Xy.num_col() == 1
|
||||||
|
|
||||||
@ -344,7 +344,7 @@ def test_from_cudf_iter(enable_categorical):
|
|||||||
params = {"tree_method": "gpu_hist"}
|
params = {"tree_method": "gpu_hist"}
|
||||||
|
|
||||||
# Use iterator
|
# Use iterator
|
||||||
m_it = xgb.DeviceQuantileDMatrix(it, enable_categorical=enable_categorical)
|
m_it = xgb.QuantileDMatrix(it, enable_categorical=enable_categorical)
|
||||||
reg_with_it = xgb.train(params, m_it, num_boost_round=rounds)
|
reg_with_it = xgb.train(params, m_it, num_boost_round=rounds)
|
||||||
|
|
||||||
X = it.as_array()
|
X = it.as_array()
|
||||||
|
|||||||
@ -27,8 +27,8 @@ def dmatrix_from_cupy(input_type, DMatrixT, missing=np.NAN):
|
|||||||
assert dtrain.num_col() == kCols
|
assert dtrain.num_col() == kCols
|
||||||
assert dtrain.num_row() == kRows
|
assert dtrain.num_row() == kRows
|
||||||
|
|
||||||
if DMatrixT is xgb.DeviceQuantileDMatrix:
|
if DMatrixT is xgb.QuantileDMatrix:
|
||||||
# Slice is not supported by DeviceQuantileDMatrix
|
# Slice is not supported by QuantileDMatrix
|
||||||
with pytest.raises(xgb.core.XGBoostError):
|
with pytest.raises(xgb.core.XGBoostError):
|
||||||
dtrain.slice(rindex=[0, 1, 2])
|
dtrain.slice(rindex=[0, 1, 2])
|
||||||
dtrain.slice(rindex=[0, 1, 2])
|
dtrain.slice(rindex=[0, 1, 2])
|
||||||
@ -153,11 +153,11 @@ Arrow specification.'''
|
|||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_cupy())
|
@pytest.mark.skipif(**tm.no_cupy())
|
||||||
def test_device_dmat_from_cupy(self):
|
def test_device_dmat_from_cupy(self):
|
||||||
_test_from_cupy(xgb.DeviceQuantileDMatrix)
|
_test_from_cupy(xgb.QuantileDMatrix)
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_cupy())
|
@pytest.mark.skipif(**tm.no_cupy())
|
||||||
def test_cupy_training_device_dmat(self):
|
def test_cupy_training_device_dmat(self):
|
||||||
_test_cupy_training(xgb.DeviceQuantileDMatrix)
|
_test_cupy_training(xgb.QuantileDMatrix)
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_cupy())
|
@pytest.mark.skipif(**tm.no_cupy())
|
||||||
def test_cupy_training_simple_dmat(self):
|
def test_cupy_training_simple_dmat(self):
|
||||||
@ -169,7 +169,7 @@ Arrow specification.'''
|
|||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_cupy())
|
@pytest.mark.skipif(**tm.no_cupy())
|
||||||
def test_cupy_metainfo_device_dmat(self):
|
def test_cupy_metainfo_device_dmat(self):
|
||||||
_test_cupy_metainfo(xgb.DeviceQuantileDMatrix)
|
_test_cupy_metainfo(xgb.QuantileDMatrix)
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_cupy())
|
@pytest.mark.skipif(**tm.no_cupy())
|
||||||
def test_dlpack_simple_dmat(self):
|
def test_dlpack_simple_dmat(self):
|
||||||
@ -196,7 +196,7 @@ Arrow specification.'''
|
|||||||
import cupy as cp
|
import cupy as cp
|
||||||
n = 100
|
n = 100
|
||||||
X = cp.random.random((n, 2))
|
X = cp.random.random((n, 2))
|
||||||
m = xgb.DeviceQuantileDMatrix(X.toDlpack())
|
m = xgb.QuantileDMatrix(X.toDlpack())
|
||||||
with pytest.raises(xgb.core.XGBoostError):
|
with pytest.raises(xgb.core.XGBoostError):
|
||||||
m.slice(rindex=[0, 1, 2])
|
m.slice(rindex=[0, 1, 2])
|
||||||
|
|
||||||
@ -222,7 +222,7 @@ Arrow specification.'''
|
|||||||
import cupy as cp
|
import cupy as cp
|
||||||
cp.cuda.runtime.setDevice(0)
|
cp.cuda.runtime.setDevice(0)
|
||||||
dtrain = dmatrix_from_cupy(
|
dtrain = dmatrix_from_cupy(
|
||||||
np.float32, xgb.DeviceQuantileDMatrix, np.nan)
|
np.float32, xgb.QuantileDMatrix, np.nan)
|
||||||
with pytest.raises(xgb.core.XGBoostError):
|
with pytest.raises(xgb.core.XGBoostError):
|
||||||
xgb.train(
|
xgb.train(
|
||||||
{'tree_method': 'gpu_hist', 'gpu_id': 1}, dtrain, num_boost_round=10
|
{'tree_method': 'gpu_hist', 'gpu_id': 1}, dtrain, num_boost_round=10
|
||||||
|
|||||||
@ -17,7 +17,7 @@ def test_large_input():
|
|||||||
assert (np.log2(m * n) > 31)
|
assert (np.log2(m * n) > 31)
|
||||||
X = cp.ones((m, n), dtype=np.float32)
|
X = cp.ones((m, n), dtype=np.float32)
|
||||||
y = cp.ones(m)
|
y = cp.ones(m)
|
||||||
dmat = xgb.DeviceQuantileDMatrix(X, y)
|
dmat = xgb.QuantileDMatrix(X, y)
|
||||||
booster = xgb.train({"tree_method": "gpu_hist", "max_depth": 1}, dmat, 1)
|
booster = xgb.train({"tree_method": "gpu_hist", "max_depth": 1}, dmat, 1)
|
||||||
del y
|
del y
|
||||||
booster.inplace_predict(X)
|
booster.inplace_predict(X)
|
||||||
|
|||||||
@ -173,7 +173,7 @@ class TestTreeMethod:
|
|||||||
|
|
||||||
X, y = cp.array(X), cp.array(y)
|
X, y = cp.array(X), cp.array(y)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
Xy = xgb.DeviceQuantileDMatrix(X, y, feature_types=["c"] * 10)
|
Xy = xgb.QuantileDMatrix(X, y, feature_types=["c"] * 10)
|
||||||
|
|
||||||
def test_invalid_category(self) -> None:
|
def test_invalid_category(self) -> None:
|
||||||
self.run_invalid_category("approx")
|
self.run_invalid_category("approx")
|
||||||
|
|||||||
@ -135,7 +135,7 @@ def run_with_dask_array(DMatrixT: Type, client: Client) -> None:
|
|||||||
def to_cp(x: Any, DMatrixT: Type) -> Any:
|
def to_cp(x: Any, DMatrixT: Type) -> Any:
|
||||||
import cupy
|
import cupy
|
||||||
|
|
||||||
if isinstance(x, np.ndarray) and DMatrixT is dxgb.DaskDeviceQuantileDMatrix:
|
if isinstance(x, np.ndarray) and DMatrixT is dxgb.DaskQuantileDMatrix:
|
||||||
X = cupy.array(x)
|
X = cupy.array(x)
|
||||||
else:
|
else:
|
||||||
X = x
|
X = x
|
||||||
@ -169,7 +169,7 @@ def run_gpu_hist(
|
|||||||
else:
|
else:
|
||||||
w = None
|
w = None
|
||||||
|
|
||||||
if DMatrixT is dxgb.DaskDeviceQuantileDMatrix:
|
if DMatrixT is dxgb.DaskQuantileDMatrix:
|
||||||
m = DMatrixT(
|
m = DMatrixT(
|
||||||
client, data=X, label=y, weight=w, max_bin=params.get("max_bin", 256)
|
client, data=X, label=y, weight=w, max_bin=params.get("max_bin", 256)
|
||||||
)
|
)
|
||||||
@ -227,7 +227,7 @@ class TestDistributedGPU:
|
|||||||
@pytest.mark.skipif(**tm.no_dask_cudf())
|
@pytest.mark.skipif(**tm.no_dask_cudf())
|
||||||
def test_dask_dataframe(self, local_cuda_client: Client) -> None:
|
def test_dask_dataframe(self, local_cuda_client: Client) -> None:
|
||||||
run_with_dask_dataframe(dxgb.DaskDMatrix, local_cuda_client)
|
run_with_dask_dataframe(dxgb.DaskDMatrix, local_cuda_client)
|
||||||
run_with_dask_dataframe(dxgb.DaskDeviceQuantileDMatrix, local_cuda_client)
|
run_with_dask_dataframe(dxgb.DaskQuantileDMatrix, local_cuda_client)
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_dask_cudf())
|
@pytest.mark.skipif(**tm.no_dask_cudf())
|
||||||
def test_categorical(self, local_cuda_client: Client) -> None:
|
def test_categorical(self, local_cuda_client: Client) -> None:
|
||||||
@ -245,7 +245,7 @@ class TestDistributedGPU:
|
|||||||
num_rounds=strategies.integers(1, 20),
|
num_rounds=strategies.integers(1, 20),
|
||||||
dataset=tm.dataset_strategy,
|
dataset=tm.dataset_strategy,
|
||||||
dmatrix_type=strategies.sampled_from(
|
dmatrix_type=strategies.sampled_from(
|
||||||
[dxgb.DaskDMatrix, dxgb.DaskDeviceQuantileDMatrix]
|
[dxgb.DaskDMatrix, dxgb.DaskQuantileDMatrix]
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@settings(
|
@settings(
|
||||||
@ -268,7 +268,7 @@ class TestDistributedGPU:
|
|||||||
@pytest.mark.skipif(**tm.no_cupy())
|
@pytest.mark.skipif(**tm.no_cupy())
|
||||||
def test_dask_array(self, local_cuda_client: Client) -> None:
|
def test_dask_array(self, local_cuda_client: Client) -> None:
|
||||||
run_with_dask_array(dxgb.DaskDMatrix, local_cuda_client)
|
run_with_dask_array(dxgb.DaskDMatrix, local_cuda_client)
|
||||||
run_with_dask_array(dxgb.DaskDeviceQuantileDMatrix, local_cuda_client)
|
run_with_dask_array(dxgb.DaskQuantileDMatrix, local_cuda_client)
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_cupy())
|
@pytest.mark.skipif(**tm.no_cupy())
|
||||||
def test_early_stopping(self, local_cuda_client: Client) -> None:
|
def test_early_stopping(self, local_cuda_client: Client) -> None:
|
||||||
@ -357,7 +357,7 @@ class TestDistributedGPU:
|
|||||||
)
|
)
|
||||||
X = ddf[ddf.columns.difference(["y"])]
|
X = ddf[ddf.columns.difference(["y"])]
|
||||||
y = ddf[["y"]]
|
y = ddf[["y"]]
|
||||||
dtrain = dxgb.DaskDeviceQuantileDMatrix(local_cuda_client, X, y)
|
dtrain = dxgb.DaskQuantileDMatrix(local_cuda_client, X, y)
|
||||||
bst_empty = xgb.dask.train(
|
bst_empty = xgb.dask.train(
|
||||||
local_cuda_client, parameters, dtrain, evals=[(dtrain, "train")]
|
local_cuda_client, parameters, dtrain, evals=[(dtrain, "train")]
|
||||||
)
|
)
|
||||||
@ -369,7 +369,7 @@ class TestDistributedGPU:
|
|||||||
)
|
)
|
||||||
X = ddf[ddf.columns.difference(["y"])]
|
X = ddf[ddf.columns.difference(["y"])]
|
||||||
y = ddf[["y"]]
|
y = ddf[["y"]]
|
||||||
dtrain = dxgb.DaskDeviceQuantileDMatrix(local_cuda_client, X, y)
|
dtrain = dxgb.DaskQuantileDMatrix(local_cuda_client, X, y)
|
||||||
bst = xgb.dask.train(
|
bst = xgb.dask.train(
|
||||||
local_cuda_client, parameters, dtrain, evals=[(dtrain, "train")]
|
local_cuda_client, parameters, dtrain, evals=[(dtrain, "train")]
|
||||||
)
|
)
|
||||||
@ -546,7 +546,7 @@ async def run_from_dask_array_asyncio(scheduler_address: str) -> dxgb.TrainRetur
|
|||||||
X = X.map_blocks(cp.array)
|
X = X.map_blocks(cp.array)
|
||||||
y = y.map_blocks(cp.array)
|
y = y.map_blocks(cp.array)
|
||||||
|
|
||||||
m = await xgb.dask.DaskDeviceQuantileDMatrix(client, X, y)
|
m = await xgb.dask.DaskQuantileDMatrix(client, X, y)
|
||||||
output = await xgb.dask.train(client, {"tree_method": "gpu_hist"}, dtrain=m)
|
output = await xgb.dask.train(client, {"tree_method": "gpu_hist"}, dtrain=m)
|
||||||
|
|
||||||
with_m = await xgb.dask.predict(client, output, m)
|
with_m = await xgb.dask.predict(client, output, m)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user