Handle the new device parameter in dask and demos. (#9386)

* Handle the new `device` parameter in dask and demos.

- Check no ordinal is specified in the dask interface.
- Update demos.
- Update dask doc.
- Update the condition for QDM.
This commit is contained in:
Jiaming Yuan 2023-07-15 19:11:20 +08:00 committed by GitHub
parent 9da5050643
commit 16eb41936d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 631 additions and 450 deletions

View File

@ -18,43 +18,45 @@ def main(client):
# The Veterans' Administration Lung Cancer Trial # The Veterans' Administration Lung Cancer Trial
# The Statistical Analysis of Failure Time Data by Kalbfleisch J. and Prentice R (1980) # The Statistical Analysis of Failure Time Data by Kalbfleisch J. and Prentice R (1980)
CURRENT_DIR = os.path.dirname(__file__) CURRENT_DIR = os.path.dirname(__file__)
df = dd.read_csv(os.path.join(CURRENT_DIR, os.pardir, 'data', 'veterans_lung_cancer.csv')) df = dd.read_csv(
os.path.join(CURRENT_DIR, os.pardir, "data", "veterans_lung_cancer.csv")
)
# DaskDMatrix acts like normal DMatrix, works as a proxy for local # DaskDMatrix acts like normal DMatrix, works as a proxy for local
# DMatrix scatter around workers. # DMatrix scatter around workers.
# For AFT survival, you'd need to extract the lower and upper bounds for the label # For AFT survival, you'd need to extract the lower and upper bounds for the label
# and pass them as arguments to DaskDMatrix. # and pass them as arguments to DaskDMatrix.
y_lower_bound = df['Survival_label_lower_bound'] y_lower_bound = df["Survival_label_lower_bound"]
y_upper_bound = df['Survival_label_upper_bound'] y_upper_bound = df["Survival_label_upper_bound"]
X = df.drop(['Survival_label_lower_bound', X = df.drop(["Survival_label_lower_bound", "Survival_label_upper_bound"], axis=1)
'Survival_label_upper_bound'], axis=1) dtrain = DaskDMatrix(
dtrain = DaskDMatrix(client, X, label_lower_bound=y_lower_bound, client, X, label_lower_bound=y_lower_bound, label_upper_bound=y_upper_bound
label_upper_bound=y_upper_bound) )
# Use train method from xgboost.dask instead of xgboost. This # Use train method from xgboost.dask instead of xgboost. This
# distributed version of train returns a dictionary containing the # distributed version of train returns a dictionary containing the
# resulting booster and evaluation history obtained from # resulting booster and evaluation history obtained from
# evaluation metrics. # evaluation metrics.
params = {'verbosity': 1, params = {
'objective': 'survival:aft', "verbosity": 1,
'eval_metric': 'aft-nloglik', "objective": "survival:aft",
'learning_rate': 0.05, "eval_metric": "aft-nloglik",
'aft_loss_distribution_scale': 1.20, "learning_rate": 0.05,
'aft_loss_distribution': 'normal', "aft_loss_distribution_scale": 1.20,
'max_depth': 6, "aft_loss_distribution": "normal",
'lambda': 0.01, "max_depth": 6,
'alpha': 0.02} "lambda": 0.01,
output = xgb.dask.train(client, "alpha": 0.02,
params, }
dtrain, output = xgb.dask.train(
num_boost_round=100, client, params, dtrain, num_boost_round=100, evals=[(dtrain, "train")]
evals=[(dtrain, 'train')]) )
bst = output['booster'] bst = output["booster"]
history = output['history'] history = output["history"]
# you can pass output directly into `predict` too. # you can pass output directly into `predict` too.
prediction = xgb.dask.predict(client, bst, dtrain) prediction = xgb.dask.predict(client, bst, dtrain)
print('Evaluation history: ', history) print("Evaluation history: ", history)
# Uncomment the following line to save the model to the disk # Uncomment the following line to save the model to the disk
# bst.save_model('survival_model.json') # bst.save_model('survival_model.json')
@ -62,7 +64,7 @@ def main(client):
return prediction return prediction
if __name__ == '__main__': if __name__ == "__main__":
# or use other clusters for scaling # or use other clusters for scaling
with LocalCluster(n_workers=7, threads_per_worker=4) as cluster: with LocalCluster(n_workers=7, threads_per_worker=4) as cluster:
with Client(cluster) as client: with Client(cluster) as client:

View File

@ -25,21 +25,23 @@ def main(client):
# distributed version of train returns a dictionary containing the # distributed version of train returns a dictionary containing the
# resulting booster and evaluation history obtained from # resulting booster and evaluation history obtained from
# evaluation metrics. # evaluation metrics.
output = xgb.dask.train(client, output = xgb.dask.train(
{'verbosity': 1, client,
'tree_method': 'hist'}, {"verbosity": 1, "tree_method": "hist"},
dtrain, dtrain,
num_boost_round=4, evals=[(dtrain, 'train')]) num_boost_round=4,
bst = output['booster'] evals=[(dtrain, "train")],
history = output['history'] )
bst = output["booster"]
history = output["history"]
# you can pass output directly into `predict` too. # you can pass output directly into `predict` too.
prediction = xgb.dask.predict(client, bst, dtrain) prediction = xgb.dask.predict(client, bst, dtrain)
print('Evaluation history:', history) print("Evaluation history:", history)
return prediction return prediction
if __name__ == '__main__': if __name__ == "__main__":
# or use other clusters for scaling # or use other clusters for scaling
with LocalCluster(n_workers=7, threads_per_worker=4) as cluster: with LocalCluster(n_workers=7, threads_per_worker=4) as cluster:
with Client(cluster) as client: with Client(cluster) as client:

View File

@ -13,33 +13,38 @@ from xgboost import dask as dxgb
from xgboost.dask import DaskDMatrix from xgboost.dask import DaskDMatrix
def using_dask_matrix(client: Client, X, y): def using_dask_matrix(client: Client, X: da.Array, y: da.Array) -> da.Array:
# DaskDMatrix acts like normal DMatrix, works as a proxy for local # DaskDMatrix acts like normal DMatrix, works as a proxy for local DMatrix scatter
# DMatrix scatter around workers. # around workers.
dtrain = DaskDMatrix(client, X, y) dtrain = DaskDMatrix(client, X, y)
# Use train method from xgboost.dask instead of xgboost. This # Use train method from xgboost.dask instead of xgboost. This distributed version
# distributed version of train returns a dictionary containing the # of train returns a dictionary containing the resulting booster and evaluation
# resulting booster and evaluation history obtained from # history obtained from evaluation metrics.
# evaluation metrics. output = xgb.dask.train(
output = xgb.dask.train(client, client,
{'verbosity': 2, {
"verbosity": 2,
"tree_method": "hist",
# Golden line for GPU training # Golden line for GPU training
'tree_method': 'gpu_hist'}, "device": "cuda",
},
dtrain, dtrain,
num_boost_round=4, evals=[(dtrain, 'train')]) num_boost_round=4,
bst = output['booster'] evals=[(dtrain, "train")],
history = output['history'] )
bst = output["booster"]
history = output["history"]
# you can pass output directly into `predict` too. # you can pass output directly into `predict` too.
prediction = xgb.dask.predict(client, bst, dtrain) prediction = xgb.dask.predict(client, bst, dtrain)
print('Evaluation history:', history) print("Evaluation history:", history)
return prediction return prediction
def using_quantile_device_dmatrix(client: Client, X, y): def using_quantile_device_dmatrix(client: Client, X: da.Array, y: da.Array) -> da.Array:
"""`DaskQuantileDMatrix` is a data type specialized for `gpu_hist` and `hist` tree """`DaskQuantileDMatrix` is a data type specialized for `hist` tree methods for
methods for reducing memory usage. reducing memory usage.
.. versionadded:: 1.2.0 .. versionadded:: 1.2.0
@ -52,17 +57,19 @@ def using_quantile_device_dmatrix(client: Client, X, y):
# the `ref` argument of `DaskQuantileDMatrix`. # the `ref` argument of `DaskQuantileDMatrix`.
dtrain = dxgb.DaskQuantileDMatrix(client, X, y) dtrain = dxgb.DaskQuantileDMatrix(client, X, y)
output = xgb.dask.train( output = xgb.dask.train(
client, {"verbosity": 2, "tree_method": "gpu_hist"}, dtrain, num_boost_round=4 client,
{"verbosity": 2, "tree_method": "hist", "device": "cuda"},
dtrain,
num_boost_round=4,
) )
prediction = xgb.dask.predict(client, output, X) prediction = xgb.dask.predict(client, output, X)
return prediction return prediction
if __name__ == '__main__': if __name__ == "__main__":
# `LocalCUDACluster` is used for assigning GPU to XGBoost processes. Here # `LocalCUDACluster` is used for assigning GPU to XGBoost processes. Here
# `n_workers` represents the number of GPUs since we use one GPU per worker # `n_workers` represents the number of GPUs since we use one GPU per worker process.
# process.
with LocalCUDACluster(n_workers=2, threads_per_worker=4) as cluster: with LocalCUDACluster(n_workers=2, threads_per_worker=4) as cluster:
with Client(cluster) as client: with Client(cluster) as client:
# generate some random data for demonstration # generate some random data for demonstration
@ -71,7 +78,7 @@ if __name__ == '__main__':
X = da.random.random(size=(m, n), chunks=10000) X = da.random.random(size=(m, n), chunks=10000)
y = da.random.random(size=(m,), chunks=10000) y = da.random.random(size=(m,), chunks=10000)
print('Using DaskQuantileDMatrix') print("Using DaskQuantileDMatrix")
from_ddqdm = using_quantile_device_dmatrix(client, X, y) from_ddqdm = using_quantile_device_dmatrix(client, X, y)
print('Using DMatrix') print("Using DMatrix")
from_dmatrix = using_dask_matrix(client, X, y) from_dmatrix = using_dask_matrix(client, X, y)

