[dask] Improve documents. (#6687)
* Add tag for versions. * use autoclass in sphinx build. Made some class methods to be private to avoid exporting documents.
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# pylint: disable=too-many-arguments, too-many-locals, no-name-in-module
|
||||
# pylint: disable=missing-class-docstring, invalid-name
|
||||
# pylint: disable=too-many-lines, fixme
|
||||
# pylint: disable=too-few-public-methods
|
||||
# pylint: disable=import-error
|
||||
"""Dask extensions for distributed training. See
|
||||
https://xgboost.readthedocs.io/en/latest/tutorials/dask.html for simple
|
||||
@@ -259,7 +260,7 @@ class DaskDMatrix:
|
||||
self.is_quantile: bool = False
|
||||
|
||||
self._init = client.sync(
|
||||
self.map_local_data,
|
||||
self._map_local_data,
|
||||
client,
|
||||
data,
|
||||
label=label,
|
||||
@@ -274,7 +275,7 @@ class DaskDMatrix:
|
||||
def __await__(self) -> Generator:
|
||||
return self._init.__await__()
|
||||
|
||||
async def map_local_data(
|
||||
async def _map_local_data(
|
||||
self,
|
||||
client: "distributed.Client",
|
||||
data: _DaskCollection,
|
||||
@@ -393,7 +394,7 @@ class DaskDMatrix:
|
||||
|
||||
return self
|
||||
|
||||
def create_fn_args(self, worker_addr: str) -> Dict[str, Any]:
|
||||
def _create_fn_args(self, worker_addr: str) -> Dict[str, Any]:
|
||||
'''Create a dictionary of objects that can be pickled for function
|
||||
arguments.
|
||||
|
||||
@@ -627,8 +628,8 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix):
|
||||
self.max_bin = max_bin
|
||||
self.is_quantile = True
|
||||
|
||||
def create_fn_args(self, worker_addr: str) -> Dict[str, Any]:
|
||||
args = super().create_fn_args(worker_addr)
|
||||
def _create_fn_args(self, worker_addr: str) -> Dict[str, Any]:
|
||||
args = super()._create_fn_args(worker_addr)
|
||||
args["max_bin"] = self.max_bin
|
||||
return args
|
||||
|
||||
@@ -864,18 +865,22 @@ async def _train_async(
|
||||
futures = []
|
||||
for i, worker_addr in enumerate(workers):
|
||||
if evals:
|
||||
evals_per_worker = [(e.create_fn_args(worker_addr), name, id(e))
|
||||
# pylint: disable=protected-access
|
||||
evals_per_worker = [(e._create_fn_args(worker_addr), name, id(e))
|
||||
for e, name in evals]
|
||||
else:
|
||||
evals_per_worker = []
|
||||
f = client.submit(dispatched_train,
|
||||
worker_addr,
|
||||
_rabit_args,
|
||||
dtrain.create_fn_args(workers[i]),
|
||||
id(dtrain),
|
||||
evals_per_worker,
|
||||
pure=False,
|
||||
workers=[worker_addr])
|
||||
f = client.submit(
|
||||
dispatched_train,
|
||||
worker_addr,
|
||||
_rabit_args,
|
||||
# pylint: disable=protected-access
|
||||
dtrain._create_fn_args(workers[i]),
|
||||
id(dtrain),
|
||||
evals_per_worker,
|
||||
pure=False,
|
||||
workers=[worker_addr]
|
||||
)
|
||||
futures.append(f)
|
||||
|
||||
results = await client.gather(futures)
|
||||
@@ -1755,7 +1760,11 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
|
||||
|
||||
@xgboost_model_doc(
|
||||
"Implementation of the Scikit-Learn API for XGBoost Ranking.",
|
||||
"""Implementation of the Scikit-Learn API for XGBoost Ranking.
|
||||
|
||||
.. versionadded:: 1.4.0
|
||||
|
||||
""",
|
||||
["estimators", "model"],
|
||||
end_note="""
|
||||
Note
|
||||
@@ -1868,7 +1877,11 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
|
||||
|
||||
|
||||
@xgboost_model_doc(
|
||||
"Implementation of the Scikit-Learn API for XGBoost Random Forest Regressor.",
|
||||
"""Implementation of the Scikit-Learn API for XGBoost Random Forest Regressor.
|
||||
|
||||
.. versionadded:: 1.4.0
|
||||
|
||||
""",
|
||||
["model", "objective"],
|
||||
extra_parameters="""
|
||||
n_estimators : int
|
||||
@@ -1904,7 +1917,11 @@ class DaskXGBRFRegressor(DaskXGBRegressor):
|
||||
|
||||
|
||||
@xgboost_model_doc(
|
||||
"Implementation of the Scikit-Learn API for XGBoost Random Forest Classifier.",
|
||||
"""Implementation of the Scikit-Learn API for XGBoost Random Forest Classifier.
|
||||
|
||||
.. versionadded:: 1.4.0
|
||||
|
||||
""",
|
||||
["model", "objective"],
|
||||
extra_parameters="""
|
||||
n_estimators : int
|
||||
|
||||
Reference in New Issue
Block a user