[dask] Improve documents. (#6687)
* Add tag for versions. * use autoclass in sphinx build. Made some class methods to be private to avoid exporting documents.
This commit is contained in:
parent
5d48d40d9a
commit
1335db6113
@ -93,9 +93,15 @@ Dask API
|
||||
--------
|
||||
.. automodule:: xgboost.dask
|
||||
|
||||
.. autofunction:: xgboost.dask.DaskDMatrix
|
||||
.. autoclass:: xgboost.dask.DaskDMatrix
|
||||
:members:
|
||||
:inherited-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autofunction:: xgboost.dask.DaskDeviceQuantileDMatrix
|
||||
.. autoclass:: xgboost.dask.DaskDeviceQuantileDMatrix
|
||||
:members:
|
||||
:inherited-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autofunction:: xgboost.dask.train
|
||||
|
||||
@ -103,12 +109,27 @@ Dask API
|
||||
|
||||
.. autofunction:: xgboost.dask.inplace_predict
|
||||
|
||||
.. autofunction:: xgboost.dask.DaskXGBClassifier
|
||||
.. autoclass:: xgboost.dask.DaskXGBClassifier
|
||||
:members:
|
||||
:inherited-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autofunction:: xgboost.dask.DaskXGBRegressor
|
||||
.. autoclass:: xgboost.dask.DaskXGBRegressor
|
||||
:members:
|
||||
:inherited-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autofunction:: xgboost.dask.DaskXGBRanker
|
||||
.. autoclass:: xgboost.dask.DaskXGBRanker
|
||||
:members:
|
||||
:inherited-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autofunction:: xgboost.dask.DaskXGBRFRegressor
|
||||
.. autoclass:: xgboost.dask.DaskXGBRFRegressor
|
||||
:members:
|
||||
:inherited-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autofunction:: xgboost.dask.DaskXGBRFClassifier
|
||||
.. autoclass:: xgboost.dask.DaskXGBRFClassifier
|
||||
:members:
|
||||
:inherited-members:
|
||||
:show-inheritance:
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# pylint: disable=too-many-arguments, too-many-locals, no-name-in-module
|
||||
# pylint: disable=missing-class-docstring, invalid-name
|
||||
# pylint: disable=too-many-lines, fixme
|
||||
# pylint: disable=too-few-public-methods
|
||||
# pylint: disable=import-error
|
||||
"""Dask extensions for distributed training. See
|
||||
https://xgboost.readthedocs.io/en/latest/tutorials/dask.html for simple
|
||||
@ -259,7 +260,7 @@ class DaskDMatrix:
|
||||
self.is_quantile: bool = False
|
||||
|
||||
self._init = client.sync(
|
||||
self.map_local_data,
|
||||
self._map_local_data,
|
||||
client,
|
||||
data,
|
||||
label=label,
|
||||
@ -274,7 +275,7 @@ class DaskDMatrix:
|
||||
def __await__(self) -> Generator:
|
||||
return self._init.__await__()
|
||||
|
||||
async def map_local_data(
|
||||
async def _map_local_data(
|
||||
self,
|
||||
client: "distributed.Client",
|
||||
data: _DaskCollection,
|
||||
@ -393,7 +394,7 @@ class DaskDMatrix:
|
||||
|
||||
return self
|
||||
|
||||
def create_fn_args(self, worker_addr: str) -> Dict[str, Any]:
|
||||
def _create_fn_args(self, worker_addr: str) -> Dict[str, Any]:
|
||||
'''Create a dictionary of objects that can be pickled for function
|
||||
arguments.
|
||||
|
||||
@ -627,8 +628,8 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix):
|
||||
self.max_bin = max_bin
|
||||
self.is_quantile = True
|
||||
|
||||
def create_fn_args(self, worker_addr: str) -> Dict[str, Any]:
|
||||
args = super().create_fn_args(worker_addr)
|
||||
def _create_fn_args(self, worker_addr: str) -> Dict[str, Any]:
|
||||
args = super()._create_fn_args(worker_addr)
|
||||
args["max_bin"] = self.max_bin
|
||||
return args
|
||||
|
||||
@ -864,18 +865,22 @@ async def _train_async(
|
||||
futures = []
|
||||
for i, worker_addr in enumerate(workers):
|
||||
if evals:
|
||||
evals_per_worker = [(e.create_fn_args(worker_addr), name, id(e))
|
||||
# pylint: disable=protected-access
|
||||
evals_per_worker = [(e._create_fn_args(worker_addr), name, id(e))
|
||||
for e, name in evals]
|
||||
else:
|
||||
evals_per_worker = []
|
||||
f = client.submit(dispatched_train,
|
||||
f = client.submit(
|
||||
dispatched_train,
|
||||
worker_addr,
|
||||
_rabit_args,
|
||||
dtrain.create_fn_args(workers[i]),
|
||||
# pylint: disable=protected-access
|
||||
dtrain._create_fn_args(workers[i]),
|
||||
id(dtrain),
|
||||
evals_per_worker,
|
||||
pure=False,
|
||||
workers=[worker_addr])
|
||||
workers=[worker_addr]
|
||||
)
|
||||
futures.append(f)
|
||||
|
||||
results = await client.gather(futures)
|
||||
@ -1755,7 +1760,11 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
|
||||
|
||||
@xgboost_model_doc(
|
||||
"Implementation of the Scikit-Learn API for XGBoost Ranking.",
|
||||
"""Implementation of the Scikit-Learn API for XGBoost Ranking.
|
||||
|
||||
.. versionadded:: 1.4.0
|
||||
|
||||
""",
|
||||
["estimators", "model"],
|
||||
end_note="""
|
||||
Note
|
||||
@ -1868,7 +1877,11 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
|
||||
|
||||
|
||||
@xgboost_model_doc(
|
||||
"Implementation of the Scikit-Learn API for XGBoost Random Forest Regressor.",
|
||||
"""Implementation of the Scikit-Learn API for XGBoost Random Forest Regressor.
|
||||
|
||||
.. versionadded:: 1.4.0
|
||||
|
||||
""",
|
||||
["model", "objective"],
|
||||
extra_parameters="""
|
||||
n_estimators : int
|
||||
@ -1904,7 +1917,11 @@ class DaskXGBRFRegressor(DaskXGBRegressor):
|
||||
|
||||
|
||||
@xgboost_model_doc(
|
||||
"Implementation of the Scikit-Learn API for XGBoost Random Forest Classifier.",
|
||||
"""Implementation of the Scikit-Learn API for XGBoost Random Forest Classifier.
|
||||
|
||||
.. versionadded:: 1.4.0
|
||||
|
||||
""",
|
||||
["model", "objective"],
|
||||
extra_parameters="""
|
||||
n_estimators : int
|
||||
|
||||
@ -304,9 +304,15 @@ class TestDistributedGPU:
|
||||
|
||||
futures = []
|
||||
for i in range(len(workers)):
|
||||
futures.append(client.submit(worker_fn, workers[i],
|
||||
m.create_fn_args(workers[i]), pure=False,
|
||||
workers=[workers[i]]))
|
||||
futures.append(
|
||||
client.submit(
|
||||
worker_fn,
|
||||
workers[i],
|
||||
m._create_fn_args(workers[i]),
|
||||
pure=False,
|
||||
workers=[workers[i]]
|
||||
)
|
||||
)
|
||||
client.gather(futures)
|
||||
|
||||
def test_interface_consistency(self) -> None:
|
||||
|
||||
@ -1046,9 +1046,12 @@ class TestWithDask:
|
||||
|
||||
futures = []
|
||||
for i in range(len(workers)):
|
||||
futures.append(client.submit(worker_fn, workers[i],
|
||||
m.create_fn_args(workers[i]), pure=False,
|
||||
workers=[workers[i]]))
|
||||
futures.append(
|
||||
client.submit(
|
||||
worker_fn, workers[i],
|
||||
m._create_fn_args(workers[i]), pure=False,
|
||||
workers=[workers[i]])
|
||||
)
|
||||
client.gather(futures)
|
||||
|
||||
has_what = client.has_what()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user