[dask] Change document to avoid using default import. (#9742)

This aligns dask with pyspark, users need to explicitly call:

```
from xgboost.dask import DaskXGBClassifier
from xgboost import dask as dxgb
```

In future releases, we might stop using the default import and remove the lazy loader.
This commit is contained in:
Jiaming Yuan 2023-11-07 02:44:39 +08:00 committed by GitHub
parent 093b675838
commit 98238d63fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 51 additions and 48 deletions

View File

@ -9,7 +9,7 @@ import os
import dask.dataframe as dd import dask.dataframe as dd
from dask.distributed import Client, LocalCluster from dask.distributed import Client, LocalCluster
import xgboost as xgb from xgboost import dask as dxgb
from xgboost.dask import DaskDMatrix from xgboost.dask import DaskDMatrix
@ -48,14 +48,14 @@ def main(client):
"lambda": 0.01, "lambda": 0.01,
"alpha": 0.02, "alpha": 0.02,
} }
output = xgb.dask.train( output = dxgb.train(
client, params, dtrain, num_boost_round=100, evals=[(dtrain, "train")] client, params, dtrain, num_boost_round=100, evals=[(dtrain, "train")]
) )
bst = output["booster"] bst = output["booster"]
history = output["history"] history = output["history"]
# you can pass output directly into `predict` too. # you can pass output directly into `predict` too.
prediction = xgb.dask.predict(client, bst, dtrain) prediction = dxgb.predict(client, bst, dtrain)
print("Evaluation history: ", history) print("Evaluation history: ", history)
# Uncomment the following line to save the model to the disk # Uncomment the following line to save the model to the disk

View File