View File

@ -21,7 +21,8 @@ def main(client):
y = da.random.random(m, partition_size) y = da.random.random(m, partition_size)
regressor = xgboost.dask.DaskXGBRegressor(verbosity=1) regressor = xgboost.dask.DaskXGBRegressor(verbosity=1)
regressor.set_params(tree_method='gpu_hist') # set the device to CUDA
regressor.set_params(tree_method="hist", device="cuda")
# assigning client here is optional # assigning client here is optional
regressor.client = client regressor.client = client
@ -31,13 +32,13 @@ def main(client):
bst = regressor.get_booster() bst = regressor.get_booster()
history = regressor.evals_result() history = regressor.evals_result()
print('Evaluation history:', history) print("Evaluation history:", history)
# returned prediction is always a dask array. # returned prediction is always a dask array.
assert isinstance(prediction, da.Array) assert isinstance(prediction, da.Array)
return bst # returning the trained model return bst # returning the trained model
if __name__ == '__main__': if __name__ == "__main__":
# With dask cuda, one can scale up XGBoost to arbitrary GPU clusters. # With dask cuda, one can scale up XGBoost to arbitrary GPU clusters.
# `LocalCUDACluster` used here is only for demonstration purpose. # `LocalCUDACluster` used here is only for demonstration purpose.
with LocalCUDACluster() as cluster: with LocalCUDACluster() as cluster:

View File

@ -71,7 +71,8 @@ def custom_callback():
{ {
'objective': 'binary:logistic', 'objective': 'binary:logistic',
'eval_metric': ['error', 'rmse'], 'eval_metric': ['error', 'rmse'],
'tree_method': 'gpu_hist' 'tree_method': 'hist',
"device": "cuda",
}, },
D_train, D_train,
evals=[(D_train, 'Train'), (D_valid, 'Valid')], evals=[(D_train, 'Train'), (D_valid, 'Valid')],

View File

@ -63,7 +63,8 @@ def load_cat_in_the_dat() -> tuple[pd.DataFrame, pd.Series]:
params = { params = {
"tree_method": "gpu_hist", "tree_method": "hist",
"device": "cuda",
"n_estimators": 32, "n_estimators": 32,
"colsample_bylevel": 0.7, "colsample_bylevel": 0.7,
} }

View File

@ -58,13 +58,13 @@ def main() -> None:
# Specify `enable_categorical` to True, also we use onehot encoding based split # Specify `enable_categorical` to True, also we use onehot encoding based split
# here for demonstration. For details see the document of `max_cat_to_onehot`. # here for demonstration. For details see the document of `max_cat_to_onehot`.
reg = xgb.XGBRegressor( reg = xgb.XGBRegressor(
tree_method="gpu_hist", enable_categorical=True, max_cat_to_onehot=5 tree_method="hist", enable_categorical=True, max_cat_to_onehot=5, device="cuda"
) )
reg.fit(X, y, eval_set=[(X, y)]) reg.fit(X, y, eval_set=[(X, y)])
# Pass in already encoded data # Pass in already encoded data
X_enc, y_enc = make_categorical(100, 10, 4, True) X_enc, y_enc = make_categorical(100, 10, 4, True)
reg_enc = xgb.XGBRegressor(tree_method="gpu_hist") reg_enc = xgb.XGBRegressor(tree_method="hist", device="cuda")
reg_enc.fit(X_enc, y_enc, eval_set=[(X_enc, y_enc)]) reg_enc.fit(X_enc, y_enc, eval_set=[(X_enc, y_enc)])
reg_results = np.array(reg.evals_result()["validation_0"]["rmse"]) reg_results = np.array(reg.evals_result()["validation_0"]["rmse"])

View File

