[dask] Improve documents. (#6687)

* Add tag for versions.
* use autoclass in sphinx build.
Made some class methods to be private to avoid exporting documents.
This commit is contained in:
Jiaming Yuan 2021-02-09 09:20:58 +08:00 committed by GitHub
parent 5d48d40d9a
commit 1335db6113
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 77 additions and 30 deletions

View File

@ -93,9 +93,15 @@ Dask API
--------
.. automodule:: xgboost.dask
.. autofunction:: xgboost.dask.DaskDMatrix
.. autoclass:: xgboost.dask.DaskDMatrix
:members:
:inherited-members:
:show-inheritance:
.. autofunction:: xgboost.dask.DaskDeviceQuantileDMatrix
.. autoclass:: xgboost.dask.DaskDeviceQuantileDMatrix
:members:
:inherited-members:
:show-inheritance:
.. autofunction:: xgboost.dask.train
@ -103,12 +109,27 @@ Dask API
.. autofunction:: xgboost.dask.inplace_predict
.. autofunction:: xgboost.dask.DaskXGBClassifier
.. autoclass:: xgboost.dask.DaskXGBClassifier
:members:
:inherited-members:
:show-inheritance:
.. autofunction:: xgboost.dask.DaskXGBRegressor
.. autoclass:: xgboost.dask.DaskXGBRegressor
:members:
:inherited-members:
:show-inheritance:
.. autofunction:: xgboost.dask.DaskXGBRanker
.. autoclass:: xgboost.dask.DaskXGBRanker
:members:
:inherited-members:
:show-inheritance:
.. autofunction:: xgboost.dask.DaskXGBRFRegressor
.. autoclass:: xgboost.dask.DaskXGBRFRegressor
:members:
:inherited-members:
:show-inheritance:
.. autofunction:: xgboost.dask.DaskXGBRFClassifier
.. autoclass:: xgboost.dask.DaskXGBRFClassifier
:members:
:inherited-members:
:show-inheritance:

View File

@ -1,6 +1,7 @@
# pylint: disable=too-many-arguments, too-many-locals, no-name-in-module
# pylint: disable=missing-class-docstring, invalid-name
# pylint: disable=too-many-lines, fixme
# pylint: disable=too-few-public-methods
# pylint: disable=import-error
"""Dask extensions for distributed training. See
https://xgboost.readthedocs.io/en/latest/tutorials/dask.html for simple
@ -259,7 +260,7 @@ class DaskDMatrix:
self.is_quantile: bool = False
self._init = client.sync(
self.map_local_data,
self._map_local_data,
client,
data,
label=label,
@ -274,7 +275,7 @@ class DaskDMatrix:
def __await__(self) -> Generator:
return self._init.__await__()
async def map_local_data(
async def _map_local_data(
self,
client: "distributed.Client",
data: _DaskCollection,
@ -393,7 +394,7 @@ class DaskDMatrix:
return self
def create_fn_args(self, worker_addr: str) -> Dict[str, Any]:
def _create_fn_args(self, worker_addr: str) -> Dict[str, Any]:
'''Create a dictionary of objects that can be pickled for function
arguments.
@ -627,8 +628,8 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix):
self.max_bin = max_bin
self.is_quantile = True
def create_fn_args(self, worker_addr: str) -> Dict[str, Any]:
args = super().create_fn_args(worker_addr)
def _create_fn_args(self, worker_addr: str) -> Dict[str, Any]:
args = super()._create_fn_args(worker_addr)
args["max_bin"] = self.max_bin
return args
@ -864,18 +865,22 @@ async def _train_async(
futures = []
for i, worker_addr in enumerate(workers):
if evals:
evals_per_worker = [(e.create_fn_args(worker_addr), name, id(e))
# pylint: disable=protected-access
evals_per_worker = [(e._create_fn_args(worker_addr), name, id(e))
for e, name in evals]
else:
evals_per_worker = []
f = client.submit(dispatched_train,
worker_addr,
_rabit_args,
dtrain.create_fn_args(workers[i]),
id(dtrain),
evals_per_worker,
pure=False,
workers=[worker_addr])
f = client.submit(
dispatched_train,
worker_addr,
_rabit_args,
# pylint: disable=protected-access
dtrain._create_fn_args(workers[i]),
id(dtrain),
evals_per_worker,
pure=False,
workers=[worker_addr]
)
futures.append(f)
results = await client.gather(futures)
@ -1755,7 +1760,11 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
@xgboost_model_doc(
"Implementation of the Scikit-Learn API for XGBoost Ranking.",
"""Implementation of the Scikit-Learn API for XGBoost Ranking.
.. versionadded:: 1.4.0
""",
["estimators", "model"],
end_note="""
Note
@ -1868,7 +1877,11 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
@xgboost_model_doc(
"Implementation of the Scikit-Learn API for XGBoost Random Forest Regressor.",
"""Implementation of the Scikit-Learn API for XGBoost Random Forest Regressor.
.. versionadded:: 1.4.0
""",
["model", "objective"],
extra_parameters="""
n_estimators : int
@ -1904,7 +1917,11 @@ class DaskXGBRFRegressor(DaskXGBRegressor):
@xgboost_model_doc(
"Implementation of the Scikit-Learn API for XGBoost Random Forest Classifier.",
"""Implementation of the Scikit-Learn API for XGBoost Random Forest Classifier.
.. versionadded:: 1.4.0
""",
["model", "objective"],
extra_parameters="""
n_estimators : int

View File

@ -304,9 +304,15 @@ class TestDistributedGPU:
futures = []
for i in range(len(workers)):
futures.append(client.submit(worker_fn, workers[i],
m.create_fn_args(workers[i]), pure=False,
workers=[workers[i]]))
futures.append(
client.submit(
worker_fn,
workers[i],
m._create_fn_args(workers[i]),
pure=False,
workers=[workers[i]]
)
)
client.gather(futures)
def test_interface_consistency(self) -> None:

View File

@ -1046,9 +1046,12 @@ class TestWithDask:
futures = []
for i in range(len(workers)):
futures.append(client.submit(worker_fn, workers[i],
m.create_fn_args(workers[i]), pure=False,
workers=[workers[i]]))
futures.append(
client.submit(
worker_fn, workers[i],
m._create_fn_args(workers[i]), pure=False,
workers=[workers[i]])
)
client.gather(futures)
has_what = client.has_what()