@ -6,7 +6,7 @@ Example of training with Dask on CPU
from dask import array as da from dask import array as da
from dask.distributed import Client, LocalCluster from dask.distributed import Client, LocalCluster
import xgboost as xgb from xgboost import dask as dxgb
from xgboost.dask import DaskDMatrix from xgboost.dask import DaskDMatrix
@ -25,7 +25,7 @@ def main(client):
# distributed version of train returns a dictionary containing the # distributed version of train returns a dictionary containing the
# resulting booster and evaluation history obtained from # resulting booster and evaluation history obtained from
# evaluation metrics. # evaluation metrics.
output = xgb.dask.train( output = dxgb.train(
client, client,
{"verbosity": 1, "tree_method": "hist"}, {"verbosity": 1, "tree_method": "hist"},
dtrain, dtrain,
@ -36,7 +36,7 @@ def main(client):
history = output["history"] history = output["history"]
# you can pass output directly into `predict` too. # you can pass output directly into `predict` too.
prediction = xgb.dask.predict(client, bst, dtrain) prediction = dxgb.predict(client, bst, dtrain)
print("Evaluation history:", history) print("Evaluation history:", history)
return prediction return prediction

View File

@ -8,6 +8,7 @@ from dask_ml.datasets import make_regression
from dask_ml.model_selection import train_test_split from dask_ml.model_selection import train_test_split
import xgboost as xgb import xgboost as xgb
import xgboost.dask as dxgb
from xgboost.dask import DaskDMatrix from xgboost.dask import DaskDMatrix
@ -61,7 +62,7 @@ def main(client):
dtrain = DaskDMatrix(client, X_train, y_train) dtrain = DaskDMatrix(client, X_train, y_train)
dtest = DaskDMatrix(client, X_test, y_test) dtest = DaskDMatrix(client, X_test, y_test)
output = xgb.dask.train( output = dxgb.train(
client, client,
{ {
"verbosity": 1, "verbosity": 1,

View File

@ -8,7 +8,6 @@ from dask import dataframe as dd
from dask.distributed import Client from dask.distributed import Client
from dask_cuda import LocalCUDACluster from dask_cuda import LocalCUDACluster
import xgboost as xgb
from xgboost import dask as dxgb from xgboost import dask as dxgb
from xgboost.dask import DaskDMatrix from xgboost.dask import DaskDMatrix
@ -21,7 +20,7 @@ def using_dask_matrix(client: Client, X: da.Array, y: da.Array) -> da.Array:
# Use train method from xgboost.dask instead of xgboost. This distributed version # Use train method from xgboost.dask instead of xgboost. This distributed version
# of train returns a dictionary containing the resulting booster and evaluation # of train returns a dictionary containing the resulting booster and evaluation
# history obtained from evaluation metrics. # history obtained from evaluation metrics.
output = xgb.dask.train( output = dxgb.train(
client, client,
{ {
"verbosity": 2, "verbosity": 2,
@ -37,7 +36,7 @@ def using_dask_matrix(client: Client, X: da.Array, y: da.Array) -> da.Array:
history = output["history"] history = output["history"]
# you can pass output directly into `predict` too. # you can pass output directly into `predict` too.
prediction = xgb.dask.predict(client, bst, dtrain) prediction = dxgb.predict(client, bst, dtrain)
print("Evaluation history:", history) print("Evaluation history:", history)
return prediction return prediction
@ -56,14 +55,14 @@ def using_quantile_device_dmatrix(client: Client, X: da.Array, y: da.Array) -> d
# be used for anything else other than training unless a reference is specified. See # be used for anything else other than training unless a reference is specified. See
# the `ref` argument of `DaskQuantileDMatrix`. # the `ref` argument of `DaskQuantileDMatrix`.
dtrain = dxgb.DaskQuantileDMatrix(client, X, y) dtrain = dxgb.DaskQuantileDMatrix(client, X, y)
output = xgb.dask.train( output = dxgb.train(
client, client,
{"verbosity": 2, "tree_method": "hist", "device": "cuda"}, {"verbosity": 2, "tree_method": "hist", "device": "cuda"},
dtrain, dtrain,
num_boost_round=4, num_boost_round=4,
) )
prediction = xgb.dask.predict(client, output, X) prediction = dxgb.predict(client, output, X)
return prediction return prediction

View File

@ -5,7 +5,7 @@ Use scikit-learn regressor interface with CPU histogram tree method
from dask import array as da from dask import array as da
from dask.distributed import Client, LocalCluster from dask.distributed import Client, LocalCluster
import xgboost from xgboost import dask as dxgb
def main(client): def main(client):
@ -16,7 +16,7 @@ def main(client):
X = da.random.random((m, n), partition_size) X = da.random.random((m, n), partition_size)
y = da.random.random(m, partition_size) y = da.random.random(m, partition_size)
regressor = xgboost.dask.DaskXGBRegressor(verbosity=1, n_estimators=2) regressor = dxgb.DaskXGBRegressor(verbosity=1, n_estimators=2)
regressor.set_params(tree_method="hist") regressor.set_params(tree_method="hist")
# assigning client here is optional # assigning client here is optional
regressor.client = client regressor.client = client

View File

@ -9,7 +9,7 @@ from dask.distributed import Client
# It's recommended to use dask_cuda for GPU assignment # It's recommended to use dask_cuda for GPU assignment
from dask_cuda import LocalCUDACluster from dask_cuda import LocalCUDACluster
import xgboost from xgboost import dask as dxgb
def main(client): def main(client):
@ -20,7 +20,7 @@ def main(client):
X = da.random.random((m, n), partition_size) X = da.random.random((m, n), partition_size)
y = da.random.random(m, partition_size) y = da.random.random(m, partition_size)
regressor = xgboost.dask.DaskXGBRegressor(verbosity=1) regressor = dxgb.DaskXGBRegressor(verbosity=1)
# set the device to CUDA # set the device to CUDA
regressor.set_params(tree_method="hist", device="cuda") regressor.set_params(tree_method="hist", device="cuda")
# assigning client here is optional # assigning client here is optional

View File

@ -39,7 +39,8 @@ on a dask cluster:
.. code-block:: python .. code-block:: python
import xgboost as xgb from xgboost import dask as dxgb
import dask.array as da import dask.array as da
import dask.distributed import dask.distributed
@ -53,11 +54,11 @@ on a dask cluster:
X = da.random.random(size=(num_obs, num_features), chunks=(1000, num_features)) X = da.random.random(size=(num_obs, num_features), chunks=(1000, num_features))
y = da.random.random(size=(num_obs, 1), chunks=(1000, 1)) y = da.random.random(size=(num_obs, 1), chunks=(1000, 1))
dtrain = xgb.dask.DaskDMatrix(client, X, y) dtrain = dxgb.DaskDMatrix(client, X, y)
# or # or
# dtrain = xgb.dask.DaskQuantileDMatrix(client, X, y) # dtrain = dxgb.DaskQuantileDMatrix(client, X, y)
output = xgb.dask.train( output = dxgb.train(
client, client,
{"verbosity": 2, "tree_method": "hist", "objective": "reg:squarederror"}, {"verbosity": 2, "tree_method": "hist", "objective": "reg:squarederror"},
dtrain, dtrain,
@ -87,25 +88,27 @@ returns a model and the computation history as a Python dictionary:
.. code-block:: python .. code-block:: python
{'booster': Booster, {
'history': dict} "booster": Booster,
"history": dict,
}
For prediction, pass the ``output`` returned by ``train`` into :py:func:`xgboost.dask.predict`: For prediction, pass the ``output`` returned by ``train`` into :py:func:`xgboost.dask.predict`:
.. code-block:: python .. code-block:: python
prediction = xgb.dask.predict(client, output, dtrain) prediction = dxgb.predict(client, output, dtrain)
# Or equivalently, pass ``output['booster']``: # Or equivalently, pass ``output['booster']``:
prediction = xgb.dask.predict(client, output['booster'], dtrain) prediction = dxgb.predict(client, output['booster'], dtrain)
Eliminating the construction of DaskDMatrix is also possible, this can make the Eliminating the construction of DaskDMatrix is also possible, this can make the
computation a bit faster when meta information like ``base_margin`` is not needed: computation a bit faster when meta information like ``base_margin`` is not needed:
.. code-block:: python .. code-block:: python
prediction = xgb.dask.predict(client, output, X) prediction = dxgb.predict(client, output, X)
# Use inplace version. # Use inplace version.
prediction = xgb.dask.inplace_predict(client, output, X) prediction = dxgb.inplace_predict(client, output, X)
Here ``prediction`` is a dask ``Array`` object containing predictions from model if input Here ``prediction`` is a dask ``Array`` object containing predictions from model if input
is a ``DaskDMatrix`` or ``da.Array``. When putting dask collection directly into the is a ``DaskDMatrix`` or ``da.Array``. When putting dask collection directly into the
@ -134,14 +137,14 @@ both memory usage and prediction time.
.. code-block:: python .. code-block:: python
# dtrain is the DaskDMatrix defined above. # dtrain is the DaskDMatrix defined above.
prediction = xgb.dask.predict(client, booster, dtrain) prediction = dxgb.predict(client, booster, dtrain)
or equivalently: or equivalently:
.. code-block:: python .. code-block:: python
# where X is a dask DataFrame or dask Array. # where X is a dask DataFrame or dask Array.
prediction = xgb.dask.predict(client, booster, X) prediction = dxgb.predict(client, booster, X)
Also for inplace prediction: Also for inplace prediction:
@ -149,7 +152,7 @@ Also for inplace prediction:
# where X is a dask DataFrame or dask Array backed by cupy or cuDF. # where X is a dask DataFrame or dask Array backed by cupy or cuDF.
booster.set_param({"device": "cuda"}) booster.set_param({"device": "cuda"})
prediction = xgb.dask.inplace_predict(client, booster, X) prediction = dxgb.inplace_predict(client, booster, X)
When input is ``da.Array`` object, output is always ``da.Array``. However, if the input When input is ``da.Array`` object, output is always ``da.Array``. However, if the input
type is ``dd.DataFrame``, output can be ``dd.Series``, ``dd.DataFrame`` or ``da.Array``, type is ``dd.DataFrame``, output can be ``dd.Series``, ``dd.DataFrame`` or ``da.Array``,
@ -174,7 +177,7 @@ One simple optimization for running consecutive predictions is using
futures = [] futures = []
for X in dataset: for X in dataset:
# Here we pass in a future instead of concrete booster # Here we pass in a future instead of concrete booster
shap_f = xgb.dask.predict(client, booster_f, X, pred_contribs=True) shap_f = dxgb.predict(client, booster_f, X, pred_contribs=True)
futures.append(shap_f) futures.append(shap_f)
results = client.gather(futures) results = client.gather(futures)
@ -186,7 +189,7 @@ Scikit-Learn wrapper object:
.. code-block:: python .. code-block:: python
cls = xgb.dask.DaskXGBClassifier() cls = dxgb.DaskXGBClassifier()
cls.fit(X, y) cls.fit(X, y)
booster = cls.get_booster() booster = cls.get_booster()
@ -207,12 +210,12 @@ collection.
.. code-block:: python .. code-block:: python
from distributed import LocalCluster, Client from distributed import LocalCluster, Client
import xgboost as xgb from xgboost import dask as dxgb
def main(client: Client) -> None: def main(client: Client) -> None:
X, y = load_data() X, y = load_data()
clf = xgb.dask.DaskXGBClassifier(n_estimators=100, tree_method="hist") clf = dxgb.DaskXGBClassifier(n_estimators=100, tree_method="hist")
clf.client = client # assign the client clf.client = client # assign the client
clf.fit(X, y, eval_set=[(X, y)]) clf.fit(X, y, eval_set=[(X, y)])
proba = clf.predict_proba(X) proba = clf.predict_proba(X)
@ -242,7 +245,7 @@ In the example below, a ``KubeCluster`` is used for `deploying Dask on Kubernete
from dask_kubernetes import KubeCluster # Need to install the ``dask-kubernetes`` package from dask_kubernetes import KubeCluster # Need to install the ``dask-kubernetes`` package
from dask.distributed import Client from dask.distributed import Client
import xgboost as xgb from xgboost import dask as dxgb
import dask import dask
import dask.array as da import dask.array as da
@ -265,7 +268,7 @@ In the example below, a ``KubeCluster`` is used for `deploying Dask on Kubernete
X = da.random.random(size=(m, n), chunks=100) X = da.random.random(size=(m, n), chunks=100)
y = da.random.random(size=(m, ), chunks=100) y = da.random.random(size=(m, ), chunks=100)
regressor = xgb.dask.DaskXGBRegressor(n_estimators=10, missing=0.0) regressor = dxgb.DaskXGBRegressor(n_estimators=10, missing=0.0)
regressor.client = client regressor.client = client
regressor.set_params(tree_method='hist', device="cuda") regressor.set_params(tree_method='hist', device="cuda")
regressor.fit(X, y, eval_set=[(X, y)]) regressor.fit(X, y, eval_set=[(X, y)])
@ -298,7 +301,7 @@ threads in each process for training. But if ``nthread`` parameter is set:
.. code-block:: python .. code-block:: python
output = xgb.dask.train( output = dxgb.train(
client, client,
{"verbosity": 1, "nthread": 8, "tree_method": "hist"}, {"verbosity": 1, "nthread": 8, "tree_method": "hist"},
dtrain, dtrain,
@ -330,12 +333,12 @@ Functional interface:
async with dask.distributed.Client(scheduler_address, asynchronous=True) as client: async with dask.distributed.Client(scheduler_address, asynchronous=True) as client:
X, y = generate_array() X, y = generate_array()
m = await xgb.dask.DaskDMatrix(client, X, y) m = await dxgb.DaskDMatrix(client, X, y)
output = await xgb.dask.train(client, {}, dtrain=m) output = await dxgb.train(client, {}, dtrain=m)
with_m = await xgb.dask.predict(client, output, m) with_m = await dxgb.predict(client, output, m)
with_X = await xgb.dask.predict(client, output, X) with_X = await dxgb.predict(client, output, X)
inplace = await xgb.dask.inplace_predict(client, output, X) inplace = await dxgb.inplace_predict(client, output, X)
# Use ``client.compute`` instead of the ``compute`` method from dask collection # Use ``client.compute`` instead of the ``compute`` method from dask collection
print(await client.compute(with_m)) print(await client.compute(with_m))
@ -349,7 +352,7 @@ actual computation will return a coroutine and hence require awaiting:
async with dask.distributed.Client(scheduler_address, asynchronous=True) as client: async with dask.distributed.Client(scheduler_address, asynchronous=True) as client:
X, y = generate_array() X, y = generate_array()
regressor = await xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2) regressor = await dxgb.DaskXGBRegressor(verbosity=1, n_estimators=2)
regressor.set_params(tree_method='hist') # trivial method, synchronous operation regressor.set_params(tree_method='hist') # trivial method, synchronous operation
regressor.client = client # accessing attribute, synchronous operation regressor.client = client # accessing attribute, synchronous operation
regressor = await regressor.fit(X, y, eval_set=[(X, y)]) regressor = await regressor.fit(X, y, eval_set=[(X, y)])
@ -371,7 +374,7 @@ To enable early stopping, pass one or more validation sets containing ``DaskDMat
.. code-block:: python .. code-block:: python
import dask.array as da import dask.array as da
import xgboost as xgb from xgboost import dask as dxgb
num_rows = 1e6 num_rows = 1e6
num_features = 100 num_features = 100
@ -398,19 +401,19 @@ To enable early stopping, pass one or more validation sets containing ``DaskDMat
chunks=(rows_per_chunk, 1) chunks=(rows_per_chunk, 1)
) )
dtrain = xgb.dask.DaskDMatrix( dtrain = dxgb.DaskDMatrix(
client=client, client=client,
data=data, data=data,
label=labels label=labels
) )
dvalid = xgb.dask.DaskDMatrix( dvalid = dxgb.DaskDMatrix(
client=client, client=client,
data=X_eval, data=X_eval,
label=y_eval label=y_eval
) )
result = xgb.dask.train( result = dxgb.train(
client=client, client=client,
params={ params={
"objective": "reg:squarederror", "objective": "reg:squarederror",
@ -421,7 +424,7 @@ To enable early stopping, pass one or more validation sets containing ``DaskDMat
early_stopping_rounds=3 early_stopping_rounds=3
) )
When validation sets are provided to ``xgb.dask.train()`` in this way, the model object returned by ``xgb.dask.train()`` contains a history of evaluation metrics for each validation set, across all boosting rounds. When validation sets are provided to :py:func:`xgboost.dask.train` in this way, the model object returned by :py:func:`xgboost.dask.train` contains a history of evaluation metrics for each validation set, across all boosting rounds.
.. code-block:: python .. code-block:: python
@ -463,7 +466,7 @@ interface, including callback functions, custom evaluation metric and objective:
save_best=True, save_best=True,
) )
booster = xgb.dask.train( booster = dxgb.train(
client, client,
params={ params={
"objective": "binary:logistic", "objective": "binary:logistic",