Handle the new device parameter in dask and demos. (#9386)
* Handle the new `device` parameter in dask and demos. - Check no ordinal is specified in the dask interface. - Update demos. - Update dask doc. - Update the condition for QDM.
This commit is contained in:
parent
9da5050643
commit
16eb41936d
@ -18,43 +18,45 @@ def main(client):
|
|||||||
# The Veterans' Administration Lung Cancer Trial
|
# The Veterans' Administration Lung Cancer Trial
|
||||||
# The Statistical Analysis of Failure Time Data by Kalbfleisch J. and Prentice R (1980)
|
# The Statistical Analysis of Failure Time Data by Kalbfleisch J. and Prentice R (1980)
|
||||||
CURRENT_DIR = os.path.dirname(__file__)
|
CURRENT_DIR = os.path.dirname(__file__)
|
||||||
df = dd.read_csv(os.path.join(CURRENT_DIR, os.pardir, 'data', 'veterans_lung_cancer.csv'))
|
df = dd.read_csv(
|
||||||
|
os.path.join(CURRENT_DIR, os.pardir, "data", "veterans_lung_cancer.csv")
|
||||||
|
)
|
||||||
|
|
||||||
# DaskDMatrix acts like normal DMatrix, works as a proxy for local
|
# DaskDMatrix acts like normal DMatrix, works as a proxy for local
|
||||||
# DMatrix scatter around workers.
|
# DMatrix scatter around workers.
|
||||||
# For AFT survival, you'd need to extract the lower and upper bounds for the label
|
# For AFT survival, you'd need to extract the lower and upper bounds for the label
|
||||||
# and pass them as arguments to DaskDMatrix.
|
# and pass them as arguments to DaskDMatrix.
|
||||||
y_lower_bound = df['Survival_label_lower_bound']
|
y_lower_bound = df["Survival_label_lower_bound"]
|
||||||
y_upper_bound = df['Survival_label_upper_bound']
|
y_upper_bound = df["Survival_label_upper_bound"]
|
||||||
X = df.drop(['Survival_label_lower_bound',
|
X = df.drop(["Survival_label_lower_bound", "Survival_label_upper_bound"], axis=1)
|
||||||
'Survival_label_upper_bound'], axis=1)
|
dtrain = DaskDMatrix(
|
||||||
dtrain = DaskDMatrix(client, X, label_lower_bound=y_lower_bound,
|
client, X, label_lower_bound=y_lower_bound, label_upper_bound=y_upper_bound
|
||||||
label_upper_bound=y_upper_bound)
|
)
|
||||||
|
|
||||||
# Use train method from xgboost.dask instead of xgboost. This
|
# Use train method from xgboost.dask instead of xgboost. This
|
||||||
# distributed version of train returns a dictionary containing the
|
# distributed version of train returns a dictionary containing the
|
||||||
# resulting booster and evaluation history obtained from
|
# resulting booster and evaluation history obtained from
|
||||||
# evaluation metrics.
|
# evaluation metrics.
|
||||||
params = {'verbosity': 1,
|
params = {
|
||||||
'objective': 'survival:aft',
|
"verbosity": 1,
|
||||||
'eval_metric': 'aft-nloglik',
|
"objective": "survival:aft",
|
||||||
'learning_rate': 0.05,
|
"eval_metric": "aft-nloglik",
|
||||||
'aft_loss_distribution_scale': 1.20,
|
"learning_rate": 0.05,
|
||||||
'aft_loss_distribution': 'normal',
|
"aft_loss_distribution_scale": 1.20,
|
||||||
'max_depth': 6,
|
"aft_loss_distribution": "normal",
|
||||||
'lambda': 0.01,
|
"max_depth": 6,
|
||||||
'alpha': 0.02}
|
"lambda": 0.01,
|
||||||
output = xgb.dask.train(client,
|
"alpha": 0.02,
|
||||||
params,
|
}
|
||||||
dtrain,
|
output = xgb.dask.train(
|
||||||
num_boost_round=100,
|
client, params, dtrain, num_boost_round=100, evals=[(dtrain, "train")]
|
||||||
evals=[(dtrain, 'train')])
|
)
|
||||||
bst = output['booster']
|
bst = output["booster"]
|
||||||
history = output['history']
|
history = output["history"]
|
||||||
|
|
||||||
# you can pass output directly into `predict` too.
|
# you can pass output directly into `predict` too.
|
||||||
prediction = xgb.dask.predict(client, bst, dtrain)
|
prediction = xgb.dask.predict(client, bst, dtrain)
|
||||||
print('Evaluation history: ', history)
|
print("Evaluation history: ", history)
|
||||||
|
|
||||||
# Uncomment the following line to save the model to the disk
|
# Uncomment the following line to save the model to the disk
|
||||||
# bst.save_model('survival_model.json')
|
# bst.save_model('survival_model.json')
|
||||||
@ -62,7 +64,7 @@ def main(client):
|
|||||||
return prediction
|
return prediction
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
# or use other clusters for scaling
|
# or use other clusters for scaling
|
||||||
with LocalCluster(n_workers=7, threads_per_worker=4) as cluster:
|
with LocalCluster(n_workers=7, threads_per_worker=4) as cluster:
|
||||||
with Client(cluster) as client:
|
with Client(cluster) as client:
|
||||||
|
|||||||
@ -25,21 +25,23 @@ def main(client):
|
|||||||
# distributed version of train returns a dictionary containing the
|
# distributed version of train returns a dictionary containing the
|
||||||
# resulting booster and evaluation history obtained from
|
# resulting booster and evaluation history obtained from
|
||||||
# evaluation metrics.
|
# evaluation metrics.
|
||||||
output = xgb.dask.train(client,
|
output = xgb.dask.train(
|
||||||
{'verbosity': 1,
|
client,
|
||||||
'tree_method': 'hist'},
|
{"verbosity": 1, "tree_method": "hist"},
|
||||||
dtrain,
|
dtrain,
|
||||||
num_boost_round=4, evals=[(dtrain, 'train')])
|
num_boost_round=4,
|
||||||
bst = output['booster']
|
evals=[(dtrain, "train")],
|
||||||
history = output['history']
|
)
|
||||||
|
bst = output["booster"]
|
||||||
|
history = output["history"]
|
||||||
|
|
||||||
# you can pass output directly into `predict` too.
|
# you can pass output directly into `predict` too.
|
||||||
prediction = xgb.dask.predict(client, bst, dtrain)
|
prediction = xgb.dask.predict(client, bst, dtrain)
|
||||||
print('Evaluation history:', history)
|
print("Evaluation history:", history)
|
||||||
return prediction
|
return prediction
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
# or use other clusters for scaling
|
# or use other clusters for scaling
|
||||||
with LocalCluster(n_workers=7, threads_per_worker=4) as cluster:
|
with LocalCluster(n_workers=7, threads_per_worker=4) as cluster:
|
||||||
with Client(cluster) as client:
|
with Client(cluster) as client:
|
||||||
|
|||||||
@ -13,33 +13,38 @@ from xgboost import dask as dxgb
|
|||||||
from xgboost.dask import DaskDMatrix
|
from xgboost.dask import DaskDMatrix
|
||||||
|
|
||||||
|
|
||||||
def using_dask_matrix(client: Client, X, y):
|
def using_dask_matrix(client: Client, X: da.Array, y: da.Array) -> da.Array:
|
||||||
# DaskDMatrix acts like normal DMatrix, works as a proxy for local
|
# DaskDMatrix acts like normal DMatrix, works as a proxy for local DMatrix scatter
|
||||||
# DMatrix scatter around workers.
|
# around workers.
|
||||||
dtrain = DaskDMatrix(client, X, y)
|
dtrain = DaskDMatrix(client, X, y)
|
||||||
|
|
||||||
# Use train method from xgboost.dask instead of xgboost. This
|
# Use train method from xgboost.dask instead of xgboost. This distributed version
|
||||||
# distributed version of train returns a dictionary containing the
|
# of train returns a dictionary containing the resulting booster and evaluation
|
||||||
# resulting booster and evaluation history obtained from
|
# history obtained from evaluation metrics.
|
||||||
# evaluation metrics.
|
output = xgb.dask.train(
|
||||||
output = xgb.dask.train(client,
|
client,
|
||||||
{'verbosity': 2,
|
{
|
||||||
|
"verbosity": 2,
|
||||||
|
"tree_method": "hist",
|
||||||
# Golden line for GPU training
|
# Golden line for GPU training
|
||||||
'tree_method': 'gpu_hist'},
|
"device": "cuda",
|
||||||
|
},
|
||||||
dtrain,
|
dtrain,
|
||||||
num_boost_round=4, evals=[(dtrain, 'train')])
|
num_boost_round=4,
|
||||||
bst = output['booster']
|
evals=[(dtrain, "train")],
|
||||||
history = output['history']
|
)
|
||||||
|
bst = output["booster"]
|
||||||
|
history = output["history"]
|
||||||
|
|
||||||
# you can pass output directly into `predict` too.
|
# you can pass output directly into `predict` too.
|
||||||
prediction = xgb.dask.predict(client, bst, dtrain)
|
prediction = xgb.dask.predict(client, bst, dtrain)
|
||||||
print('Evaluation history:', history)
|
print("Evaluation history:", history)
|
||||||
return prediction
|
return prediction
|
||||||
|
|
||||||
|
|
||||||
def using_quantile_device_dmatrix(client: Client, X, y):
|
def using_quantile_device_dmatrix(client: Client, X: da.Array, y: da.Array) -> da.Array:
|
||||||
"""`DaskQuantileDMatrix` is a data type specialized for `gpu_hist` and `hist` tree
|
"""`DaskQuantileDMatrix` is a data type specialized for `hist` tree methods for
|
||||||
methods for reducing memory usage.
|
reducing memory usage.
|
||||||
|
|
||||||
.. versionadded:: 1.2.0
|
.. versionadded:: 1.2.0
|
||||||
|
|
||||||
@ -52,17 +57,19 @@ def using_quantile_device_dmatrix(client: Client, X, y):
|
|||||||
# the `ref` argument of `DaskQuantileDMatrix`.
|
# the `ref` argument of `DaskQuantileDMatrix`.
|
||||||
dtrain = dxgb.DaskQuantileDMatrix(client, X, y)
|
dtrain = dxgb.DaskQuantileDMatrix(client, X, y)
|
||||||
output = xgb.dask.train(
|
output = xgb.dask.train(
|
||||||
client, {"verbosity": 2, "tree_method": "gpu_hist"}, dtrain, num_boost_round=4
|
client,
|
||||||
|
{"verbosity": 2, "tree_method": "hist", "device": "cuda"},
|
||||||
|
dtrain,
|
||||||
|
num_boost_round=4,
|
||||||
)
|
)
|
||||||
|
|
||||||
prediction = xgb.dask.predict(client, output, X)
|
prediction = xgb.dask.predict(client, output, X)
|
||||||
return prediction
|
return prediction
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
# `LocalCUDACluster` is used for assigning GPU to XGBoost processes. Here
|
# `LocalCUDACluster` is used for assigning GPU to XGBoost processes. Here
|
||||||
# `n_workers` represents the number of GPUs since we use one GPU per worker
|
# `n_workers` represents the number of GPUs since we use one GPU per worker process.
|
||||||
# process.
|
|
||||||
with LocalCUDACluster(n_workers=2, threads_per_worker=4) as cluster:
|
with LocalCUDACluster(n_workers=2, threads_per_worker=4) as cluster:
|
||||||
with Client(cluster) as client:
|
with Client(cluster) as client:
|
||||||
# generate some random data for demonstration
|
# generate some random data for demonstration
|
||||||
@ -71,7 +78,7 @@ if __name__ == '__main__':
|
|||||||
X = da.random.random(size=(m, n), chunks=10000)
|
X = da.random.random(size=(m, n), chunks=10000)
|
||||||
y = da.random.random(size=(m,), chunks=10000)
|
y = da.random.random(size=(m,), chunks=10000)
|
||||||
|
|
||||||
print('Using DaskQuantileDMatrix')
|
print("Using DaskQuantileDMatrix")
|
||||||
from_ddqdm = using_quantile_device_dmatrix(client, X, y)
|
from_ddqdm = using_quantile_device_dmatrix(client, X, y)
|
||||||
print('Using DMatrix')
|
print("Using DMatrix")
|
||||||
from_dmatrix = using_dask_matrix(client, X, y)
|
from_dmatrix = using_dask_matrix(client, X, y)
|
||||||
|
|||||||
@ -21,7 +21,8 @@ def main(client):
|
|||||||
y = da.random.random(m, partition_size)
|
y = da.random.random(m, partition_size)
|
||||||
|
|
||||||
regressor = xgboost.dask.DaskXGBRegressor(verbosity=1)
|
regressor = xgboost.dask.DaskXGBRegressor(verbosity=1)
|
||||||
regressor.set_params(tree_method='gpu_hist')
|
# set the device to CUDA
|
||||||
|
regressor.set_params(tree_method="hist", device="cuda")
|
||||||
# assigning client here is optional
|
# assigning client here is optional
|
||||||
regressor.client = client
|
regressor.client = client
|
||||||
|
|
||||||
@ -31,13 +32,13 @@ def main(client):
|
|||||||
bst = regressor.get_booster()
|
bst = regressor.get_booster()
|
||||||
history = regressor.evals_result()
|
history = regressor.evals_result()
|
||||||
|
|
||||||
print('Evaluation history:', history)
|
print("Evaluation history:", history)
|
||||||
# returned prediction is always a dask array.
|
# returned prediction is always a dask array.
|
||||||
assert isinstance(prediction, da.Array)
|
assert isinstance(prediction, da.Array)
|
||||||
return bst # returning the trained model
|
return bst # returning the trained model
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
# With dask cuda, one can scale up XGBoost to arbitrary GPU clusters.
|
# With dask cuda, one can scale up XGBoost to arbitrary GPU clusters.
|
||||||
# `LocalCUDACluster` used here is only for demonstration purpose.
|
# `LocalCUDACluster` used here is only for demonstration purpose.
|
||||||
with LocalCUDACluster() as cluster:
|
with LocalCUDACluster() as cluster:
|
||||||
|
|||||||
@ -71,7 +71,8 @@ def custom_callback():
|
|||||||
{
|
{
|
||||||
'objective': 'binary:logistic',
|
'objective': 'binary:logistic',
|
||||||
'eval_metric': ['error', 'rmse'],
|
'eval_metric': ['error', 'rmse'],
|
||||||
'tree_method': 'gpu_hist'
|
'tree_method': 'hist',
|
||||||
|
"device": "cuda",
|
||||||
},
|
},
|
||||||
D_train,
|
D_train,
|
||||||
evals=[(D_train, 'Train'), (D_valid, 'Valid')],
|
evals=[(D_train, 'Train'), (D_valid, 'Valid')],
|
||||||
|
|||||||
@ -63,7 +63,8 @@ def load_cat_in_the_dat() -> tuple[pd.DataFrame, pd.Series]:
|
|||||||
|
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"tree_method": "gpu_hist",
|
"tree_method": "hist",
|
||||||
|
"device": "cuda",
|
||||||
"n_estimators": 32,
|
"n_estimators": 32,
|
||||||
"colsample_bylevel": 0.7,
|
"colsample_bylevel": 0.7,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -58,13 +58,13 @@ def main() -> None:
|
|||||||
# Specify `enable_categorical` to True, also we use onehot encoding based split
|
# Specify `enable_categorical` to True, also we use onehot encoding based split
|
||||||
# here for demonstration. For details see the document of `max_cat_to_onehot`.
|
# here for demonstration. For details see the document of `max_cat_to_onehot`.
|
||||||
reg = xgb.XGBRegressor(
|
reg = xgb.XGBRegressor(
|
||||||
tree_method="gpu_hist", enable_categorical=True, max_cat_to_onehot=5
|
tree_method="hist", enable_categorical=True, max_cat_to_onehot=5, device="cuda"
|
||||||
)
|
)
|
||||||
reg.fit(X, y, eval_set=[(X, y)])
|
reg.fit(X, y, eval_set=[(X, y)])
|
||||||
|
|
||||||
# Pass in already encoded data
|
# Pass in already encoded data
|
||||||
X_enc, y_enc = make_categorical(100, 10, 4, True)
|
X_enc, y_enc = make_categorical(100, 10, 4, True)
|
||||||
reg_enc = xgb.XGBRegressor(tree_method="gpu_hist")
|
reg_enc = xgb.XGBRegressor(tree_method="hist", device="cuda")
|
||||||
reg_enc.fit(X_enc, y_enc, eval_set=[(X_enc, y_enc)])
|
reg_enc.fit(X_enc, y_enc, eval_set=[(X_enc, y_enc)])
|
||||||
|
|
||||||
reg_results = np.array(reg.evals_result()["validation_0"]["rmse"])
|
reg_results = np.array(reg.evals_result()["validation_0"]["rmse"])
|
||||||
|
|||||||
@ -82,8 +82,9 @@ def main(tmpdir: str) -> xgboost.Booster:
|
|||||||
missing = np.NaN
|
missing = np.NaN
|
||||||
Xy = xgboost.DMatrix(it, missing=missing, enable_categorical=False)
|
Xy = xgboost.DMatrix(it, missing=missing, enable_categorical=False)
|
||||||
|
|
||||||
# Other tree methods including ``approx``, and ``gpu_hist`` are supported. GPU
|
# ``approx`` is also supported, but less efficient due to sketching. GPU behaves
|
||||||
# behaves differently than CPU tree methods. See tutorial in doc for details.
|
# differently than CPU tree methods as it uses a hybrid approach. See tutorial in
|
||||||
|
# doc for details.
|
||||||
booster = xgboost.train(
|
booster = xgboost.train(
|
||||||
{"tree_method": "hist", "max_depth": 4},
|
{"tree_method": "hist", "max_depth": 4},
|
||||||
Xy,
|
Xy,
|
||||||
|
|||||||
@ -104,7 +104,8 @@ def ranking_demo(args: argparse.Namespace) -> None:
|
|||||||
qid_test = qid_test[sorted_idx]
|
qid_test = qid_test[sorted_idx]
|
||||||
|
|
||||||
ranker = xgb.XGBRanker(
|
ranker = xgb.XGBRanker(
|
||||||
tree_method="gpu_hist",
|
tree_method="hist",
|
||||||
|
device="cuda",
|
||||||
lambdarank_pair_method="topk",
|
lambdarank_pair_method="topk",
|
||||||
lambdarank_num_pair_per_sample=13,
|
lambdarank_num_pair_per_sample=13,
|
||||||
eval_metric=["ndcg@1", "ndcg@8"],
|
eval_metric=["ndcg@1", "ndcg@8"],
|
||||||
@ -161,7 +162,8 @@ def click_data_demo(args: argparse.Namespace) -> None:
|
|||||||
|
|
||||||
ranker = xgb.XGBRanker(
|
ranker = xgb.XGBRanker(
|
||||||
n_estimators=512,
|
n_estimators=512,
|
||||||
tree_method="gpu_hist",
|
tree_method="hist",
|
||||||
|
device="cuda",
|
||||||
learning_rate=0.01,
|
learning_rate=0.01,
|
||||||
reg_lambda=1.5,
|
reg_lambda=1.5,
|
||||||
subsample=0.8,
|
subsample=0.8,
|
||||||
|
|||||||
@ -28,17 +28,18 @@ BATCHES = 32
|
|||||||
|
|
||||||
|
|
||||||
class IterForDMatrixDemo(xgboost.core.DataIter):
|
class IterForDMatrixDemo(xgboost.core.DataIter):
|
||||||
'''A data iterator for XGBoost DMatrix.
|
"""A data iterator for XGBoost DMatrix.
|
||||||
|
|
||||||
`reset` and `next` are required for any data iterator, other functions here
|
`reset` and `next` are required for any data iterator, other functions here
|
||||||
are utilites for demonstration's purpose.
|
are utilites for demonstration's purpose.
|
||||||
|
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
'''Generate some random data for demostration.
|
"""Generate some random data for demostration.
|
||||||
|
|
||||||
Actual data can be anything that is currently supported by XGBoost.
|
Actual data can be anything that is currently supported by XGBoost.
|
||||||
'''
|
"""
|
||||||
self.rows = ROWS_PER_BATCH
|
self.rows = ROWS_PER_BATCH
|
||||||
self.cols = COLS
|
self.cols = COLS
|
||||||
rng = cupy.random.RandomState(1994)
|
rng = cupy.random.RandomState(1994)
|
||||||
@ -59,27 +60,26 @@ class IterForDMatrixDemo(xgboost.core.DataIter):
|
|||||||
return cupy.concatenate(self._weights)
|
return cupy.concatenate(self._weights)
|
||||||
|
|
||||||
def data(self):
|
def data(self):
|
||||||
'''Utility function for obtaining current batch of data.'''
|
"""Utility function for obtaining current batch of data."""
|
||||||
return self._data[self.it]
|
return self._data[self.it]
|
||||||
|
|
||||||
def labels(self):
|
def labels(self):
|
||||||
'''Utility function for obtaining current batch of label.'''
|
"""Utility function for obtaining current batch of label."""
|
||||||
return self._labels[self.it]
|
return self._labels[self.it]
|
||||||
|
|
||||||
def weights(self):
|
def weights(self):
|
||||||
return self._weights[self.it]
|
return self._weights[self.it]
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
'''Reset the iterator'''
|
"""Reset the iterator"""
|
||||||
self.it = 0
|
self.it = 0
|
||||||
|
|
||||||
def next(self, input_data):
|
def next(self, input_data):
|
||||||
'''Yield next batch of data.'''
|
"""Yield next batch of data."""
|
||||||
if self.it == len(self._data):
|
if self.it == len(self._data):
|
||||||
# Return 0 when there's no more batch.
|
# Return 0 when there's no more batch.
|
||||||
return 0
|
return 0
|
||||||
input_data(data=self.data(), label=self.labels(),
|
input_data(data=self.data(), label=self.labels(), weight=self.weights())
|
||||||
weight=self.weights())
|
|
||||||
self.it += 1
|
self.it += 1
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
@ -103,18 +103,19 @@ def main():
|
|||||||
|
|
||||||
assert m_with_it.num_col() == m.num_col()
|
assert m_with_it.num_col() == m.num_col()
|
||||||
assert m_with_it.num_row() == m.num_row()
|
assert m_with_it.num_row() == m.num_row()
|
||||||
# Tree meethod must be one of the `hist` or `gpu_hist`. We use `gpu_hist` for GPU
|
# Tree meethod must be `hist`.
|
||||||
# input here.
|
|
||||||
reg_with_it = xgboost.train(
|
reg_with_it = xgboost.train(
|
||||||
{"tree_method": "gpu_hist"}, m_with_it, num_boost_round=rounds
|
{"tree_method": "hist", "device": "cuda"}, m_with_it, num_boost_round=rounds
|
||||||
)
|
)
|
||||||
predict_with_it = reg_with_it.predict(m_with_it)
|
predict_with_it = reg_with_it.predict(m_with_it)
|
||||||
|
|
||||||
reg = xgboost.train({"tree_method": "gpu_hist"}, m, num_boost_round=rounds)
|
reg = xgboost.train(
|
||||||
|
{"tree_method": "hist", "device": "cuda"}, m, num_boost_round=rounds
|
||||||
|
)
|
||||||
predict = reg.predict(m)
|
predict = reg.predict(m)
|
||||||
|
|
||||||
numpy.testing.assert_allclose(predict_with_it, predict, rtol=1e6)
|
numpy.testing.assert_allclose(predict_with_it, predict, rtol=1e6)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@ -24,7 +24,7 @@ def main():
|
|||||||
Xy = xgb.DMatrix(X_train, y_train)
|
Xy = xgb.DMatrix(X_train, y_train)
|
||||||
evals_result: xgb.callback.EvaluationMonitor.EvalsLog = {}
|
evals_result: xgb.callback.EvaluationMonitor.EvalsLog = {}
|
||||||
booster = xgb.train(
|
booster = xgb.train(
|
||||||
{"tree_method": "gpu_hist", "max_depth": 6},
|
{"tree_method": "hist", "max_depth": 6, "device": "cuda"},
|
||||||
Xy,
|
Xy,
|
||||||
num_boost_round=n_rounds,
|
num_boost_round=n_rounds,
|
||||||
evals=[(Xy, "Train")],
|
evals=[(Xy, "Train")],
|
||||||
@ -87,7 +87,7 @@ def main():
|
|||||||
np.testing.assert_allclose(
|
np.testing.assert_allclose(
|
||||||
np.array(prune_result["Original"]["rmse"]),
|
np.array(prune_result["Original"]["rmse"]),
|
||||||
np.array(prune_result["Train"]["rmse"]),
|
np.array(prune_result["Train"]["rmse"]),
|
||||||
atol=1e-5
|
atol=1e-5,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -14,30 +14,24 @@ Most of the algorithms in XGBoost including training, prediction and evaluation
|
|||||||
|
|
||||||
Usage
|
Usage
|
||||||
=====
|
=====
|
||||||
Specify the ``tree_method`` parameter as ``gpu_hist``. For details around the ``tree_method`` parameter, see :doc:`tree method </treemethod>`.
|
|
||||||
|
|
||||||
Supported parameters
|
|
||||||
--------------------
|
|
||||||
|
|
||||||
GPU accelerated prediction is enabled by default for the above mentioned ``tree_method`` parameters but can be switched to CPU prediction by setting ``predictor`` to ``cpu_predictor``. This could be useful if you want to conserve GPU memory. Likewise when using CPU algorithms, GPU accelerated prediction can be enabled by setting ``predictor`` to ``gpu_predictor``.
|
|
||||||
|
|
||||||
The device ordinal (which GPU to use if you have many of them) can be selected using the
|
|
||||||
``device`` parameter, which defaults to 0 when "CUDA" is specified(the first device reported by CUDA
|
|
||||||
runtime).
|
|
||||||
|
|
||||||
|
To enable GPU acceleration, specify the ``device`` parameter as ``cuda``. In addition, the device ordinal (which GPU to use if you have multiple devices in the same node) can be specified using the ``cuda:<ordinal>`` syntax, where ``<ordinal>`` is an integer that represents the device ordinal. XGBoost defaults to 0 (the first device reported by CUDA runtime).
|
||||||
|
|
||||||
The GPU algorithms currently work with CLI, Python, R, and JVM packages. See :doc:`/install` for details.
|
The GPU algorithms currently work with CLI, Python, R, and JVM packages. See :doc:`/install` for details.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
:caption: Python example
|
:caption: Python example
|
||||||
|
|
||||||
param["device"] = "cuda:0"
|
params = dict()
|
||||||
param['tree_method'] = 'gpu_hist'
|
params["device"] = "cuda:0"
|
||||||
|
params["tree_method"] = "hist"
|
||||||
|
Xy = xgboost.QuantileDMatrix(X, y)
|
||||||
|
xgboost.train(params, Xy)
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
:caption: With Scikit-Learn interface
|
:caption: With Scikit-Learn interface
|
||||||
|
|
||||||
XGBRegressor(tree_method='gpu_hist', device="cuda")
|
XGBRegressor(tree_method="hist", device="cuda")
|
||||||
|
|
||||||
|
|
||||||
GPU-Accelerated SHAP values
|
GPU-Accelerated SHAP values
|
||||||
@ -46,12 +40,11 @@ XGBoost makes use of `GPUTreeShap <https://github.com/rapidsai/gputreeshap>`_ as
|
|||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
model.set_param({"device": "cuda:0", "tree_method": "gpu_hist"})
|
booster.set_param({"device": "cuda:0"})
|
||||||
shap_values = model.predict(dtrain, pred_contribs=True)
|
shap_values = booster.predict(dtrain, pred_contribs=True)
|
||||||
shap_interaction_values = model.predict(dtrain, pred_interactions=True)
|
shap_interaction_values = model.predict(dtrain, pred_interactions=True)
|
||||||
|
|
||||||
See examples `here
|
See examples `here <https://github.com/dmlc/xgboost/tree/master/demo/gpu_acceleration>`__.
|
||||||
<https://github.com/dmlc/xgboost/tree/master/demo/gpu_acceleration>`__.
|
|
||||||
|
|
||||||
Multi-node Multi-GPU Training
|
Multi-node Multi-GPU Training
|
||||||
=============================
|
=============================
|
||||||
@ -61,7 +54,7 @@ XGBoost supports fully distributed GPU training using `Dask <https://dask.org/>`
|
|||||||
|
|
||||||
Memory usage
|
Memory usage
|
||||||
============
|
============
|
||||||
The following are some guidelines on the device memory usage of the `gpu_hist` tree method.
|
The following are some guidelines on the device memory usage of the ``hist`` tree method on GPU.
|
||||||
|
|
||||||
Memory inside xgboost training is generally allocated for two reasons - storing the dataset and working memory.
|
Memory inside xgboost training is generally allocated for two reasons - storing the dataset and working memory.
|
||||||
|
|
||||||
@ -79,7 +72,7 @@ XGBoost models trained on GPUs can be used on CPU-only systems to generate predi
|
|||||||
|
|
||||||
Developer notes
|
Developer notes
|
||||||
===============
|
===============
|
||||||
The application may be profiled with annotations by specifying USE_NTVX to cmake. Regions covered by the 'Monitor' class in CUDA code will automatically appear in the nsight profiler when `verbosity` is set to 3.
|
The application may be profiled with annotations by specifying ``USE_NTVX`` to cmake. Regions covered by the 'Monitor' class in CUDA code will automatically appear in the nsight profiler when `verbosity` is set to 3.
|
||||||
|
|
||||||
**********
|
**********
|
||||||
References
|
References
|
||||||
|
|||||||
@ -55,10 +55,6 @@ General Parameters
|
|||||||
|
|
||||||
- Flag to disable default metric. Set to 1 or ``true`` to disable.
|
- Flag to disable default metric. Set to 1 or ``true`` to disable.
|
||||||
|
|
||||||
* ``num_feature`` [set automatically by XGBoost, no need to be set by user]
|
|
||||||
|
|
||||||
- Feature dimension used in boosting, set to maximum dimension of the feature
|
|
||||||
|
|
||||||
* ``device`` [default= ``cpu``]
|
* ``device`` [default= ``cpu``]
|
||||||
|
|
||||||
.. versionadded:: 2.0.0
|
.. versionadded:: 2.0.0
|
||||||
@ -164,7 +160,7 @@ Parameters for Tree Booster
|
|||||||
- ``grow_colmaker``: non-distributed column-based construction of trees.
|
- ``grow_colmaker``: non-distributed column-based construction of trees.
|
||||||
- ``grow_histmaker``: distributed tree construction with row-based data splitting based on global proposal of histogram counting.
|
- ``grow_histmaker``: distributed tree construction with row-based data splitting based on global proposal of histogram counting.
|
||||||
- ``grow_quantile_histmaker``: Grow tree using quantized histogram.
|
- ``grow_quantile_histmaker``: Grow tree using quantized histogram.
|
||||||
- ``grow_gpu_hist``: Grow tree with GPU. Same as setting tree method to ``hist`` and use ``device=cuda``.
|
- ``grow_gpu_hist``: Grow tree with GPU. Same as setting ``tree_method`` to ``hist`` and use ``device=cuda``.
|
||||||
- ``sync``: synchronizes trees in all distributed nodes.
|
- ``sync``: synchronizes trees in all distributed nodes.
|
||||||
- ``refresh``: refreshes tree's statistics and/or leaf values based on the current data. Note that no random subsampling of data rows is performed.
|
- ``refresh``: refreshes tree's statistics and/or leaf values based on the current data. Note that no random subsampling of data rows is performed.
|
||||||
- ``prune``: prunes the splits where loss < min_split_loss (or gamma) and nodes that have depth greater than ``max_depth``.
|
- ``prune``: prunes the splits where loss < min_split_loss (or gamma) and nodes that have depth greater than ``max_depth``.
|
||||||
|
|||||||
@ -310,8 +310,8 @@ for more info.
|
|||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
# Use "gpu_hist" for training the model.
|
# Use "hist" for training the model.
|
||||||
reg = xgb.XGBRegressor(tree_method="gpu_hist")
|
reg = xgb.XGBRegressor(tree_method="hist", device="cuda")
|
||||||
# Fit the model using predictor X and response y.
|
# Fit the model using predictor X and response y.
|
||||||
reg.fit(X, y)
|
reg.fit(X, y)
|
||||||
# Save model into JSON format.
|
# Save model into JSON format.
|
||||||
|
|||||||
@ -56,7 +56,6 @@ on a dask cluster:
|
|||||||
dtrain = xgb.dask.DaskDMatrix(client, X, y)
|
dtrain = xgb.dask.DaskDMatrix(client, X, y)
|
||||||
# or
|
# or
|
||||||
# dtrain = xgb.dask.DaskQuantileDMatrix(client, X, y)
|
# dtrain = xgb.dask.DaskQuantileDMatrix(client, X, y)
|
||||||
# `DaskQuantileDMatrix` is available for the `hist` and `gpu_hist` tree method.
|
|
||||||
|
|
||||||
output = xgb.dask.train(
|
output = xgb.dask.train(
|
||||||
client,
|
client,
|
||||||
@ -149,7 +148,7 @@ Also for inplace prediction:
|
|||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
# where X is a dask DataFrame or dask Array backed by cupy or cuDF.
|
# where X is a dask DataFrame or dask Array backed by cupy or cuDF.
|
||||||
booster.set_param({"device": "cuda:0"})
|
booster.set_param({"device": "cuda"})
|
||||||
prediction = xgb.dask.inplace_predict(client, booster, X)
|
prediction = xgb.dask.inplace_predict(client, booster, X)
|
||||||
|
|
||||||
When input is ``da.Array`` object, output is always ``da.Array``. However, if the input
|
When input is ``da.Array`` object, output is always ``da.Array``. However, if the input
|
||||||
@ -225,6 +224,12 @@ collection.
|
|||||||
main(client)
|
main(client)
|
||||||
|
|
||||||
|
|
||||||
|
****************
|
||||||
|
GPU acceleration
|
||||||
|
****************
|
||||||
|
|
||||||
|
For most of the use cases with GPUs, the `Dask-CUDA <https://docs.rapids.ai/api/dask-cuda/stable/quickstart.html>`__ project should be used to create the cluster, which automatically configures the correct device ordinal for worker processes. As a result, users should NOT specify the ordinal (good: ``device=cuda``, bad: ``device=cuda:1``). See :ref:`sphx_glr_python_dask-examples_gpu_training.py` and :ref:`sphx_glr_python_dask-examples_sklearn_gpu_training.py` for worked examples.
|
||||||
|
|
||||||
***************************
|
***************************
|
||||||
Working with other clusters
|
Working with other clusters
|
||||||
***************************
|
***************************
|
||||||
@ -262,7 +267,7 @@ In the example below, a ``KubeCluster`` is used for `deploying Dask on Kubernete
|
|||||||
|
|
||||||
regressor = xgb.dask.DaskXGBRegressor(n_estimators=10, missing=0.0)
|
regressor = xgb.dask.DaskXGBRegressor(n_estimators=10, missing=0.0)
|
||||||
regressor.client = client
|
regressor.client = client
|
||||||
regressor.set_params(tree_method='gpu_hist')
|
regressor.set_params(tree_method='hist', device="cuda")
|
||||||
regressor.fit(X, y, eval_set=[(X, y)])
|
regressor.fit(X, y, eval_set=[(X, y)])
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1451,7 +1451,7 @@ class QuantileDMatrix(DMatrix):
|
|||||||
enable_categorical: bool = False,
|
enable_categorical: bool = False,
|
||||||
data_split_mode: DataSplitMode = DataSplitMode.ROW,
|
data_split_mode: DataSplitMode = DataSplitMode.ROW,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.max_bin: int = max_bin if max_bin is not None else 256
|
self.max_bin = max_bin
|
||||||
self.missing = missing if missing is not None else np.nan
|
self.missing = missing if missing is not None else np.nan
|
||||||
self.nthread = nthread if nthread is not None else -1
|
self.nthread = nthread if nthread is not None else -1
|
||||||
self._silent = silent # unused, kept for compatibility
|
self._silent = silent # unused, kept for compatibility
|
||||||
|
|||||||
@ -82,6 +82,7 @@ from .sklearn import (
|
|||||||
XGBRanker,
|
XGBRanker,
|
||||||
XGBRankerMixIn,
|
XGBRankerMixIn,
|
||||||
XGBRegressorBase,
|
XGBRegressorBase,
|
||||||
|
_can_use_qdm,
|
||||||
_check_rf_callback,
|
_check_rf_callback,
|
||||||
_cls_predict_proba,
|
_cls_predict_proba,
|
||||||
_objective_decorator,
|
_objective_decorator,
|
||||||
@ -617,14 +618,7 @@ class DaskPartitionIter(DataIter): # pylint: disable=R0902
|
|||||||
if self._iter == len(self._data):
|
if self._iter == len(self._data):
|
||||||
# Return 0 when there's no more batch.
|
# Return 0 when there's no more batch.
|
||||||
return 0
|
return 0
|
||||||
feature_names: Optional[FeatureNames] = None
|
|
||||||
if self._feature_names:
|
|
||||||
feature_names = self._feature_names
|
|
||||||
else:
|
|
||||||
if hasattr(self.data(), "columns"):
|
|
||||||
feature_names = self.data().columns.format()
|
|
||||||
else:
|
|
||||||
feature_names = None
|
|
||||||
input_data(
|
input_data(
|
||||||
data=self.data(),
|
data=self.data(),
|
||||||
label=self._get("_label"),
|
label=self._get("_label"),
|
||||||
@ -634,7 +628,7 @@ class DaskPartitionIter(DataIter): # pylint: disable=R0902
|
|||||||
base_margin=self._get("_base_margin"),
|
base_margin=self._get("_base_margin"),
|
||||||
label_lower_bound=self._get("_label_lower_bound"),
|
label_lower_bound=self._get("_label_lower_bound"),
|
||||||
label_upper_bound=self._get("_label_upper_bound"),
|
label_upper_bound=self._get("_label_upper_bound"),
|
||||||
feature_names=feature_names,
|
feature_names=self._feature_names,
|
||||||
feature_types=self._feature_types,
|
feature_types=self._feature_types,
|
||||||
feature_weights=self._feature_weights,
|
feature_weights=self._feature_weights,
|
||||||
)
|
)
|
||||||
@ -935,6 +929,12 @@ async def _train_async(
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"booster `{params['booster']}` is not yet supported for dask."
|
f"booster `{params['booster']}` is not yet supported for dask."
|
||||||
)
|
)
|
||||||
|
device = params.get("device", None)
|
||||||
|
if device and device.find(":") != -1:
|
||||||
|
raise ValueError(
|
||||||
|
"The dask interface for XGBoost doesn't support selecting specific device"
|
||||||
|
" ordinal. Use `device=cpu` or `device=cuda` instead."
|
||||||
|
)
|
||||||
|
|
||||||
def dispatched_train(
|
def dispatched_train(
|
||||||
parameters: Dict,
|
parameters: Dict,
|
||||||
@ -1574,7 +1574,7 @@ async def _async_wrap_evaluation_matrices(
|
|||||||
"""A switch function for async environment."""
|
"""A switch function for async environment."""
|
||||||
|
|
||||||
def _dispatch(ref: Optional[DaskDMatrix], **kwargs: Any) -> DaskDMatrix:
|
def _dispatch(ref: Optional[DaskDMatrix], **kwargs: Any) -> DaskDMatrix:
|
||||||
if tree_method in ("hist", "gpu_hist"):
|
if _can_use_qdm(tree_method):
|
||||||
return DaskQuantileDMatrix(
|
return DaskQuantileDMatrix(
|
||||||
client=client, ref=ref, max_bin=max_bin, **kwargs
|
client=client, ref=ref, max_bin=max_bin, **kwargs
|
||||||
)
|
)
|
||||||
|
|||||||
@ -76,6 +76,10 @@ def _check_rf_callback(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _can_use_qdm(tree_method: Optional[str]) -> bool:
|
||||||
|
return tree_method in ("hist", "gpu_hist", None, "auto")
|
||||||
|
|
||||||
|
|
||||||
SklObjective = Optional[
|
SklObjective = Optional[
|
||||||
Union[str, Callable[[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]]
|
Union[str, Callable[[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]]
|
||||||
]
|
]
|
||||||
@ -939,7 +943,7 @@ class XGBModel(XGBModelBase):
|
|||||||
|
|
||||||
def _create_dmatrix(self, ref: Optional[DMatrix], **kwargs: Any) -> DMatrix:
|
def _create_dmatrix(self, ref: Optional[DMatrix], **kwargs: Any) -> DMatrix:
|
||||||
# Use `QuantileDMatrix` to save memory.
|
# Use `QuantileDMatrix` to save memory.
|
||||||
if self.tree_method in ("hist", "gpu_hist"):
|
if _can_use_qdm(self.tree_method) and self.booster != "gblinear":
|
||||||
try:
|
try:
|
||||||
return QuantileDMatrix(
|
return QuantileDMatrix(
|
||||||
**kwargs, ref=ref, nthread=self.n_jobs, max_bin=self.max_bin
|
**kwargs, ref=ref, nthread=self.n_jobs, max_bin=self.max_bin
|
||||||
|
|||||||
@ -61,7 +61,7 @@ import xgboost
|
|||||||
from xgboost import XGBClassifier
|
from xgboost import XGBClassifier
|
||||||
from xgboost.compat import is_cudf_available
|
from xgboost.compat import is_cudf_available
|
||||||
from xgboost.core import Booster
|
from xgboost.core import Booster
|
||||||
from xgboost.sklearn import DEFAULT_N_ESTIMATORS, XGBModel
|
from xgboost.sklearn import DEFAULT_N_ESTIMATORS, XGBModel, _can_use_qdm
|
||||||
from xgboost.training import train as worker_train
|
from xgboost.training import train as worker_train
|
||||||
|
|
||||||
from .data import (
|
from .data import (
|
||||||
@ -901,7 +901,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
context = BarrierTaskContext.get()
|
context = BarrierTaskContext.get()
|
||||||
|
|
||||||
dev_ordinal = None
|
dev_ordinal = None
|
||||||
use_hist = booster_params.get("tree_method", None) in ("hist", "gpu_hist")
|
use_qdm = _can_use_qdm(booster_params.get("tree_method", None))
|
||||||
|
|
||||||
if use_gpu:
|
if use_gpu:
|
||||||
dev_ordinal = (
|
dev_ordinal = (
|
||||||
@ -912,9 +912,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
# because without cuDF, DMatrix performs better than QDM.
|
# because without cuDF, DMatrix performs better than QDM.
|
||||||
# Note: Checking `is_cudf_available` in spark worker side because
|
# Note: Checking `is_cudf_available` in spark worker side because
|
||||||
# spark worker might has different python environment with driver side.
|
# spark worker might has different python environment with driver side.
|
||||||
use_qdm = use_hist and is_cudf_available()
|
use_qdm = use_qdm and is_cudf_available()
|
||||||
else:
|
|
||||||
use_qdm = use_hist
|
|
||||||
|
|
||||||
if use_qdm and (booster_params.get("max_bin", None) is not None):
|
if use_qdm and (booster_params.get("max_bin", None) is not None):
|
||||||
dmatrix_kwargs["max_bin"] = booster_params["max_bin"]
|
dmatrix_kwargs["max_bin"] = booster_params["max_bin"]
|
||||||
|
|||||||
@ -81,13 +81,6 @@ void XGBBuildInfoDevice(Json *p_info) {
|
|||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace {
|
|
||||||
void DeprecatedFunc(StringView old, StringView since, StringView replacement) {
|
|
||||||
LOG(WARNING) << "`" << old << "` is deprecated since" << since << ", use `" << replacement
|
|
||||||
<< "` instead.";
|
|
||||||
}
|
|
||||||
} // anonymous namespace
|
|
||||||
|
|
||||||
XGB_DLL int XGBuildInfo(char const **out) {
|
XGB_DLL int XGBuildInfo(char const **out) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
xgboost_CHECK_C_ARG_PTR(out);
|
xgboost_CHECK_C_ARG_PTR(out);
|
||||||
@ -328,7 +321,7 @@ XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatr
|
|||||||
int nthread, int max_bin,
|
int nthread, int max_bin,
|
||||||
DMatrixHandle *out) {
|
DMatrixHandle *out) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
DeprecatedFunc(__func__, "1.7.0", "XGQuantileDMatrixCreateFromCallback");
|
LOG(WARNING) << error::DeprecatedFunc(__func__, "1.7.0", "XGQuantileDMatrixCreateFromCallback");
|
||||||
*out = new std::shared_ptr<xgboost::DMatrix>{
|
*out = new std::shared_ptr<xgboost::DMatrix>{
|
||||||
xgboost::DMatrix::Create(iter, proxy, nullptr, reset, next, missing, nthread, max_bin)};
|
xgboost::DMatrix::Create(iter, proxy, nullptr, reset, next, missing, nthread, max_bin)};
|
||||||
API_END();
|
API_END();
|
||||||
@ -432,7 +425,7 @@ XGB_DLL int XGDMatrixCreateFromCSREx(const size_t *indptr, const unsigned *indic
|
|||||||
const bst_float *data, size_t nindptr, size_t nelem,
|
const bst_float *data, size_t nindptr, size_t nelem,
|
||||||
size_t num_col, DMatrixHandle *out) {
|
size_t num_col, DMatrixHandle *out) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
DeprecatedFunc(__func__, "2.0.0", "XGDMatrixCreateFromCSR");
|
LOG(WARNING) << error::DeprecatedFunc(__func__, "2.0.0", "XGDMatrixCreateFromCSR");
|
||||||
data::CSRAdapter adapter(indptr, indices, data, nindptr - 1, nelem, num_col);
|
data::CSRAdapter adapter(indptr, indices, data, nindptr - 1, nelem, num_col);
|
||||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, std::nan(""), 1));
|
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, std::nan(""), 1));
|
||||||
API_END();
|
API_END();
|
||||||
@ -496,7 +489,7 @@ XGB_DLL int XGDMatrixCreateFromCSCEx(const size_t *col_ptr, const unsigned *indi
|
|||||||
const bst_float *data, size_t nindptr, size_t, size_t num_row,
|
const bst_float *data, size_t nindptr, size_t, size_t num_row,
|
||||||
DMatrixHandle *out) {
|
DMatrixHandle *out) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
DeprecatedFunc(__func__, "2.0.0", "XGDMatrixCreateFromCSC");
|
LOG(WARNING) << error::DeprecatedFunc(__func__, "2.0.0", "XGDMatrixCreateFromCSC");
|
||||||
data::CSCAdapter adapter(col_ptr, indices, data, nindptr - 1, num_row);
|
data::CSCAdapter adapter(col_ptr, indices, data, nindptr - 1, num_row);
|
||||||
xgboost_CHECK_C_ARG_PTR(out);
|
xgboost_CHECK_C_ARG_PTR(out);
|
||||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, std::nan(""), 1));
|
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, std::nan(""), 1));
|
||||||
@ -1347,7 +1340,7 @@ XGB_DLL int XGBoosterGetModelRaw(BoosterHandle handle, xgboost::bst_ulong *out_l
|
|||||||
raw_str.resize(0);
|
raw_str.resize(0);
|
||||||
|
|
||||||
common::MemoryBufferStream fo(&raw_str);
|
common::MemoryBufferStream fo(&raw_str);
|
||||||
DeprecatedFunc(__func__, "1.6.0", "XGBoosterSaveModelToBuffer");
|
LOG(WARNING) << error::DeprecatedFunc(__func__, "1.6.0", "XGBoosterSaveModelToBuffer");
|
||||||
|
|
||||||
learner->Configure();
|
learner->Configure();
|
||||||
learner->SaveModel(&fo);
|
learner->SaveModel(&fo);
|
||||||
|
|||||||
@ -3,10 +3,18 @@
|
|||||||
*/
|
*/
|
||||||
#include "error_msg.h"
|
#include "error_msg.h"
|
||||||
|
|
||||||
|
#include <sstream> // for stringstream
|
||||||
|
|
||||||
#include "../collective/communicator-inl.h" // for GetRank
|
#include "../collective/communicator-inl.h" // for GetRank
|
||||||
#include "xgboost/logging.h"
|
#include "xgboost/logging.h"
|
||||||
|
|
||||||
namespace xgboost::error {
|
namespace xgboost::error {
|
||||||
|
std::string DeprecatedFunc(StringView old, StringView since, StringView replacement) {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "`" << old << "` is deprecated since" << since << ", use `" << replacement << "` instead.";
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
void WarnDeprecatedGPUHist() {
|
void WarnDeprecatedGPUHist() {
|
||||||
auto msg =
|
auto msg =
|
||||||
"The tree method `gpu_hist` is deprecated since 2.0.0. To use GPU training, set the `device` "
|
"The tree method `gpu_hist` is deprecated since 2.0.0. To use GPU training, set the `device` "
|
||||||
@ -34,8 +42,9 @@ void WarnDeprecatedGPUId() {
|
|||||||
if (logged) {
|
if (logged) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
LOG(WARNING) << "`gpu_id` is deprecated in favor of the new `device` parameter: "
|
auto msg = DeprecatedFunc("gpu_id", "2.0.0", "device");
|
||||||
<< "device = cpu/cuda/cuda:0";
|
msg += " E.g. device=cpu/cuda/cuda:0";
|
||||||
|
LOG(WARNING) << msg;
|
||||||
logged = true;
|
logged = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -8,6 +8,7 @@
|
|||||||
|
|
||||||
#include <cinttypes> // for uint64_t
|
#include <cinttypes> // for uint64_t
|
||||||
#include <limits> // for numeric_limits
|
#include <limits> // for numeric_limits
|
||||||
|
#include <string> // for string
|
||||||
|
|
||||||
#include "xgboost/base.h" // for bst_feature_t
|
#include "xgboost/base.h" // for bst_feature_t
|
||||||
#include "xgboost/logging.h"
|
#include "xgboost/logging.h"
|
||||||
@ -86,5 +87,7 @@ void WarnManualUpdater();
|
|||||||
void WarnDeprecatedGPUId();
|
void WarnDeprecatedGPUId();
|
||||||
|
|
||||||
void WarnEmptyDataset();
|
void WarnEmptyDataset();
|
||||||
|
|
||||||
|
std::string DeprecatedFunc(StringView old, StringView since, StringView replacement);
|
||||||
} // namespace xgboost::error
|
} // namespace xgboost::error
|
||||||
#endif // XGBOOST_COMMON_ERROR_MSG_H_
|
#endif // XGBOOST_COMMON_ERROR_MSG_H_
|
||||||
|
|||||||
@ -693,20 +693,21 @@ class LearnerConfiguration : public Learner {
|
|||||||
for (auto const& kv : obj) {
|
for (auto const& kv : obj) {
|
||||||
if (is_parameter(kv.first)) {
|
if (is_parameter(kv.first)) {
|
||||||
auto parameter = get<Object const>(kv.second);
|
auto parameter = get<Object const>(kv.second);
|
||||||
std::transform(parameter.begin(), parameter.end(), std::back_inserter(keys),
|
std::transform(
|
||||||
[](std::pair<std::string const&, Json const&> const& kv) {
|
parameter.begin(), parameter.end(), std::back_inserter(keys),
|
||||||
return kv.first;
|
[](std::pair<std::string const&, Json const&> const& kv) { return kv.first; });
|
||||||
});
|
|
||||||
} else if (IsA<Object>(kv.second)) {
|
} else if (IsA<Object>(kv.second)) {
|
||||||
stack.push(kv.second);
|
stack.push(kv.second);
|
||||||
} else if (kv.first == "metrics") {
|
} else if (IsA<Array>(kv.second)) {
|
||||||
auto const& array = get<Array const>(kv.second);
|
auto const& array = get<Array const>(kv.second);
|
||||||
for (auto const& v : array) {
|
for (auto const& v : array) {
|
||||||
|
if (IsA<Object>(v) || IsA<Array>(v)) {
|
||||||
stack.push(v);
|
stack.push(v);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// FIXME(trivialfis): Make eval_metric a training parameter.
|
// FIXME(trivialfis): Make eval_metric a training parameter.
|
||||||
keys.emplace_back(kEvalMetric);
|
keys.emplace_back(kEvalMetric);
|
||||||
|
|||||||
@ -32,6 +32,7 @@ class LintersPaths:
|
|||||||
"tests/test_distributed/test_with_spark/",
|
"tests/test_distributed/test_with_spark/",
|
||||||
"tests/test_distributed/test_gpu_with_spark/",
|
"tests/test_distributed/test_gpu_with_spark/",
|
||||||
# demo
|
# demo
|
||||||
|
"demo/dask/",
|
||||||
"demo/json-model/json_parser.py",
|
"demo/json-model/json_parser.py",
|
||||||
"demo/guide-python/cat_in_the_dat.py",
|
"demo/guide-python/cat_in_the_dat.py",
|
||||||
"demo/guide-python/categorical.py",
|
"demo/guide-python/categorical.py",
|
||||||
@ -42,6 +43,8 @@ class LintersPaths:
|
|||||||
"demo/guide-python/quantile_regression.py",
|
"demo/guide-python/quantile_regression.py",
|
||||||
"demo/guide-python/multioutput_regression.py",
|
"demo/guide-python/multioutput_regression.py",
|
||||||
"demo/guide-python/learning_to_rank.py",
|
"demo/guide-python/learning_to_rank.py",
|
||||||
|
"demo/guide-python/quantile_data_iterator.py",
|
||||||
|
"demo/guide-python/update_process.py",
|
||||||
"demo/aft_survival/aft_survival_viz_demo.py",
|
"demo/aft_survival/aft_survival_viz_demo.py",
|
||||||
# CI
|
# CI
|
||||||
"tests/ci_build/lint_python.py",
|
"tests/ci_build/lint_python.py",
|
||||||
|
|||||||
@ -322,3 +322,15 @@ class TestQuantileDMatrix:
|
|||||||
X: np.ndarray = np.array(orig, dtype=dtype)
|
X: np.ndarray = np.array(orig, dtype=dtype)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
xgb.QuantileDMatrix(X)
|
xgb.QuantileDMatrix(X)
|
||||||
|
|
||||||
|
def test_changed_max_bin(self) -> None:
|
||||||
|
n_samples = 128
|
||||||
|
n_features = 16
|
||||||
|
csr, y = make_sparse_regression(n_samples, n_features, 0.5, False)
|
||||||
|
Xy = xgb.QuantileDMatrix(csr, y, max_bin=9)
|
||||||
|
booster = xgb.train({"max_bin": 9}, Xy, num_boost_round=2)
|
||||||
|
|
||||||
|
Xy = xgb.QuantileDMatrix(csr, y, max_bin=11)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="consistent"):
|
||||||
|
xgb.train({}, Xy, num_boost_round=2, xgb_model=booster)
|
||||||
|
|||||||
@ -27,7 +27,7 @@ def train_result(param, dmat, num_rounds):
|
|||||||
param,
|
param,
|
||||||
dmat,
|
dmat,
|
||||||
num_rounds,
|
num_rounds,
|
||||||
[(dmat, "train")],
|
evals=[(dmat, "train")],
|
||||||
verbose_eval=False,
|
verbose_eval=False,
|
||||||
evals_result=result,
|
evals_result=result,
|
||||||
)
|
)
|
||||||
@ -169,13 +169,21 @@ class TestTreeMethod:
|
|||||||
hist_res = {}
|
hist_res = {}
|
||||||
exact_res = {}
|
exact_res = {}
|
||||||
|
|
||||||
xgb.train(ag_param, ag_dtrain, 10,
|
xgb.train(
|
||||||
[(ag_dtrain, 'train'), (ag_dtest, 'test')],
|
ag_param,
|
||||||
evals_result=hist_res)
|
ag_dtrain,
|
||||||
|
10,
|
||||||
|
evals=[(ag_dtrain, "train"), (ag_dtest, "test")],
|
||||||
|
evals_result=hist_res
|
||||||
|
)
|
||||||
ag_param["tree_method"] = "exact"
|
ag_param["tree_method"] = "exact"
|
||||||
xgb.train(ag_param, ag_dtrain, 10,
|
xgb.train(
|
||||||
[(ag_dtrain, 'train'), (ag_dtest, 'test')],
|
ag_param,
|
||||||
evals_result=exact_res)
|
ag_dtrain,
|
||||||
|
10,
|
||||||
|
evals=[(ag_dtrain, "train"), (ag_dtest, "test")],
|
||||||
|
evals_result=exact_res
|
||||||
|
)
|
||||||
assert hist_res['train']['auc'] == exact_res['train']['auc']
|
assert hist_res['train']['auc'] == exact_res['train']['auc']
|
||||||
assert hist_res['test']['auc'] == exact_res['test']['auc']
|
assert hist_res['test']['auc'] == exact_res['test']['auc']
|
||||||
|
|
||||||
|
|||||||
@ -1349,10 +1349,11 @@ def test_multilabel_classification() -> None:
|
|||||||
np.testing.assert_allclose(clf.predict(X), predt)
|
np.testing.assert_allclose(clf.predict(X), predt)
|
||||||
|
|
||||||
|
|
||||||
def test_data_initialization():
|
def test_data_initialization() -> None:
|
||||||
from sklearn.datasets import load_digits
|
from sklearn.datasets import load_digits
|
||||||
|
|
||||||
X, y = load_digits(return_X_y=True)
|
X, y = load_digits(return_X_y=True)
|
||||||
validate_data_initialization(xgb.DMatrix, xgb.XGBClassifier, X, y)
|
validate_data_initialization(xgb.QuantileDMatrix, xgb.XGBClassifier, X, y)
|
||||||
|
|
||||||
|
|
||||||
@parametrize_with_checks([xgb.XGBRegressor()])
|
@parametrize_with_checks([xgb.XGBRegressor()])
|
||||||
|
|||||||
@ -1,10 +1,9 @@
|
|||||||
"""Copyright 2019-2022 XGBoost contributors"""
|
"""Copyright 2019-2022 XGBoost contributors"""
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import json
|
||||||
import subprocess
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from typing import Any, Dict, Type, TypeVar, Union
|
from typing import Any, Dict, Type, TypeVar
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
@ -64,7 +63,7 @@ def run_with_dask_dataframe(DMatrixT: Type, client: Client) -> None:
|
|||||||
dtrain = DMatrixT(client, X, y)
|
dtrain = DMatrixT(client, X, y)
|
||||||
out = dxgb.train(
|
out = dxgb.train(
|
||||||
client,
|
client,
|
||||||
{"tree_method": "gpu_hist", "debug_synchronize": True},
|
{"tree_method": "hist", "debug_synchronize": True, "device": "cuda"},
|
||||||
dtrain=dtrain,
|
dtrain=dtrain,
|
||||||
evals=[(dtrain, "X")],
|
evals=[(dtrain, "X")],
|
||||||
num_boost_round=4,
|
num_boost_round=4,
|
||||||
@ -116,12 +115,18 @@ def run_with_dask_array(DMatrixT: Type, client: Client) -> None:
|
|||||||
dtrain = DMatrixT(client, X, y)
|
dtrain = DMatrixT(client, X, y)
|
||||||
out = dxgb.train(
|
out = dxgb.train(
|
||||||
client,
|
client,
|
||||||
{"tree_method": "gpu_hist", "debug_synchronize": True},
|
{"tree_method": "hist", "debug_synchronize": True, "device": "cuda"},
|
||||||
dtrain=dtrain,
|
dtrain=dtrain,
|
||||||
evals=[(dtrain, "X")],
|
evals=[(dtrain, "X")],
|
||||||
num_boost_round=2,
|
num_boost_round=2,
|
||||||
)
|
)
|
||||||
from_dmatrix = dxgb.predict(client, out, dtrain).compute()
|
from_dmatrix = dxgb.predict(client, out, dtrain).compute()
|
||||||
|
assert (
|
||||||
|
json.loads(out["booster"].save_config())["learner"]["gradient_booster"][
|
||||||
|
"updater"
|
||||||
|
][0]["name"]
|
||||||
|
== "grow_gpu_hist"
|
||||||
|
)
|
||||||
inplace_predictions = dxgb.inplace_predict(client, out, X).compute()
|
inplace_predictions = dxgb.inplace_predict(client, out, X).compute()
|
||||||
single_node = out["booster"].predict(xgb.DMatrix(X.compute()))
|
single_node = out["booster"].predict(xgb.DMatrix(X.compute()))
|
||||||
np.testing.assert_allclose(single_node, from_dmatrix)
|
np.testing.assert_allclose(single_node, from_dmatrix)
|
||||||
@ -149,7 +154,8 @@ def run_gpu_hist(
|
|||||||
DMatrixT: Type,
|
DMatrixT: Type,
|
||||||
client: Client,
|
client: Client,
|
||||||
) -> None:
|
) -> None:
|
||||||
params["tree_method"] = "gpu_hist"
|
params["tree_method"] = "hist"
|
||||||
|
params["device"] = "cuda"
|
||||||
params = dataset.set_params(params)
|
params = dataset.set_params(params)
|
||||||
# It doesn't make sense to distribute a completely
|
# It doesn't make sense to distribute a completely
|
||||||
# empty dataset.
|
# empty dataset.
|
||||||
@ -196,11 +202,11 @@ def run_gpu_hist(
|
|||||||
def test_tree_stats() -> None:
|
def test_tree_stats() -> None:
|
||||||
with LocalCUDACluster(n_workers=1) as cluster:
|
with LocalCUDACluster(n_workers=1) as cluster:
|
||||||
with Client(cluster) as client:
|
with Client(cluster) as client:
|
||||||
local = run_tree_stats(client, "gpu_hist")
|
local = run_tree_stats(client, "hist", "cuda")
|
||||||
|
|
||||||
with LocalCUDACluster(n_workers=2) as cluster:
|
with LocalCUDACluster(n_workers=2) as cluster:
|
||||||
with Client(cluster) as client:
|
with Client(cluster) as client:
|
||||||
distributed = run_tree_stats(client, "gpu_hist")
|
distributed = run_tree_stats(client, "hist", "cuda")
|
||||||
|
|
||||||
assert local == distributed
|
assert local == distributed
|
||||||
|
|
||||||
@ -214,12 +220,12 @@ class TestDistributedGPU:
|
|||||||
X_, y_ = load_breast_cancer(return_X_y=True)
|
X_, y_ = load_breast_cancer(return_X_y=True)
|
||||||
X = dd.from_array(X_, chunksize=100).map_partitions(cudf.from_pandas)
|
X = dd.from_array(X_, chunksize=100).map_partitions(cudf.from_pandas)
|
||||||
y = dd.from_array(y_, chunksize=100).map_partitions(cudf.from_pandas)
|
y = dd.from_array(y_, chunksize=100).map_partitions(cudf.from_pandas)
|
||||||
run_boost_from_prediction(X, y, "gpu_hist", local_cuda_client)
|
run_boost_from_prediction(X, y, "hist", "cuda", local_cuda_client)
|
||||||
|
|
||||||
X_, y_ = load_iris(return_X_y=True)
|
X_, y_ = load_iris(return_X_y=True)
|
||||||
X = dd.from_array(X_, chunksize=50).map_partitions(cudf.from_pandas)
|
X = dd.from_array(X_, chunksize=50).map_partitions(cudf.from_pandas)
|
||||||
y = dd.from_array(y_, chunksize=50).map_partitions(cudf.from_pandas)
|
y = dd.from_array(y_, chunksize=50).map_partitions(cudf.from_pandas)
|
||||||
run_boost_from_prediction_multi_class(X, y, "gpu_hist", local_cuda_client)
|
run_boost_from_prediction_multi_class(X, y, "hist", "cuda", local_cuda_client)
|
||||||
|
|
||||||
def test_init_estimation(self, local_cuda_client: Client) -> None:
|
def test_init_estimation(self, local_cuda_client: Client) -> None:
|
||||||
check_init_estimation("gpu_hist", local_cuda_client)
|
check_init_estimation("gpu_hist", local_cuda_client)
|
||||||
@ -282,7 +288,7 @@ class TestDistributedGPU:
|
|||||||
)
|
)
|
||||||
result = xgb.dask.train(
|
result = xgb.dask.train(
|
||||||
client,
|
client,
|
||||||
{"tree_method": "gpu_hist"},
|
{"tree_method": "hist", "device": "cuda", "debug_synchronize": True},
|
||||||
Xy,
|
Xy,
|
||||||
num_boost_round=10,
|
num_boost_round=10,
|
||||||
evals=[(Xy_valid, "Valid")],
|
evals=[(Xy_valid, "Valid")],
|
||||||
@ -313,7 +319,8 @@ class TestDistributedGPU:
|
|||||||
{
|
{
|
||||||
"objective": "binary:logistic",
|
"objective": "binary:logistic",
|
||||||
"eval_metric": "error",
|
"eval_metric": "error",
|
||||||
"tree_method": "gpu_hist",
|
"tree_method": "hist",
|
||||||
|
"device": "cuda",
|
||||||
},
|
},
|
||||||
m,
|
m,
|
||||||
evals=[(valid, "Valid")],
|
evals=[(valid, "Valid")],
|
||||||
@ -328,7 +335,8 @@ class TestDistributedGPU:
|
|||||||
valid_y = y
|
valid_y = y
|
||||||
cls = dxgb.DaskXGBClassifier(
|
cls = dxgb.DaskXGBClassifier(
|
||||||
objective="binary:logistic",
|
objective="binary:logistic",
|
||||||
tree_method="gpu_hist",
|
tree_method="hist",
|
||||||
|
device="cuda",
|
||||||
eval_metric="error",
|
eval_metric="error",
|
||||||
n_estimators=100,
|
n_estimators=100,
|
||||||
)
|
)
|
||||||
@ -356,7 +364,11 @@ class TestDistributedGPU:
|
|||||||
run_dask_classifier(X, y, w, model, "gpu_hist", local_cuda_client, 10)
|
run_dask_classifier(X, y, w, model, "gpu_hist", local_cuda_client, 10)
|
||||||
|
|
||||||
def test_empty_dmatrix(self, local_cuda_client: Client) -> None:
|
def test_empty_dmatrix(self, local_cuda_client: Client) -> None:
|
||||||
parameters = {"tree_method": "gpu_hist", "debug_synchronize": True}
|
parameters = {
|
||||||
|
"tree_method": "hist",
|
||||||
|
"debug_synchronize": True,
|
||||||
|
"device": "cuda",
|
||||||
|
}
|
||||||
run_empty_dmatrix_reg(local_cuda_client, parameters)
|
run_empty_dmatrix_reg(local_cuda_client, parameters)
|
||||||
run_empty_dmatrix_cls(local_cuda_client, parameters)
|
run_empty_dmatrix_cls(local_cuda_client, parameters)
|
||||||
|
|
||||||
@ -374,7 +386,11 @@ class TestDistributedGPU:
|
|||||||
"y": [10, 20, 30, 40.0, 50] * mult,
|
"y": [10, 20, 30, 40.0, 50] * mult,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
parameters = {"tree_method": "gpu_hist", "debug_synchronize": True}
|
parameters = {
|
||||||
|
"tree_method": "hist",
|
||||||
|
"debug_synchronize": True,
|
||||||
|
"device": "cuda",
|
||||||
|
}
|
||||||
|
|
||||||
empty = df.iloc[:0]
|
empty = df.iloc[:0]
|
||||||
ddf = dask_cudf.concat(
|
ddf = dask_cudf.concat(
|
||||||
@ -432,13 +448,25 @@ class TestDistributedGPU:
|
|||||||
|
|
||||||
def test_empty_dmatrix_auc(self, local_cuda_client: Client) -> None:
|
def test_empty_dmatrix_auc(self, local_cuda_client: Client) -> None:
|
||||||
n_workers = len(tm.get_client_workers(local_cuda_client))
|
n_workers = len(tm.get_client_workers(local_cuda_client))
|
||||||
run_empty_dmatrix_auc(local_cuda_client, "gpu_hist", n_workers)
|
run_empty_dmatrix_auc(local_cuda_client, "cuda", n_workers)
|
||||||
|
|
||||||
def test_auc(self, local_cuda_client: Client) -> None:
|
def test_auc(self, local_cuda_client: Client) -> None:
|
||||||
run_auc(local_cuda_client, "gpu_hist")
|
run_auc(local_cuda_client, "cuda")
|
||||||
|
|
||||||
|
def test_invalid_ordinal(self, local_cuda_client: Client) -> None:
|
||||||
|
"""One should not specify the device ordinal with dask."""
|
||||||
|
with pytest.raises(ValueError, match="device=cuda"):
|
||||||
|
X, y, _ = generate_array()
|
||||||
|
m = dxgb.DaskDMatrix(local_cuda_client, X, y)
|
||||||
|
dxgb.train(local_cuda_client, {"device": "cuda:0"}, m)
|
||||||
|
|
||||||
|
booster = dxgb.train(local_cuda_client, {"device": "cuda"}, m)["booster"]
|
||||||
|
assert (
|
||||||
|
json.loads(booster.save_config())["learner"]["generic_param"]["device"]
|
||||||
|
== "cuda:0"
|
||||||
|
)
|
||||||
|
|
||||||
def test_data_initialization(self, local_cuda_client: Client) -> None:
|
def test_data_initialization(self, local_cuda_client: Client) -> None:
|
||||||
|
|
||||||
X, y, _ = generate_array()
|
X, y, _ = generate_array()
|
||||||
fw = da.random.random((random_cols,))
|
fw = da.random.random((random_cols,))
|
||||||
fw = fw - fw.min()
|
fw = fw - fw.min()
|
||||||
@ -531,7 +559,9 @@ async def run_from_dask_array_asyncio(scheduler_address: str) -> dxgb.TrainRetur
|
|||||||
y = y.map_blocks(cp.array)
|
y = y.map_blocks(cp.array)
|
||||||
|
|
||||||
m = await xgb.dask.DaskQuantileDMatrix(client, X, y)
|
m = await xgb.dask.DaskQuantileDMatrix(client, X, y)
|
||||||
output = await xgb.dask.train(client, {"tree_method": "gpu_hist"}, dtrain=m)
|
output = await xgb.dask.train(
|
||||||
|
client, {"tree_method": "hist", "device": "cuda"}, dtrain=m
|
||||||
|
)
|
||||||
|
|
||||||
with_m = await xgb.dask.predict(client, output, m)
|
with_m = await xgb.dask.predict(client, output, m)
|
||||||
with_X = await xgb.dask.predict(client, output, X)
|
with_X = await xgb.dask.predict(client, output, X)
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -1120,7 +1120,9 @@ class XgboostLocalTest(SparkTestCase):
|
|||||||
reg1 = SparkXGBRegressor(**self.reg_params)
|
reg1 = SparkXGBRegressor(**self.reg_params)
|
||||||
model = reg1.fit(self.reg_df_train)
|
model = reg1.fit(self.reg_df_train)
|
||||||
init_booster = model.get_booster()
|
init_booster = model.get_booster()
|
||||||
reg2 = SparkXGBRegressor(max_depth=2, n_estimators=2, xgb_model=init_booster)
|
reg2 = SparkXGBRegressor(
|
||||||
|
max_depth=2, n_estimators=2, xgb_model=init_booster, max_bin=21
|
||||||
|
)
|
||||||
model21 = reg2.fit(self.reg_df_train)
|
model21 = reg2.fit(self.reg_df_train)
|
||||||
pred_res21 = model21.transform(self.reg_df_test).collect()
|
pred_res21 = model21.transform(self.reg_df_test).collect()
|
||||||
reg2.save(path)
|
reg2.save(path)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user