[dask] Improve documents. (#6687)

* Add tag for versions.
* use autoclass in sphinx build.
Made some class methods to be private to avoid exporting documents.
This commit is contained in:
Jiaming Yuan
2021-02-09 09:20:58 +08:00
committed by GitHub
parent 5d48d40d9a
commit 1335db6113
4 changed files with 77 additions and 30 deletions

View File

@@ -1,6 +1,7 @@
# pylint: disable=too-many-arguments, too-many-locals, no-name-in-module
# pylint: disable=missing-class-docstring, invalid-name
# pylint: disable=too-many-lines, fixme
# pylint: disable=too-few-public-methods
# pylint: disable=import-error
"""Dask extensions for distributed training. See
https://xgboost.readthedocs.io/en/latest/tutorials/dask.html for simple
@@ -259,7 +260,7 @@ class DaskDMatrix:
self.is_quantile: bool = False
self._init = client.sync(
self.map_local_data,
self._map_local_data,
client,
data,
label=label,
@@ -274,7 +275,7 @@ class DaskDMatrix:
def __await__(self) -> Generator:
return self._init.__await__()
async def map_local_data(
async def _map_local_data(
self,
client: "distributed.Client",
data: _DaskCollection,
@@ -393,7 +394,7 @@ class DaskDMatrix:
return self
def create_fn_args(self, worker_addr: str) -> Dict[str, Any]:
def _create_fn_args(self, worker_addr: str) -> Dict[str, Any]:
'''Create a dictionary of objects that can be pickled for function
arguments.
@@ -627,8 +628,8 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix):
self.max_bin = max_bin
self.is_quantile = True
def create_fn_args(self, worker_addr: str) -> Dict[str, Any]:
args = super().create_fn_args(worker_addr)
def _create_fn_args(self, worker_addr: str) -> Dict[str, Any]:
args = super()._create_fn_args(worker_addr)
args["max_bin"] = self.max_bin
return args
@@ -864,18 +865,22 @@ async def _train_async(
futures = []
for i, worker_addr in enumerate(workers):
if evals:
evals_per_worker = [(e.create_fn_args(worker_addr), name, id(e))
# pylint: disable=protected-access
evals_per_worker = [(e._create_fn_args(worker_addr), name, id(e))
for e, name in evals]
else:
evals_per_worker = []
f = client.submit(dispatched_train,
worker_addr,
_rabit_args,
dtrain.create_fn_args(workers[i]),
id(dtrain),
evals_per_worker,
pure=False,
workers=[worker_addr])
f = client.submit(
dispatched_train,
worker_addr,
_rabit_args,
# pylint: disable=protected-access
dtrain._create_fn_args(workers[i]),
id(dtrain),
evals_per_worker,
pure=False,
workers=[worker_addr]
)
futures.append(f)
results = await client.gather(futures)
@@ -1755,7 +1760,11 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
@xgboost_model_doc(
"Implementation of the Scikit-Learn API for XGBoost Ranking.",
"""Implementation of the Scikit-Learn API for XGBoost Ranking.
.. versionadded:: 1.4.0
""",
["estimators", "model"],
end_note="""
Note
@@ -1868,7 +1877,11 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
@xgboost_model_doc(
"Implementation of the Scikit-Learn API for XGBoost Random Forest Regressor.",
"""Implementation of the Scikit-Learn API for XGBoost Random Forest Regressor.
.. versionadded:: 1.4.0
""",
["model", "objective"],
extra_parameters="""
n_estimators : int
@@ -1904,7 +1917,11 @@ class DaskXGBRFRegressor(DaskXGBRegressor):
@xgboost_model_doc(
"Implementation of the Scikit-Learn API for XGBoost Random Forest Classifier.",
"""Implementation of the Scikit-Learn API for XGBoost Random Forest Classifier.
.. versionadded:: 1.4.0
""",
["model", "objective"],
extra_parameters="""
n_estimators : int