@ -82,8 +82,9 @@ def main(tmpdir: str) -> xgboost.Booster:
missing = np.NaN missing = np.NaN
Xy = xgboost.DMatrix(it, missing=missing, enable_categorical=False) Xy = xgboost.DMatrix(it, missing=missing, enable_categorical=False)
# Other tree methods including ``approx``, and ``gpu_hist`` are supported. GPU # ``approx`` is also supported, but less efficient due to sketching. GPU behaves
# behaves differently than CPU tree methods. See tutorial in doc for details. # differently than CPU tree methods as it uses a hybrid approach. See tutorial in
# doc for details.
booster = xgboost.train( booster = xgboost.train(
{"tree_method": "hist", "max_depth": 4}, {"tree_method": "hist", "max_depth": 4},
Xy, Xy,

View File

@ -104,7 +104,8 @@ def ranking_demo(args: argparse.Namespace) -> None:
qid_test = qid_test[sorted_idx] qid_test = qid_test[sorted_idx]
ranker = xgb.XGBRanker( ranker = xgb.XGBRanker(
tree_method="gpu_hist", tree_method="hist",
device="cuda",
lambdarank_pair_method="topk", lambdarank_pair_method="topk",
lambdarank_num_pair_per_sample=13, lambdarank_num_pair_per_sample=13,
eval_metric=["ndcg@1", "ndcg@8"], eval_metric=["ndcg@1", "ndcg@8"],
@ -161,7 +162,8 @@ def click_data_demo(args: argparse.Namespace) -> None:
ranker = xgb.XGBRanker( ranker = xgb.XGBRanker(
n_estimators=512, n_estimators=512,
tree_method="gpu_hist", tree_method="hist",
device="cuda",
learning_rate=0.01, learning_rate=0.01,
reg_lambda=1.5, reg_lambda=1.5,
subsample=0.8, subsample=0.8,

View File

@ -28,17 +28,18 @@ BATCHES = 32
class IterForDMatrixDemo(xgboost.core.DataIter): class IterForDMatrixDemo(xgboost.core.DataIter):
'''A data iterator for XGBoost DMatrix. """A data iterator for XGBoost DMatrix.
`reset` and `next` are required for any data iterator, other functions here `reset` and `next` are required for any data iterator, other functions here
are utilites for demonstration's purpose. are utilites for demonstration's purpose.
''' """
def __init__(self): def __init__(self):
'''Generate some random data for demostration. """Generate some random data for demostration.
Actual data can be anything that is currently supported by XGBoost. Actual data can be anything that is currently supported by XGBoost.
''' """
self.rows = ROWS_PER_BATCH self.rows = ROWS_PER_BATCH
self.cols = COLS self.cols = COLS
rng = cupy.random.RandomState(1994) rng = cupy.random.RandomState(1994)
@ -59,27 +60,26 @@ class IterForDMatrixDemo(xgboost.core.DataIter):
return cupy.concatenate(self._weights) return cupy.concatenate(self._weights)
def data(self): def data(self):
'''Utility function for obtaining current batch of data.''' """Utility function for obtaining current batch of data."""
return self._data[self.it] return self._data[self.it]
def labels(self): def labels(self):
'''Utility function for obtaining current batch of label.''' """Utility function for obtaining current batch of label."""
return self._labels[self.it] return self._labels[self.it]
def weights(self): def weights(self):
return self._weights[self.it] return self._weights[self.it]
def reset(self): def reset(self):
'''Reset the iterator''' """Reset the iterator"""
self.it = 0 self.it = 0
def next(self, input_data): def next(self, input_data):
'''Yield next batch of data.''' """Yield next batch of data."""
if self.it == len(self._data): if self.it == len(self._data):
# Return 0 when there's no more batch. # Return 0 when there's no more batch.
return 0 return 0
input_data(data=self.data(), label=self.labels(), input_data(data=self.data(), label=self.labels(), weight=self.weights())
weight=self.weights())
self.it += 1 self.it += 1
return 1 return 1
@ -103,18 +103,19 @@ def main():
assert m_with_it.num_col() == m.num_col() assert m_with_it.num_col() == m.num_col()
assert m_with_it.num_row() == m.num_row() assert m_with_it.num_row() == m.num_row()
# Tree meethod must be one of the `hist` or `gpu_hist`. We use `gpu_hist` for GPU # Tree meethod must be `hist`.
# input here.
reg_with_it = xgboost.train( reg_with_it = xgboost.train(
{"tree_method": "gpu_hist"}, m_with_it, num_boost_round=rounds {"tree_method": "hist", "device": "cuda"}, m_with_it, num_boost_round=rounds
) )
predict_with_it = reg_with_it.predict(m_with_it) predict_with_it = reg_with_it.predict(m_with_it)
reg = xgboost.train({"tree_method": "gpu_hist"}, m, num_boost_round=rounds) reg = xgboost.train(
{"tree_method": "hist", "device": "cuda"}, m, num_boost_round=rounds
)
predict = reg.predict(m) predict = reg.predict(m)
numpy.testing.assert_allclose(predict_with_it, predict, rtol=1e6) numpy.testing.assert_allclose(predict_with_it, predict, rtol=1e6)
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View File

@ -24,7 +24,7 @@ def main():
Xy = xgb.DMatrix(X_train, y_train) Xy = xgb.DMatrix(X_train, y_train)
evals_result: xgb.callback.EvaluationMonitor.EvalsLog = {} evals_result: xgb.callback.EvaluationMonitor.EvalsLog = {}
booster = xgb.train( booster = xgb.train(
{"tree_method": "gpu_hist", "max_depth": 6}, {"tree_method": "hist", "max_depth": 6, "device": "cuda"},
Xy, Xy,
num_boost_round=n_rounds, num_boost_round=n_rounds,
evals=[(Xy, "Train")], evals=[(Xy, "Train")],
@ -87,7 +87,7 @@ def main():
np.testing.assert_allclose( np.testing.assert_allclose(
np.array(prune_result["Original"]["rmse"]), np.array(prune_result["Original"]["rmse"]),
np.array(prune_result["Train"]["rmse"]), np.array(prune_result["Train"]["rmse"]),
atol=1e-5 atol=1e-5,
) )

View File

@ -14,30 +14,24 @@ Most of the algorithms in XGBoost including training, prediction and evaluation
Usage Usage
===== =====
Specify the ``tree_method`` parameter as ``gpu_hist``. For details around the ``tree_method`` parameter, see :doc:`tree method </treemethod>`.
Supported parameters
--------------------
GPU accelerated prediction is enabled by default for the above mentioned ``tree_method`` parameters but can be switched to CPU prediction by setting ``predictor`` to ``cpu_predictor``. This could be useful if you want to conserve GPU memory. Likewise when using CPU algorithms, GPU accelerated prediction can be enabled by setting ``predictor`` to ``gpu_predictor``.
The device ordinal (which GPU to use if you have many of them) can be selected using the
``device`` parameter, which defaults to 0 when "CUDA" is specified(the first device reported by CUDA
runtime).
To enable GPU acceleration, specify the ``device`` parameter as ``cuda``. In addition, the device ordinal (which GPU to use if you have multiple devices in the same node) can be specified using the ``cuda:<ordinal>`` syntax, where ``<ordinal>`` is an integer that represents the device ordinal. XGBoost defaults to 0 (the first device reported by CUDA runtime).
The GPU algorithms currently work with CLI, Python, R, and JVM packages. See :doc:`/install` for details. The GPU algorithms currently work with CLI, Python, R, and JVM packages. See :doc:`/install` for details.
.. code-block:: python .. code-block:: python
:caption: Python example :caption: Python example
param["device"] = "cuda:0" params = dict()
param['tree_method'] = 'gpu_hist' params["device"] = "cuda:0"
params["tree_method"] = "hist"
Xy = xgboost.QuantileDMatrix(X, y)
xgboost.train(params, Xy)
.. code-block:: python .. code-block:: python
:caption: With Scikit-Learn interface :caption: With Scikit-Learn interface
XGBRegressor(tree_method='gpu_hist', device="cuda") XGBRegressor(tree_method="hist", device="cuda")
GPU-Accelerated SHAP values GPU-Accelerated SHAP values
@ -46,12 +40,11 @@ XGBoost makes use of `GPUTreeShap <https://github.com/rapidsai/gputreeshap>`_ as
.. code-block:: python .. code-block:: python
model.set_param({"device": "cuda:0", "tree_method": "gpu_hist"}) booster.set_param({"device": "cuda:0"})
shap_values = model.predict(dtrain, pred_contribs=True) shap_values = booster.predict(dtrain, pred_contribs=True)
shap_interaction_values = model.predict(dtrain, pred_interactions=True) shap_interaction_values = model.predict(dtrain, pred_interactions=True)
See examples `here See examples `here <https://github.com/dmlc/xgboost/tree/master/demo/gpu_acceleration>`__.
<https://github.com/dmlc/xgboost/tree/master/demo/gpu_acceleration>`__.
Multi-node Multi-GPU Training Multi-node Multi-GPU Training
============================= =============================
@ -61,7 +54,7 @@ XGBoost supports fully distributed GPU training using `Dask <https://dask.org/>`
Memory usage Memory usage
============ ============
The following are some guidelines on the device memory usage of the `gpu_hist` tree method. The following are some guidelines on the device memory usage of the ``hist`` tree method on GPU.
Memory inside xgboost training is generally allocated for two reasons - storing the dataset and working memory. Memory inside xgboost training is generally allocated for two reasons - storing the dataset and working memory.
@ -79,7 +72,7 @@ XGBoost models trained on GPUs can be used on CPU-only systems to generate predi
Developer notes Developer notes
=============== ===============
The application may be profiled with annotations by specifying USE_NTVX to cmake. Regions covered by the 'Monitor' class in CUDA code will automatically appear in the nsight profiler when `verbosity` is set to 3. The application may be profiled with annotations by specifying ``USE_NTVX`` to cmake. Regions covered by the 'Monitor' class in CUDA code will automatically appear in the nsight profiler when `verbosity` is set to 3.
********** **********
References References

View File

@ -55,10 +55,6 @@ General Parameters
- Flag to disable default metric. Set to 1 or ``true`` to disable. - Flag to disable default metric. Set to 1 or ``true`` to disable.
* ``num_feature`` [set automatically by XGBoost, no need to be set by user]
- Feature dimension used in boosting, set to maximum dimension of the feature
* ``device`` [default= ``cpu``] * ``device`` [default= ``cpu``]
.. versionadded:: 2.0.0 .. versionadded:: 2.0.0
@ -164,7 +160,7 @@ Parameters for Tree Booster
- ``grow_colmaker``: non-distributed column-based construction of trees. - ``grow_colmaker``: non-distributed column-based construction of trees.
- ``grow_histmaker``: distributed tree construction with row-based data splitting based on global proposal of histogram counting. - ``grow_histmaker``: distributed tree construction with row-based data splitting based on global proposal of histogram counting.
- ``grow_quantile_histmaker``: Grow tree using quantized histogram. - ``grow_quantile_histmaker``: Grow tree using quantized histogram.
- ``grow_gpu_hist``: Grow tree with GPU. Same as setting tree method to ``hist`` and use ``device=cuda``. - ``grow_gpu_hist``: Grow tree with GPU. Same as setting ``tree_method`` to ``hist`` and use ``device=cuda``.
- ``sync``: synchronizes trees in all distributed nodes. - ``sync``: synchronizes trees in all distributed nodes.
- ``refresh``: refreshes tree's statistics and/or leaf values based on the current data. Note that no random subsampling of data rows is performed. - ``refresh``: refreshes tree's statistics and/or leaf values based on the current data. Note that no random subsampling of data rows is performed.
- ``prune``: prunes the splits where loss < min_split_loss (or gamma) and nodes that have depth greater than ``max_depth``. - ``prune``: prunes the splits where loss < min_split_loss (or gamma) and nodes that have depth greater than ``max_depth``.

View File

@ -310,8 +310,8 @@ for more info.
.. code-block:: python .. code-block:: python
# Use "gpu_hist" for training the model. # Use "hist" for training the model.
reg = xgb.XGBRegressor(tree_method="gpu_hist") reg = xgb.XGBRegressor(tree_method="hist", device="cuda")
# Fit the model using predictor X and response y. # Fit the model using predictor X and response y.
reg.fit(X, y) reg.fit(X, y)
# Save model into JSON format. # Save model into JSON format.

View File

@ -56,7 +56,6 @@ on a dask cluster:
dtrain = xgb.dask.DaskDMatrix(client, X, y) dtrain = xgb.dask.DaskDMatrix(client, X, y)
# or # or
# dtrain = xgb.dask.DaskQuantileDMatrix(client, X, y) # dtrain = xgb.dask.DaskQuantileDMatrix(client, X, y)
# `DaskQuantileDMatrix` is available for the `hist` and `gpu_hist` tree method.
output = xgb.dask.train( output = xgb.dask.train(
client, client,
@ -149,7 +148,7 @@ Also for inplace prediction:
.. code-block:: python .. code-block:: python
# where X is a dask DataFrame or dask Array backed by cupy or cuDF. # where X is a dask DataFrame or dask Array backed by cupy or cuDF.
booster.set_param({"device": "cuda:0"}) booster.set_param({"device": "cuda"})
prediction = xgb.dask.inplace_predict(client, booster, X) prediction = xgb.dask.inplace_predict(client, booster, X)
When input is ``da.Array`` object, output is always ``da.Array``. However, if the input When input is ``da.Array`` object, output is always ``da.Array``. However, if the input
@ -225,6 +224,12 @@ collection.
main(client) main(client)
****************
GPU acceleration
****************
For most of the use cases with GPUs, the `Dask-CUDA <https://docs.rapids.ai/api/dask-cuda/stable/quickstart.html>`__ project should be used to create the cluster, which automatically configures the correct device ordinal for worker processes. As a result, users should NOT specify the ordinal (good: ``device=cuda``, bad: ``device=cuda:1``). See :ref:`sphx_glr_python_dask-examples_gpu_training.py` and :ref:`sphx_glr_python_dask-examples_sklearn_gpu_training.py` for worked examples.
*************************** ***************************
Working with other clusters Working with other clusters
*************************** ***************************
@ -262,7 +267,7 @@ In the example below, a ``KubeCluster`` is used for `deploying Dask on Kubernete
regressor = xgb.dask.DaskXGBRegressor(n_estimators=10, missing=0.0) regressor = xgb.dask.DaskXGBRegressor(n_estimators=10, missing=0.0)
regressor.client = client regressor.client = client
regressor.set_params(tree_method='gpu_hist') regressor.set_params(tree_method='hist', device="cuda")
regressor.fit(X, y, eval_set=[(X, y)]) regressor.fit(X, y, eval_set=[(X, y)])

View File

@ -1451,7 +1451,7 @@ class QuantileDMatrix(DMatrix):
enable_categorical: bool = False, enable_categorical: bool = False,
data_split_mode: DataSplitMode = DataSplitMode.ROW, data_split_mode: DataSplitMode = DataSplitMode.ROW,
) -> None: ) -> None:
self.max_bin: int = max_bin if max_bin is not None else 256 self.max_bin = max_bin
self.missing = missing if missing is not None else np.nan self.missing = missing if missing is not None else np.nan
self.nthread = nthread if nthread is not None else -1 self.nthread = nthread if nthread is not None else -1
self._silent = silent # unused, kept for compatibility self._silent = silent # unused, kept for compatibility

View File

@ -82,6 +82,7 @@ from .sklearn import (
XGBRanker, XGBRanker,
XGBRankerMixIn, XGBRankerMixIn,
XGBRegressorBase, XGBRegressorBase,
_can_use_qdm,
_check_rf_callback, _check_rf_callback,
_cls_predict_proba, _cls_predict_proba,
_objective_decorator, _objective_decorator,
@ -617,14 +618,7 @@ class DaskPartitionIter(DataIter): # pylint: disable=R0902
if self._iter == len(self._data): if self._iter == len(self._data):
# Return 0 when there's no more batch. # Return 0 when there's no more batch.
return 0 return 0
feature_names: Optional[FeatureNames] = None
if self._feature_names:
feature_names = self._feature_names
else:
if hasattr(self.data(), "columns"):
feature_names = self.data().columns.format()
else:
feature_names = None
input_data( input_data(
data=self.data(), data=self.data(),
label=self._get("_label"), label=self._get("_label"),
@ -634,7 +628,7 @@ class DaskPartitionIter(DataIter): # pylint: disable=R0902
base_margin=self._get("_base_margin"), base_margin=self._get("_base_margin"),
label_lower_bound=self._get("_label_lower_bound"), label_lower_bound=self._get("_label_lower_bound"),
label_upper_bound=self._get("_label_upper_bound"), label_upper_bound=self._get("_label_upper_bound"),
feature_names=feature_names, feature_names=self._feature_names,
feature_types=self._feature_types, feature_types=self._feature_types,
feature_weights=self._feature_weights, feature_weights=self._feature_weights,
) )
@ -935,6 +929,12 @@ async def _train_async(
raise NotImplementedError( raise NotImplementedError(
f"booster `{params['booster']}` is not yet supported for dask." f"booster `{params['booster']}` is not yet supported for dask."
) )
device = params.get("device", None)
if device and device.find(":") != -1:
raise ValueError(
"The dask interface for XGBoost doesn't support selecting specific device"
" ordinal. Use `device=cpu` or `device=cuda` instead."
)
def dispatched_train( def dispatched_train(
parameters: Dict, parameters: Dict,
@ -1574,7 +1574,7 @@ async def _async_wrap_evaluation_matrices(
"""A switch function for async environment.""" """A switch function for async environment."""
def _dispatch(ref: Optional[DaskDMatrix], **kwargs: Any) -> DaskDMatrix: def _dispatch(ref: Optional[DaskDMatrix], **kwargs: Any) -> DaskDMatrix:
if tree_method in ("hist", "gpu_hist"): if _can_use_qdm(tree_method):
return DaskQuantileDMatrix( return DaskQuantileDMatrix(
client=client, ref=ref, max_bin=max_bin, **kwargs client=client, ref=ref, max_bin=max_bin, **kwargs
) )

View File

@ -76,6 +76,10 @@ def _check_rf_callback(
) )
def _can_use_qdm(tree_method: Optional[str]) -> bool:
return tree_method in ("hist", "gpu_hist", None, "auto")
SklObjective = Optional[ SklObjective = Optional[
Union[str, Callable[[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]] Union[str, Callable[[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]]
] ]
@ -939,7 +943,7 @@ class XGBModel(XGBModelBase):
def _create_dmatrix(self, ref: Optional[DMatrix], **kwargs: Any) -> DMatrix: def _create_dmatrix(self, ref: Optional[DMatrix], **kwargs: Any) -> DMatrix:
# Use `QuantileDMatrix` to save memory. # Use `QuantileDMatrix` to save memory.
if self.tree_method in ("hist", "gpu_hist"): if _can_use_qdm(self.tree_method) and self.booster != "gblinear":
try: try:
return QuantileDMatrix( return QuantileDMatrix(
**kwargs, ref=ref, nthread=self.n_jobs, max_bin=self.max_bin **kwargs, ref=ref, nthread=self.n_jobs, max_bin=self.max_bin

View File

@ -61,7 +61,7 @@ import xgboost
from xgboost import XGBClassifier from xgboost import XGBClassifier
from xgboost.compat import is_cudf_available from xgboost.compat import is_cudf_available
from xgboost.core import Booster from xgboost.core import Booster
from xgboost.sklearn import DEFAULT_N_ESTIMATORS, XGBModel from xgboost.sklearn import DEFAULT_N_ESTIMATORS, XGBModel, _can_use_qdm
from xgboost.training import train as worker_train from xgboost.training import train as worker_train
from .data import ( from .data import (
@ -901,7 +901,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
context = BarrierTaskContext.get() context = BarrierTaskContext.get()
dev_ordinal = None dev_ordinal = None
use_hist = booster_params.get("tree_method", None) in ("hist", "gpu_hist") use_qdm = _can_use_qdm(booster_params.get("tree_method", None))
if use_gpu: if use_gpu:
dev_ordinal = ( dev_ordinal = (
@ -912,9 +912,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
# because without cuDF, DMatrix performs better than QDM. # because without cuDF, DMatrix performs better than QDM.
# Note: Checking `is_cudf_available` in spark worker side because # Note: Checking `is_cudf_available` in spark worker side because
# spark worker might has different python environment with driver side. # spark worker might has different python environment with driver side.
use_qdm = use_hist and is_cudf_available() use_qdm = use_qdm and is_cudf_available()
else:
use_qdm = use_hist
if use_qdm and (booster_params.get("max_bin", None) is not None): if use_qdm and (booster_params.get("max_bin", None) is not None):
dmatrix_kwargs["max_bin"] = booster_params["max_bin"] dmatrix_kwargs["max_bin"] = booster_params["max_bin"]

View File

@ -81,13 +81,6 @@ void XGBBuildInfoDevice(Json *p_info) {
} // namespace xgboost } // namespace xgboost
#endif #endif
namespace {
void DeprecatedFunc(StringView old, StringView since, StringView replacement) {
LOG(WARNING) << "`" << old << "` is deprecated since" << since << ", use `" << replacement
<< "` instead.";
}
} // anonymous namespace
XGB_DLL int XGBuildInfo(char const **out) { XGB_DLL int XGBuildInfo(char const **out) {
API_BEGIN(); API_BEGIN();
xgboost_CHECK_C_ARG_PTR(out); xgboost_CHECK_C_ARG_PTR(out);
@ -328,7 +321,7 @@ XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatr
int nthread, int max_bin, int nthread, int max_bin,
DMatrixHandle *out) { DMatrixHandle *out) {
API_BEGIN(); API_BEGIN();
DeprecatedFunc(__func__, "1.7.0", "XGQuantileDMatrixCreateFromCallback"); LOG(WARNING) << error::DeprecatedFunc(__func__, "1.7.0", "XGQuantileDMatrixCreateFromCallback");
*out = new std::shared_ptr<xgboost::DMatrix>{ *out = new std::shared_ptr<xgboost::DMatrix>{
xgboost::DMatrix::Create(iter, proxy, nullptr, reset, next, missing, nthread, max_bin)}; xgboost::DMatrix::Create(iter, proxy, nullptr, reset, next, missing, nthread, max_bin)};
API_END(); API_END();
@ -432,7 +425,7 @@ XGB_DLL int XGDMatrixCreateFromCSREx(const size_t *indptr, const unsigned *indic
const bst_float *data, size_t nindptr, size_t nelem, const bst_float *data, size_t nindptr, size_t nelem,
size_t num_col, DMatrixHandle *out) { size_t num_col, DMatrixHandle *out) {
API_BEGIN(); API_BEGIN();
DeprecatedFunc(__func__, "2.0.0", "XGDMatrixCreateFromCSR"); LOG(WARNING) << error::DeprecatedFunc(__func__, "2.0.0", "XGDMatrixCreateFromCSR");
data::CSRAdapter adapter(indptr, indices, data, nindptr - 1, nelem, num_col); data::CSRAdapter adapter(indptr, indices, data, nindptr - 1, nelem, num_col);
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, std::nan(""), 1)); *out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, std::nan(""), 1));
API_END(); API_END();
@ -496,7 +489,7 @@ XGB_DLL int XGDMatrixCreateFromCSCEx(const size_t *col_ptr, const unsigned *indi
const bst_float *data, size_t nindptr, size_t, size_t num_row, const bst_float *data, size_t nindptr, size_t, size_t num_row,
DMatrixHandle *out) { DMatrixHandle *out) {
API_BEGIN(); API_BEGIN();
DeprecatedFunc(__func__, "2.0.0", "XGDMatrixCreateFromCSC"); LOG(WARNING) << error::DeprecatedFunc(__func__, "2.0.0", "XGDMatrixCreateFromCSC");
data::CSCAdapter adapter(col_ptr, indices, data, nindptr - 1, num_row); data::CSCAdapter adapter(col_ptr, indices, data, nindptr - 1, num_row);
xgboost_CHECK_C_ARG_PTR(out); xgboost_CHECK_C_ARG_PTR(out);
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, std::nan(""), 1)); *out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, std::nan(""), 1));
@ -1347,7 +1340,7 @@ XGB_DLL int XGBoosterGetModelRaw(BoosterHandle handle, xgboost::bst_ulong *out_l
raw_str.resize(0); raw_str.resize(0);
common::MemoryBufferStream fo(&raw_str); common::MemoryBufferStream fo(&raw_str);
DeprecatedFunc(__func__, "1.6.0", "XGBoosterSaveModelToBuffer"); LOG(WARNING) << error::DeprecatedFunc(__func__, "1.6.0", "XGBoosterSaveModelToBuffer");
learner->Configure(); learner->Configure();
learner->SaveModel(&fo); learner->SaveModel(&fo);

View File

@ -3,10 +3,18 @@
*/ */
#include "error_msg.h" #include "error_msg.h"
#include <sstream> // for stringstream
#include "../collective/communicator-inl.h" // for GetRank #include "../collective/communicator-inl.h" // for GetRank
#include "xgboost/logging.h" #include "xgboost/logging.h"
namespace xgboost::error { namespace xgboost::error {
std::string DeprecatedFunc(StringView old, StringView since, StringView replacement) {
std::stringstream ss;
ss << "`" << old << "` is deprecated since" << since << ", use `" << replacement << "` instead.";
return ss.str();
}
void WarnDeprecatedGPUHist() { void WarnDeprecatedGPUHist() {
auto msg = auto msg =
"The tree method `gpu_hist` is deprecated since 2.0.0. To use GPU training, set the `device` " "The tree method `gpu_hist` is deprecated since 2.0.0. To use GPU training, set the `device` "
@ -34,8 +42,9 @@ void WarnDeprecatedGPUId() {
if (logged) { if (logged) {
return; return;
} }
LOG(WARNING) << "`gpu_id` is deprecated in favor of the new `device` parameter: " auto msg = DeprecatedFunc("gpu_id", "2.0.0", "device");
<< "device = cpu/cuda/cuda:0"; msg += " E.g. device=cpu/cuda/cuda:0";
LOG(WARNING) << msg;
logged = true; logged = true;
} }

View File

@ -8,6 +8,7 @@
#include <cinttypes> // for uint64_t #include <cinttypes> // for uint64_t
#include <limits> // for numeric_limits #include <limits> // for numeric_limits
#include <string> // for string
#include "xgboost/base.h" // for bst_feature_t #include "xgboost/base.h" // for bst_feature_t
#include "xgboost/logging.h" #include "xgboost/logging.h"
@ -86,5 +87,7 @@ void WarnManualUpdater();
void WarnDeprecatedGPUId(); void WarnDeprecatedGPUId();
void WarnEmptyDataset(); void WarnEmptyDataset();
std::string DeprecatedFunc(StringView old, StringView since, StringView replacement);
} // namespace xgboost::error } // namespace xgboost::error
#endif // XGBOOST_COMMON_ERROR_MSG_H_ #endif // XGBOOST_COMMON_ERROR_MSG_H_

View File

@ -693,20 +693,21 @@ class LearnerConfiguration : public Learner {
for (auto const& kv : obj) { for (auto const& kv : obj) {
if (is_parameter(kv.first)) { if (is_parameter(kv.first)) {
auto parameter = get<Object const>(kv.second); auto parameter = get<Object const>(kv.second);
std::transform(parameter.begin(), parameter.end(), std::back_inserter(keys), std::transform(
[](std::pair<std::string const&, Json const&> const& kv) { parameter.begin(), parameter.end(), std::back_inserter(keys),
return kv.first; [](std::pair<std::string const&, Json const&> const& kv) { return kv.first; });
});
} else if (IsA<Object>(kv.second)) { } else if (IsA<Object>(kv.second)) {
stack.push(kv.second); stack.push(kv.second);
} else if (kv.first == "metrics") { } else if (IsA<Array>(kv.second)) {
auto const& array = get<Array const>(kv.second); auto const& array = get<Array const>(kv.second);
for (auto const& v : array) { for (auto const& v : array) {
if (IsA<Object>(v) || IsA<Array>(v)) {
stack.push(v); stack.push(v);
} }
} }
} }
} }
}
// FIXME(trivialfis): Make eval_metric a training parameter. // FIXME(trivialfis): Make eval_metric a training parameter.
keys.emplace_back(kEvalMetric); keys.emplace_back(kEvalMetric);

View File

@ -32,6 +32,7 @@ class LintersPaths:
"tests/test_distributed/test_with_spark/", "tests/test_distributed/test_with_spark/",
"tests/test_distributed/test_gpu_with_spark/", "tests/test_distributed/test_gpu_with_spark/",
# demo # demo
"demo/dask/",
"demo/json-model/json_parser.py", "demo/json-model/json_parser.py",
"demo/guide-python/cat_in_the_dat.py", "demo/guide-python/cat_in_the_dat.py",
"demo/guide-python/categorical.py", "demo/guide-python/categorical.py",
@ -42,6 +43,8 @@ class LintersPaths:
"demo/guide-python/quantile_regression.py", "demo/guide-python/quantile_regression.py",
"demo/guide-python/multioutput_regression.py", "demo/guide-python/multioutput_regression.py",
"demo/guide-python/learning_to_rank.py", "demo/guide-python/learning_to_rank.py",
"demo/guide-python/quantile_data_iterator.py",
"demo/guide-python/update_process.py",
"demo/aft_survival/aft_survival_viz_demo.py", "demo/aft_survival/aft_survival_viz_demo.py",
# CI # CI
"tests/ci_build/lint_python.py", "tests/ci_build/lint_python.py",

View File

@ -322,3 +322,15 @@ class TestQuantileDMatrix:
X: np.ndarray = np.array(orig, dtype=dtype) X: np.ndarray = np.array(orig, dtype=dtype)
with pytest.raises(ValueError): with pytest.raises(ValueError):
xgb.QuantileDMatrix(X) xgb.QuantileDMatrix(X)
def test_changed_max_bin(self) -> None:
n_samples = 128
n_features = 16
csr, y = make_sparse_regression(n_samples, n_features, 0.5, False)
Xy = xgb.QuantileDMatrix(csr, y, max_bin=9)
booster = xgb.train({"max_bin": 9}, Xy, num_boost_round=2)
Xy = xgb.QuantileDMatrix(csr, y, max_bin=11)
with pytest.raises(ValueError, match="consistent"):
xgb.train({}, Xy, num_boost_round=2, xgb_model=booster)

View File

@ -27,7 +27,7 @@ def train_result(param, dmat, num_rounds):
param, param,
dmat, dmat,
num_rounds, num_rounds,
[(dmat, "train")], evals=[(dmat, "train")],
verbose_eval=False, verbose_eval=False,
evals_result=result, evals_result=result,
) )
@ -169,13 +169,21 @@ class TestTreeMethod:
hist_res = {} hist_res = {}
exact_res = {} exact_res = {}
xgb.train(ag_param, ag_dtrain, 10, xgb.train(
[(ag_dtrain, 'train'), (ag_dtest, 'test')], ag_param,
evals_result=hist_res) ag_dtrain,
10,
evals=[(ag_dtrain, "train"), (ag_dtest, "test")],
evals_result=hist_res
)
ag_param["tree_method"] = "exact" ag_param["tree_method"] = "exact"
xgb.train(ag_param, ag_dtrain, 10, xgb.train(
[(ag_dtrain, 'train'), (ag_dtest, 'test')], ag_param,
evals_result=exact_res) ag_dtrain,
10,
evals=[(ag_dtrain, "train"), (ag_dtest, "test")],
evals_result=exact_res
)
assert hist_res['train']['auc'] == exact_res['train']['auc'] assert hist_res['train']['auc'] == exact_res['train']['auc']
assert hist_res['test']['auc'] == exact_res['test']['auc'] assert hist_res['test']['auc'] == exact_res['test']['auc']

View File

@ -1349,10 +1349,11 @@ def test_multilabel_classification() -> None:
np.testing.assert_allclose(clf.predict(X), predt) np.testing.assert_allclose(clf.predict(X), predt)
def test_data_initialization(): def test_data_initialization() -> None:
from sklearn.datasets import load_digits from sklearn.datasets import load_digits
X, y = load_digits(return_X_y=True) X, y = load_digits(return_X_y=True)
validate_data_initialization(xgb.DMatrix, xgb.XGBClassifier, X, y) validate_data_initialization(xgb.QuantileDMatrix, xgb.XGBClassifier, X, y)
@parametrize_with_checks([xgb.XGBRegressor()]) @parametrize_with_checks([xgb.XGBRegressor()])

View File

@ -1,10 +1,9 @@
"""Copyright 2019-2022 XGBoost contributors""" """Copyright 2019-2022 XGBoost contributors"""
import asyncio import asyncio
import os import json
import subprocess
from collections import OrderedDict from collections import OrderedDict
from inspect import signature from inspect import signature
from typing import Any, Dict, Type, TypeVar, Union from typing import Any, Dict, Type, TypeVar
import numpy as np import numpy as np
import pytest import pytest
@ -64,7 +63,7 @@ def run_with_dask_dataframe(DMatrixT: Type, client: Client) -> None:
dtrain = DMatrixT(client, X, y) dtrain = DMatrixT(client, X, y)
out = dxgb.train( out = dxgb.train(
client, client,
{"tree_method": "gpu_hist", "debug_synchronize": True}, {"tree_method": "hist", "debug_synchronize": True, "device": "cuda"},
dtrain=dtrain, dtrain=dtrain,
evals=[(dtrain, "X")], evals=[(dtrain, "X")],
num_boost_round=4, num_boost_round=4,
@ -116,12 +115,18 @@ def run_with_dask_array(DMatrixT: Type, client: Client) -> None:
dtrain = DMatrixT(client, X, y) dtrain = DMatrixT(client, X, y)
out = dxgb.train( out = dxgb.train(
client, client,
{"tree_method": "gpu_hist", "debug_synchronize": True}, {"tree_method": "hist", "debug_synchronize": True, "device": "cuda"},
dtrain=dtrain, dtrain=dtrain,
evals=[(dtrain, "X")], evals=[(dtrain, "X")],
num_boost_round=2, num_boost_round=2,
) )
from_dmatrix = dxgb.predict(client, out, dtrain).compute() from_dmatrix = dxgb.predict(client, out, dtrain).compute()
assert (
json.loads(out["booster"].save_config())["learner"]["gradient_booster"][
"updater"
][0]["name"]
== "grow_gpu_hist"
)
inplace_predictions = dxgb.inplace_predict(client, out, X).compute() inplace_predictions = dxgb.inplace_predict(client, out, X).compute()
single_node = out["booster"].predict(xgb.DMatrix(X.compute())) single_node = out["booster"].predict(xgb.DMatrix(X.compute()))
np.testing.assert_allclose(single_node, from_dmatrix) np.testing.assert_allclose(single_node, from_dmatrix)
@ -149,7 +154,8 @@ def run_gpu_hist(
DMatrixT: Type, DMatrixT: Type,
client: Client, client: Client,
) -> None: ) -> None:
params["tree_method"] = "gpu_hist" params["tree_method"] = "hist"
params["device"] = "cuda"
params = dataset.set_params(params) params = dataset.set_params(params)
# It doesn't make sense to distribute a completely # It doesn't make sense to distribute a completely
# empty dataset. # empty dataset.
@ -196,11 +202,11 @@ def run_gpu_hist(
def test_tree_stats() -> None: def test_tree_stats() -> None:
with LocalCUDACluster(n_workers=1) as cluster: with LocalCUDACluster(n_workers=1) as cluster:
with Client(cluster) as client: with Client(cluster) as client:
local = run_tree_stats(client, "gpu_hist") local = run_tree_stats(client, "hist", "cuda")
with LocalCUDACluster(n_workers=2) as cluster: with LocalCUDACluster(n_workers=2) as cluster:
with Client(cluster) as client: with Client(cluster) as client:
distributed = run_tree_stats(client, "gpu_hist") distributed = run_tree_stats(client, "hist", "cuda")
assert local == distributed assert local == distributed
@ -214,12 +220,12 @@ class TestDistributedGPU:
X_, y_ = load_breast_cancer(return_X_y=True) X_, y_ = load_breast_cancer(return_X_y=True)
X = dd.from_array(X_, chunksize=100).map_partitions(cudf.from_pandas) X = dd.from_array(X_, chunksize=100).map_partitions(cudf.from_pandas)
y = dd.from_array(y_, chunksize=100).map_partitions(cudf.from_pandas) y = dd.from_array(y_, chunksize=100).map_partitions(cudf.from_pandas)
run_boost_from_prediction(X, y, "gpu_hist", local_cuda_client) run_boost_from_prediction(X, y, "hist", "cuda", local_cuda_client)
X_, y_ = load_iris(return_X_y=True) X_, y_ = load_iris(return_X_y=True)
X = dd.from_array(X_, chunksize=50).map_partitions(cudf.from_pandas) X = dd.from_array(X_, chunksize=50).map_partitions(cudf.from_pandas)
y = dd.from_array(y_, chunksize=50).map_partitions(cudf.from_pandas) y = dd.from_array(y_, chunksize=50).map_partitions(cudf.from_pandas)
run_boost_from_prediction_multi_class(X, y, "gpu_hist", local_cuda_client) run_boost_from_prediction_multi_class(X, y, "hist", "cuda", local_cuda_client)
def test_init_estimation(self, local_cuda_client: Client) -> None: def test_init_estimation(self, local_cuda_client: Client) -> None:
check_init_estimation("gpu_hist", local_cuda_client) check_init_estimation("gpu_hist", local_cuda_client)
@ -282,7 +288,7 @@ class TestDistributedGPU:
) )
result = xgb.dask.train( result = xgb.dask.train(
client, client,
{"tree_method": "gpu_hist"}, {"tree_method": "hist", "device": "cuda", "debug_synchronize": True},
Xy, Xy,
num_boost_round=10, num_boost_round=10,
evals=[(Xy_valid, "Valid")], evals=[(Xy_valid, "Valid")],
@ -313,7 +319,8 @@ class TestDistributedGPU:
{ {
"objective": "binary:logistic", "objective": "binary:logistic",
"eval_metric": "error", "eval_metric": "error",
"tree_method": "gpu_hist", "tree_method": "hist",
"device": "cuda",
}, },
m, m,
evals=[(valid, "Valid")], evals=[(valid, "Valid")],
@ -328,7 +335,8 @@ class TestDistributedGPU:
valid_y = y valid_y = y
cls = dxgb.DaskXGBClassifier( cls = dxgb.DaskXGBClassifier(
objective="binary:logistic", objective="binary:logistic",
tree_method="gpu_hist", tree_method="hist",
device="cuda",
eval_metric="error", eval_metric="error",
n_estimators=100, n_estimators=100,
) )
@ -356,7 +364,11 @@ class TestDistributedGPU:
run_dask_classifier(X, y, w, model, "gpu_hist", local_cuda_client, 10) run_dask_classifier(X, y, w, model, "gpu_hist", local_cuda_client, 10)
def test_empty_dmatrix(self, local_cuda_client: Client) -> None: def test_empty_dmatrix(self, local_cuda_client: Client) -> None:
parameters = {"tree_method": "gpu_hist", "debug_synchronize": True} parameters = {
"tree_method": "hist",
"debug_synchronize": True,
"device": "cuda",
}
run_empty_dmatrix_reg(local_cuda_client, parameters) run_empty_dmatrix_reg(local_cuda_client, parameters)
run_empty_dmatrix_cls(local_cuda_client, parameters) run_empty_dmatrix_cls(local_cuda_client, parameters)
@ -374,7 +386,11 @@ class TestDistributedGPU:
"y": [10, 20, 30, 40.0, 50] * mult, "y": [10, 20, 30, 40.0, 50] * mult,
} }
) )
parameters = {"tree_method": "gpu_hist", "debug_synchronize": True} parameters = {
"tree_method": "hist",
"debug_synchronize": True,
"device": "cuda",
}
empty = df.iloc[:0] empty = df.iloc[:0]
ddf = dask_cudf.concat( ddf = dask_cudf.concat(
@ -432,13 +448,25 @@ class TestDistributedGPU:
def test_empty_dmatrix_auc(self, local_cuda_client: Client) -> None: def test_empty_dmatrix_auc(self, local_cuda_client: Client) -> None:
n_workers = len(tm.get_client_workers(local_cuda_client)) n_workers = len(tm.get_client_workers(local_cuda_client))
run_empty_dmatrix_auc(local_cuda_client, "gpu_hist", n_workers) run_empty_dmatrix_auc(local_cuda_client, "cuda", n_workers)
def test_auc(self, local_cuda_client: Client) -> None: def test_auc(self, local_cuda_client: Client) -> None:
run_auc(local_cuda_client, "gpu_hist") run_auc(local_cuda_client, "cuda")
def test_invalid_ordinal(self, local_cuda_client: Client) -> None:
"""One should not specify the device ordinal with dask."""
with pytest.raises(ValueError, match="device=cuda"):
X, y, _ = generate_array()
m = dxgb.DaskDMatrix(local_cuda_client, X, y)
dxgb.train(local_cuda_client, {"device": "cuda:0"}, m)
booster = dxgb.train(local_cuda_client, {"device": "cuda"}, m)["booster"]
assert (
json.loads(booster.save_config())["learner"]["generic_param"]["device"]
== "cuda:0"
)
def test_data_initialization(self, local_cuda_client: Client) -> None: def test_data_initialization(self, local_cuda_client: Client) -> None:
X, y, _ = generate_array() X, y, _ = generate_array()
fw = da.random.random((random_cols,)) fw = da.random.random((random_cols,))
fw = fw - fw.min() fw = fw - fw.min()
@ -531,7 +559,9 @@ async def run_from_dask_array_asyncio(scheduler_address: str) -> dxgb.TrainRetur
y = y.map_blocks(cp.array) y = y.map_blocks(cp.array)
m = await xgb.dask.DaskQuantileDMatrix(client, X, y) m = await xgb.dask.DaskQuantileDMatrix(client, X, y)
output = await xgb.dask.train(client, {"tree_method": "gpu_hist"}, dtrain=m) output = await xgb.dask.train(
client, {"tree_method": "hist", "device": "cuda"}, dtrain=m
)
with_m = await xgb.dask.predict(client, output, m) with_m = await xgb.dask.predict(client, output, m)
with_X = await xgb.dask.predict(client, output, X) with_X = await xgb.dask.predict(client, output, X)

File diff suppressed because it is too large Load Diff

View File

@ -1120,7 +1120,9 @@ class XgboostLocalTest(SparkTestCase):
reg1 = SparkXGBRegressor(**self.reg_params) reg1 = SparkXGBRegressor(**self.reg_params)
model = reg1.fit(self.reg_df_train) model = reg1.fit(self.reg_df_train)
init_booster = model.get_booster() init_booster = model.get_booster()
reg2 = SparkXGBRegressor(max_depth=2, n_estimators=2, xgb_model=init_booster) reg2 = SparkXGBRegressor(
max_depth=2, n_estimators=2, xgb_model=init_booster, max_bin=21
)
model21 = reg2.fit(self.reg_df_train) model21 = reg2.fit(self.reg_df_train)
pred_res21 = model21.transform(self.reg_df_test).collect() pred_res21 = model21.transform(self.reg_df_test).collect()
reg2.save(path) reg2.save(path)