[dask] Change document to avoid using default import. (#9742)
This aligns dask with pyspark, users need to explicitly call: ``` from xgboost.dask import DaskXGBClassifier from xgboost import dask as dxgb ``` In future releases, we might stop using the default import and remove the lazy loader.
This commit is contained in:
parent
093b675838
commit
98238d63fa
@ -9,7 +9,7 @@ import os
|
||||
import dask.dataframe as dd
|
||||
from dask.distributed import Client, LocalCluster
|
||||
|
||||
import xgboost as xgb
|
||||
from xgboost import dask as dxgb
|
||||
from xgboost.dask import DaskDMatrix
|
||||
|
||||
|
||||
@ -48,14 +48,14 @@ def main(client):
|
||||
"lambda": 0.01,
|
||||
"alpha": 0.02,
|
||||
}
|
||||
output = xgb.dask.train(
|
||||
output = dxgb.train(
|
||||
client, params, dtrain, num_boost_round=100, evals=[(dtrain, "train")]
|
||||
)
|
||||
bst = output["booster"]
|
||||
history = output["history"]
|
||||
|
||||
# you can pass output directly into `predict` too.
|
||||
prediction = xgb.dask.predict(client, bst, dtrain)
|
||||
prediction = dxgb.predict(client, bst, dtrain)
|
||||
print("Evaluation history: ", history)
|
||||
|
||||
# Uncomment the following line to save the model to the disk
|
||||
|
||||
@ -6,7 +6,7 @@ Example of training with Dask on CPU
|
||||
from dask import array as da
|
||||
from dask.distributed import Client, LocalCluster
|
||||
|
||||
import xgboost as xgb
|
||||
from xgboost import dask as dxgb
|
||||
from xgboost.dask import DaskDMatrix
|
||||
|
||||
|
||||
@ -25,7 +25,7 @@ def main(client):
|
||||
# distributed version of train returns a dictionary containing the
|
||||
# resulting booster and evaluation history obtained from
|
||||
# evaluation metrics.
|
||||
output = xgb.dask.train(
|
||||
output = dxgb.train(
|
||||
client,
|
||||
{"verbosity": 1, "tree_method": "hist"},
|
||||
dtrain,
|
||||
@ -36,7 +36,7 @@ def main(client):
|
||||
history = output["history"]
|
||||
|
||||
# you can pass output directly into `predict` too.
|
||||
prediction = xgb.dask.predict(client, bst, dtrain)
|
||||
prediction = dxgb.predict(client, bst, dtrain)
|
||||
print("Evaluation history:", history)
|
||||
return prediction
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@ from dask_ml.datasets import make_regression
|
||||
from dask_ml.model_selection import train_test_split
|
||||
|
||||
import xgboost as xgb
|
||||
import xgboost.dask as dxgb
|
||||
from xgboost.dask import DaskDMatrix
|
||||
|
||||
|
||||
@ -61,7 +62,7 @@ def main(client):
|
||||
dtrain = DaskDMatrix(client, X_train, y_train)
|
||||
dtest = DaskDMatrix(client, X_test, y_test)
|
||||
|
||||
output = xgb.dask.train(
|
||||
output = dxgb.train(
|
||||
client,
|
||||
{
|
||||
"verbosity": 1,
|
||||
|
||||
@ -8,7 +8,6 @@ from dask import dataframe as dd
|
||||
from dask.distributed import Client
|
||||
from dask_cuda import LocalCUDACluster
|
||||
|
||||
import xgboost as xgb
|
||||
from xgboost import dask as dxgb
|
||||
from xgboost.dask import DaskDMatrix
|
||||
|
||||
@ -21,7 +20,7 @@ def using_dask_matrix(client: Client, X: da.Array, y: da.Array) -> da.Array:
|
||||
# Use train method from xgboost.dask instead of xgboost. This distributed version
|
||||
# of train returns a dictionary containing the resulting booster and evaluation
|
||||
# history obtained from evaluation metrics.
|
||||
output = xgb.dask.train(
|
||||
output = dxgb.train(
|
||||
client,
|
||||
{
|
||||
"verbosity": 2,
|
||||
@ -37,7 +36,7 @@ def using_dask_matrix(client: Client, X: da.Array, y: da.Array) -> da.Array:
|
||||
history = output["history"]
|
||||
|
||||
# you can pass output directly into `predict` too.
|
||||
prediction = xgb.dask.predict(client, bst, dtrain)
|
||||
prediction = dxgb.predict(client, bst, dtrain)
|
||||
print("Evaluation history:", history)
|
||||
return prediction
|
||||
|
||||
@ -56,14 +55,14 @@ def using_quantile_device_dmatrix(client: Client, X: da.Array, y: da.Array) -> d
|
||||
# be used for anything else other than training unless a reference is specified. See
|
||||
# the `ref` argument of `DaskQuantileDMatrix`.
|
||||
dtrain = dxgb.DaskQuantileDMatrix(client, X, y)
|
||||
output = xgb.dask.train(
|
||||
output = dxgb.train(
|
||||
client,
|
||||
{"verbosity": 2, "tree_method": "hist", "device": "cuda"},
|
||||
dtrain,
|
||||
num_boost_round=4,
|
||||
)
|
||||
|
||||
prediction = xgb.dask.predict(client, output, X)
|
||||
prediction = dxgb.predict(client, output, X)
|
||||
return prediction
|
||||
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ Use scikit-learn regressor interface with CPU histogram tree method
|
||||
from dask import array as da
|
||||
from dask.distributed import Client, LocalCluster
|
||||
|
||||
import xgboost
|
||||
from xgboost import dask as dxgb
|
||||
|
||||
|
||||
def main(client):
|
||||
@ -16,7 +16,7 @@ def main(client):
|
||||
X = da.random.random((m, n), partition_size)
|
||||
y = da.random.random(m, partition_size)
|
||||
|
||||
regressor = xgboost.dask.DaskXGBRegressor(verbosity=1, n_estimators=2)
|
||||
regressor = dxgb.DaskXGBRegressor(verbosity=1, n_estimators=2)
|
||||
regressor.set_params(tree_method="hist")
|
||||
# assigning client here is optional
|
||||
regressor.client = client
|
||||
|
||||
@ -9,7 +9,7 @@ from dask.distributed import Client
|
||||
# It's recommended to use dask_cuda for GPU assignment
|
||||
from dask_cuda import LocalCUDACluster
|
||||
|
||||
import xgboost
|
||||
from xgboost import dask as dxgb
|
||||
|
||||
|
||||
def main(client):
|
||||
@ -20,7 +20,7 @@ def main(client):
|
||||
X = da.random.random((m, n), partition_size)
|
||||
y = da.random.random(m, partition_size)
|
||||
|
||||
regressor = xgboost.dask.DaskXGBRegressor(verbosity=1)
|
||||
regressor = dxgb.DaskXGBRegressor(verbosity=1)
|
||||
# set the device to CUDA
|
||||
regressor.set_params(tree_method="hist", device="cuda")
|
||||
# assigning client here is optional
|
||||
|
||||
@ -39,7 +39,8 @@ on a dask cluster:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import xgboost as xgb
|
||||
from xgboost import dask as dxgb
|
||||
|
||||
import dask.array as da
|
||||
import dask.distributed
|
||||
|
||||
@ -53,11 +54,11 @@ on a dask cluster:
|
||||
X = da.random.random(size=(num_obs, num_features), chunks=(1000, num_features))
|
||||
y = da.random.random(size=(num_obs, 1), chunks=(1000, 1))
|
||||
|
||||
dtrain = xgb.dask.DaskDMatrix(client, X, y)
|
||||
dtrain = dxgb.DaskDMatrix(client, X, y)
|
||||
# or
|
||||
# dtrain = xgb.dask.DaskQuantileDMatrix(client, X, y)
|
||||
# dtrain = dxgb.DaskQuantileDMatrix(client, X, y)
|
||||
|
||||
output = xgb.dask.train(
|
||||
output = dxgb.train(
|
||||
client,
|
||||
{"verbosity": 2, "tree_method": "hist", "objective": "reg:squarederror"},
|
||||
dtrain,
|
||||
@ -87,25 +88,27 @@ returns a model and the computation history as a Python dictionary:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{'booster': Booster,
|
||||
'history': dict}
|
||||
{
|
||||
"booster": Booster,
|
||||
"history": dict,
|
||||
}
|
||||
|
||||
For prediction, pass the ``output`` returned by ``train`` into :py:func:`xgboost.dask.predict`:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
prediction = xgb.dask.predict(client, output, dtrain)
|
||||
prediction = dxgb.predict(client, output, dtrain)
|
||||
# Or equivalently, pass ``output['booster']``:
|
||||
prediction = xgb.dask.predict(client, output['booster'], dtrain)
|
||||
prediction = dxgb.predict(client, output['booster'], dtrain)
|
||||
|
||||
Eliminating the construction of DaskDMatrix is also possible, this can make the
|
||||
computation a bit faster when meta information like ``base_margin`` is not needed:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
prediction = xgb.dask.predict(client, output, X)
|
||||
prediction = dxgb.predict(client, output, X)
|
||||
# Use inplace version.
|
||||
prediction = xgb.dask.inplace_predict(client, output, X)
|
||||
prediction = dxgb.inplace_predict(client, output, X)
|
||||
|
||||
Here ``prediction`` is a dask ``Array`` object containing predictions from model if input
|
||||
is a ``DaskDMatrix`` or ``da.Array``. When putting dask collection directly into the
|
||||
@ -134,14 +137,14 @@ both memory usage and prediction time.
|
||||
.. code-block:: python
|
||||
|
||||
# dtrain is the DaskDMatrix defined above.
|
||||
prediction = xgb.dask.predict(client, booster, dtrain)
|
||||
prediction = dxgb.predict(client, booster, dtrain)
|
||||
|
||||
or equivalently:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# where X is a dask DataFrame or dask Array.
|
||||
prediction = xgb.dask.predict(client, booster, X)
|
||||
prediction = dxgb.predict(client, booster, X)
|
||||
|
||||
Also for inplace prediction:
|
||||
|
||||
@ -149,7 +152,7 @@ Also for inplace prediction:
|
||||
|
||||
# where X is a dask DataFrame or dask Array backed by cupy or cuDF.
|
||||
booster.set_param({"device": "cuda"})
|
||||
prediction = xgb.dask.inplace_predict(client, booster, X)
|
||||
prediction = dxgb.inplace_predict(client, booster, X)
|
||||
|
||||
When input is ``da.Array`` object, output is always ``da.Array``. However, if the input
|
||||
type is ``dd.DataFrame``, output can be ``dd.Series``, ``dd.DataFrame`` or ``da.Array``,
|
||||
@ -174,7 +177,7 @@ One simple optimization for running consecutive predictions is using
|
||||
futures = []
|
||||
for X in dataset:
|
||||
# Here we pass in a future instead of concrete booster
|
||||
shap_f = xgb.dask.predict(client, booster_f, X, pred_contribs=True)
|
||||
shap_f = dxgb.predict(client, booster_f, X, pred_contribs=True)
|
||||
futures.append(shap_f)
|
||||
|
||||
results = client.gather(futures)
|
||||
@ -186,7 +189,7 @@ Scikit-Learn wrapper object:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
cls = xgb.dask.DaskXGBClassifier()
|
||||
cls = dxgb.DaskXGBClassifier()
|
||||
cls.fit(X, y)
|
||||
|
||||
booster = cls.get_booster()
|
||||
@ -207,12 +210,12 @@ collection.
|
||||
.. code-block:: python
|
||||
|
||||
from distributed import LocalCluster, Client
|
||||
import xgboost as xgb
|
||||
from xgboost import dask as dxgb
|
||||
|
||||
|
||||
def main(client: Client) -> None:
|
||||
X, y = load_data()
|
||||
clf = xgb.dask.DaskXGBClassifier(n_estimators=100, tree_method="hist")
|
||||
clf = dxgb.DaskXGBClassifier(n_estimators=100, tree_method="hist")
|
||||
clf.client = client # assign the client
|
||||
clf.fit(X, y, eval_set=[(X, y)])
|
||||
proba = clf.predict_proba(X)
|
||||
@ -242,7 +245,7 @@ In the example below, a ``KubeCluster`` is used for `deploying Dask on Kubernete
|
||||
|
||||
from dask_kubernetes import KubeCluster # Need to install the ``dask-kubernetes`` package
|
||||
from dask.distributed import Client
|
||||
import xgboost as xgb
|
||||
from xgboost import dask as dxgb
|
||||
import dask
|
||||
import dask.array as da
|
||||
|
||||
@ -265,7 +268,7 @@ In the example below, a ``KubeCluster`` is used for `deploying Dask on Kubernete
|
||||
X = da.random.random(size=(m, n), chunks=100)
|
||||
y = da.random.random(size=(m, ), chunks=100)
|
||||
|
||||
regressor = xgb.dask.DaskXGBRegressor(n_estimators=10, missing=0.0)
|
||||
regressor = dxgb.DaskXGBRegressor(n_estimators=10, missing=0.0)
|
||||
regressor.client = client
|
||||
regressor.set_params(tree_method='hist', device="cuda")
|
||||
regressor.fit(X, y, eval_set=[(X, y)])
|
||||
@ -298,7 +301,7 @@ threads in each process for training. But if ``nthread`` parameter is set:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
output = xgb.dask.train(
|
||||
output = dxgb.train(
|
||||
client,
|
||||
{"verbosity": 1, "nthread": 8, "tree_method": "hist"},
|
||||
dtrain,
|
||||
@ -330,12 +333,12 @@ Functional interface:
|
||||
|
||||
async with dask.distributed.Client(scheduler_address, asynchronous=True) as client:
|
||||
X, y = generate_array()
|
||||
m = await xgb.dask.DaskDMatrix(client, X, y)
|
||||
output = await xgb.dask.train(client, {}, dtrain=m)
|
||||
m = await dxgb.DaskDMatrix(client, X, y)
|
||||
output = await dxgb.train(client, {}, dtrain=m)
|
||||
|
||||
with_m = await xgb.dask.predict(client, output, m)
|
||||
with_X = await xgb.dask.predict(client, output, X)
|
||||
inplace = await xgb.dask.inplace_predict(client, output, X)
|
||||
with_m = await dxgb.predict(client, output, m)
|
||||
with_X = await dxgb.predict(client, output, X)
|
||||
inplace = await dxgb.inplace_predict(client, output, X)
|
||||
|
||||
# Use ``client.compute`` instead of the ``compute`` method from dask collection
|
||||
print(await client.compute(with_m))
|
||||
@ -349,7 +352,7 @@ actual computation will return a coroutine and hence require awaiting:
|
||||
|
||||
async with dask.distributed.Client(scheduler_address, asynchronous=True) as client:
|
||||
X, y = generate_array()
|
||||
regressor = await xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2)
|
||||
regressor = await dxgb.DaskXGBRegressor(verbosity=1, n_estimators=2)
|
||||
regressor.set_params(tree_method='hist') # trivial method, synchronous operation
|
||||
regressor.client = client # accessing attribute, synchronous operation
|
||||
regressor = await regressor.fit(X, y, eval_set=[(X, y)])
|
||||
@ -371,7 +374,7 @@ To enable early stopping, pass one or more validation sets containing ``DaskDMat
|
||||
.. code-block:: python
|
||||
|
||||
import dask.array as da
|
||||
import xgboost as xgb
|
||||
from xgboost import dask as dxgb
|
||||
|
||||
num_rows = 1e6
|
||||
num_features = 100
|
||||
@ -398,19 +401,19 @@ To enable early stopping, pass one or more validation sets containing ``DaskDMat
|
||||
chunks=(rows_per_chunk, 1)
|
||||
)
|
||||
|
||||
dtrain = xgb.dask.DaskDMatrix(
|
||||
dtrain = dxgb.DaskDMatrix(
|
||||
client=client,
|
||||
data=data,
|
||||
label=labels
|
||||
)
|
||||
|
||||
dvalid = xgb.dask.DaskDMatrix(
|
||||
dvalid = dxgb.DaskDMatrix(
|
||||
client=client,
|
||||
data=X_eval,
|
||||
label=y_eval
|
||||
)
|
||||
|
||||
result = xgb.dask.train(
|
||||
result = dxgb.train(
|
||||
client=client,
|
||||
params={
|
||||
"objective": "reg:squarederror",
|
||||
@ -421,7 +424,7 @@ To enable early stopping, pass one or more validation sets containing ``DaskDMat
|
||||
early_stopping_rounds=3
|
||||
)
|
||||
|
||||
When validation sets are provided to ``xgb.dask.train()`` in this way, the model object returned by ``xgb.dask.train()`` contains a history of evaluation metrics for each validation set, across all boosting rounds.
|
||||
When validation sets are provided to :py:func:`xgboost.dask.train` in this way, the model object returned by :py:func:`xgboost.dask.train` contains a history of evaluation metrics for each validation set, across all boosting rounds.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@ -463,7 +466,7 @@ interface, including callback functions, custom evaluation metric and objective:
|
||||
save_best=True,
|
||||
)
|
||||
|
||||
booster = xgb.dask.train(
|
||||
booster = dxgb.train(
|
||||
client,
|
||||
params={
|
||||
"objective": "binary:logistic",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user