[dask] Improve documents. (#6687)

* Add tag for versions.
* use autoclass in sphinx build.
Made some class methods to be private to avoid exporting documents.
This commit is contained in:
Jiaming Yuan 2021-02-09 09:20:58 +08:00 committed by GitHub
parent 5d48d40d9a
commit 1335db6113
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 77 additions and 30 deletions

View File

@ -93,9 +93,15 @@ Dask API
-------- --------
.. automodule:: xgboost.dask .. automodule:: xgboost.dask
.. autofunction:: xgboost.dask.DaskDMatrix .. autoclass:: xgboost.dask.DaskDMatrix
:members:
:inherited-members:
:show-inheritance:
.. autofunction:: xgboost.dask.DaskDeviceQuantileDMatrix .. autoclass:: xgboost.dask.DaskDeviceQuantileDMatrix
:members:
:inherited-members:
:show-inheritance:
.. autofunction:: xgboost.dask.train .. autofunction:: xgboost.dask.train
@ -103,12 +109,27 @@ Dask API
.. autofunction:: xgboost.dask.inplace_predict .. autofunction:: xgboost.dask.inplace_predict
.. autofunction:: xgboost.dask.DaskXGBClassifier .. autoclass:: xgboost.dask.DaskXGBClassifier
:members:
:inherited-members:
:show-inheritance:
.. autofunction:: xgboost.dask.DaskXGBRegressor .. autoclass:: xgboost.dask.DaskXGBRegressor
:members:
:inherited-members:
:show-inheritance:
.. autofunction:: xgboost.dask.DaskXGBRanker .. autoclass:: xgboost.dask.DaskXGBRanker
:members:
:inherited-members:
:show-inheritance:
.. autofunction:: xgboost.dask.DaskXGBRFRegressor .. autoclass:: xgboost.dask.DaskXGBRFRegressor
:members:
:inherited-members:
:show-inheritance:
.. autofunction:: xgboost.dask.DaskXGBRFClassifier .. autoclass:: xgboost.dask.DaskXGBRFClassifier
:members:
:inherited-members:
:show-inheritance:

View File

@ -1,6 +1,7 @@
# pylint: disable=too-many-arguments, too-many-locals, no-name-in-module # pylint: disable=too-many-arguments, too-many-locals, no-name-in-module
# pylint: disable=missing-class-docstring, invalid-name # pylint: disable=missing-class-docstring, invalid-name
# pylint: disable=too-many-lines, fixme # pylint: disable=too-many-lines, fixme
# pylint: disable=too-few-public-methods
# pylint: disable=import-error # pylint: disable=import-error
"""Dask extensions for distributed training. See """Dask extensions for distributed training. See
https://xgboost.readthedocs.io/en/latest/tutorials/dask.html for simple https://xgboost.readthedocs.io/en/latest/tutorials/dask.html for simple
@ -259,7 +260,7 @@ class DaskDMatrix:
self.is_quantile: bool = False self.is_quantile: bool = False
self._init = client.sync( self._init = client.sync(
self.map_local_data, self._map_local_data,
client, client,
data, data,
label=label, label=label,
@ -274,7 +275,7 @@ class DaskDMatrix:
def __await__(self) -> Generator: def __await__(self) -> Generator:
return self._init.__await__() return self._init.__await__()
async def map_local_data( async def _map_local_data(
self, self,
client: "distributed.Client", client: "distributed.Client",
data: _DaskCollection, data: _DaskCollection,
@ -393,7 +394,7 @@ class DaskDMatrix:
return self return self
def create_fn_args(self, worker_addr: str) -> Dict[str, Any]: def _create_fn_args(self, worker_addr: str) -> Dict[str, Any]:
'''Create a dictionary of objects that can be pickled for function '''Create a dictionary of objects that can be pickled for function
arguments. arguments.
@ -627,8 +628,8 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix):
self.max_bin = max_bin self.max_bin = max_bin
self.is_quantile = True self.is_quantile = True
def create_fn_args(self, worker_addr: str) -> Dict[str, Any]: def _create_fn_args(self, worker_addr: str) -> Dict[str, Any]:
args = super().create_fn_args(worker_addr) args = super()._create_fn_args(worker_addr)
args["max_bin"] = self.max_bin args["max_bin"] = self.max_bin
return args return args
@ -864,18 +865,22 @@ async def _train_async(
futures = [] futures = []
for i, worker_addr in enumerate(workers): for i, worker_addr in enumerate(workers):
if evals: if evals:
evals_per_worker = [(e.create_fn_args(worker_addr), name, id(e)) # pylint: disable=protected-access
evals_per_worker = [(e._create_fn_args(worker_addr), name, id(e))
for e, name in evals] for e, name in evals]
else: else:
evals_per_worker = [] evals_per_worker = []
f = client.submit(dispatched_train, f = client.submit(
dispatched_train,
worker_addr, worker_addr,
_rabit_args, _rabit_args,
dtrain.create_fn_args(workers[i]), # pylint: disable=protected-access
dtrain._create_fn_args(workers[i]),
id(dtrain), id(dtrain),
evals_per_worker, evals_per_worker,
pure=False, pure=False,
workers=[worker_addr]) workers=[worker_addr]
)
futures.append(f) futures.append(f)
results = await client.gather(futures) results = await client.gather(futures)
@ -1755,7 +1760,11 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
@xgboost_model_doc( @xgboost_model_doc(
"Implementation of the Scikit-Learn API for XGBoost Ranking.", """Implementation of the Scikit-Learn API for XGBoost Ranking.
.. versionadded:: 1.4.0
""",
["estimators", "model"], ["estimators", "model"],
end_note=""" end_note="""
Note Note
@ -1868,7 +1877,11 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
@xgboost_model_doc( @xgboost_model_doc(
"Implementation of the Scikit-Learn API for XGBoost Random Forest Regressor.", """Implementation of the Scikit-Learn API for XGBoost Random Forest Regressor.
.. versionadded:: 1.4.0
""",
["model", "objective"], ["model", "objective"],
extra_parameters=""" extra_parameters="""
n_estimators : int n_estimators : int
@ -1904,7 +1917,11 @@ class DaskXGBRFRegressor(DaskXGBRegressor):
@xgboost_model_doc( @xgboost_model_doc(
"Implementation of the Scikit-Learn API for XGBoost Random Forest Classifier.", """Implementation of the Scikit-Learn API for XGBoost Random Forest Classifier.
.. versionadded:: 1.4.0
""",
["model", "objective"], ["model", "objective"],
extra_parameters=""" extra_parameters="""
n_estimators : int n_estimators : int

View File

@ -304,9 +304,15 @@ class TestDistributedGPU:
futures = [] futures = []
for i in range(len(workers)): for i in range(len(workers)):
futures.append(client.submit(worker_fn, workers[i], futures.append(
m.create_fn_args(workers[i]), pure=False, client.submit(
workers=[workers[i]])) worker_fn,
workers[i],
m._create_fn_args(workers[i]),
pure=False,
workers=[workers[i]]
)
)
client.gather(futures) client.gather(futures)
def test_interface_consistency(self) -> None: def test_interface_consistency(self) -> None:

View File

@ -1046,9 +1046,12 @@ class TestWithDask:
futures = [] futures = []
for i in range(len(workers)): for i in range(len(workers)):
futures.append(client.submit(worker_fn, workers[i], futures.append(
m.create_fn_args(workers[i]), pure=False, client.submit(
workers=[workers[i]])) worker_fn, workers[i],
m._create_fn_args(workers[i]), pure=False,
workers=[workers[i]])
)
client.gather(futures) client.gather(futures)
has_what = client.has_what() has_what = client.has_what()