[dask] Improve documents. (#6687)
* Add tag for versions. * use autoclass in sphinx build. Made some class methods to be private to avoid exporting documents.
This commit is contained in:
parent
5d48d40d9a
commit
1335db6113
@ -93,9 +93,15 @@ Dask API
|
|||||||
--------
|
--------
|
||||||
.. automodule:: xgboost.dask
|
.. automodule:: xgboost.dask
|
||||||
|
|
||||||
.. autofunction:: xgboost.dask.DaskDMatrix
|
.. autoclass:: xgboost.dask.DaskDMatrix
|
||||||
|
:members:
|
||||||
|
:inherited-members:
|
||||||
|
:show-inheritance:
|
||||||
|
|
||||||
.. autofunction:: xgboost.dask.DaskDeviceQuantileDMatrix
|
.. autoclass:: xgboost.dask.DaskDeviceQuantileDMatrix
|
||||||
|
:members:
|
||||||
|
:inherited-members:
|
||||||
|
:show-inheritance:
|
||||||
|
|
||||||
.. autofunction:: xgboost.dask.train
|
.. autofunction:: xgboost.dask.train
|
||||||
|
|
||||||
@ -103,12 +109,27 @@ Dask API
|
|||||||
|
|
||||||
.. autofunction:: xgboost.dask.inplace_predict
|
.. autofunction:: xgboost.dask.inplace_predict
|
||||||
|
|
||||||
.. autofunction:: xgboost.dask.DaskXGBClassifier
|
.. autoclass:: xgboost.dask.DaskXGBClassifier
|
||||||
|
:members:
|
||||||
|
:inherited-members:
|
||||||
|
:show-inheritance:
|
||||||
|
|
||||||
.. autofunction:: xgboost.dask.DaskXGBRegressor
|
.. autoclass:: xgboost.dask.DaskXGBRegressor
|
||||||
|
:members:
|
||||||
|
:inherited-members:
|
||||||
|
:show-inheritance:
|
||||||
|
|
||||||
.. autofunction:: xgboost.dask.DaskXGBRanker
|
.. autoclass:: xgboost.dask.DaskXGBRanker
|
||||||
|
:members:
|
||||||
|
:inherited-members:
|
||||||
|
:show-inheritance:
|
||||||
|
|
||||||
.. autofunction:: xgboost.dask.DaskXGBRFRegressor
|
.. autoclass:: xgboost.dask.DaskXGBRFRegressor
|
||||||
|
:members:
|
||||||
|
:inherited-members:
|
||||||
|
:show-inheritance:
|
||||||
|
|
||||||
.. autofunction:: xgboost.dask.DaskXGBRFClassifier
|
.. autoclass:: xgboost.dask.DaskXGBRFClassifier
|
||||||
|
:members:
|
||||||
|
:inherited-members:
|
||||||
|
:show-inheritance:
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
# pylint: disable=too-many-arguments, too-many-locals, no-name-in-module
|
# pylint: disable=too-many-arguments, too-many-locals, no-name-in-module
|
||||||
# pylint: disable=missing-class-docstring, invalid-name
|
# pylint: disable=missing-class-docstring, invalid-name
|
||||||
# pylint: disable=too-many-lines, fixme
|
# pylint: disable=too-many-lines, fixme
|
||||||
|
# pylint: disable=too-few-public-methods
|
||||||
# pylint: disable=import-error
|
# pylint: disable=import-error
|
||||||
"""Dask extensions for distributed training. See
|
"""Dask extensions for distributed training. See
|
||||||
https://xgboost.readthedocs.io/en/latest/tutorials/dask.html for simple
|
https://xgboost.readthedocs.io/en/latest/tutorials/dask.html for simple
|
||||||
@ -259,7 +260,7 @@ class DaskDMatrix:
|
|||||||
self.is_quantile: bool = False
|
self.is_quantile: bool = False
|
||||||
|
|
||||||
self._init = client.sync(
|
self._init = client.sync(
|
||||||
self.map_local_data,
|
self._map_local_data,
|
||||||
client,
|
client,
|
||||||
data,
|
data,
|
||||||
label=label,
|
label=label,
|
||||||
@ -274,7 +275,7 @@ class DaskDMatrix:
|
|||||||
def __await__(self) -> Generator:
|
def __await__(self) -> Generator:
|
||||||
return self._init.__await__()
|
return self._init.__await__()
|
||||||
|
|
||||||
async def map_local_data(
|
async def _map_local_data(
|
||||||
self,
|
self,
|
||||||
client: "distributed.Client",
|
client: "distributed.Client",
|
||||||
data: _DaskCollection,
|
data: _DaskCollection,
|
||||||
@ -393,7 +394,7 @@ class DaskDMatrix:
|
|||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def create_fn_args(self, worker_addr: str) -> Dict[str, Any]:
|
def _create_fn_args(self, worker_addr: str) -> Dict[str, Any]:
|
||||||
'''Create a dictionary of objects that can be pickled for function
|
'''Create a dictionary of objects that can be pickled for function
|
||||||
arguments.
|
arguments.
|
||||||
|
|
||||||
@ -627,8 +628,8 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix):
|
|||||||
self.max_bin = max_bin
|
self.max_bin = max_bin
|
||||||
self.is_quantile = True
|
self.is_quantile = True
|
||||||
|
|
||||||
def create_fn_args(self, worker_addr: str) -> Dict[str, Any]:
|
def _create_fn_args(self, worker_addr: str) -> Dict[str, Any]:
|
||||||
args = super().create_fn_args(worker_addr)
|
args = super()._create_fn_args(worker_addr)
|
||||||
args["max_bin"] = self.max_bin
|
args["max_bin"] = self.max_bin
|
||||||
return args
|
return args
|
||||||
|
|
||||||
@ -864,18 +865,22 @@ async def _train_async(
|
|||||||
futures = []
|
futures = []
|
||||||
for i, worker_addr in enumerate(workers):
|
for i, worker_addr in enumerate(workers):
|
||||||
if evals:
|
if evals:
|
||||||
evals_per_worker = [(e.create_fn_args(worker_addr), name, id(e))
|
# pylint: disable=protected-access
|
||||||
|
evals_per_worker = [(e._create_fn_args(worker_addr), name, id(e))
|
||||||
for e, name in evals]
|
for e, name in evals]
|
||||||
else:
|
else:
|
||||||
evals_per_worker = []
|
evals_per_worker = []
|
||||||
f = client.submit(dispatched_train,
|
f = client.submit(
|
||||||
|
dispatched_train,
|
||||||
worker_addr,
|
worker_addr,
|
||||||
_rabit_args,
|
_rabit_args,
|
||||||
dtrain.create_fn_args(workers[i]),
|
# pylint: disable=protected-access
|
||||||
|
dtrain._create_fn_args(workers[i]),
|
||||||
id(dtrain),
|
id(dtrain),
|
||||||
evals_per_worker,
|
evals_per_worker,
|
||||||
pure=False,
|
pure=False,
|
||||||
workers=[worker_addr])
|
workers=[worker_addr]
|
||||||
|
)
|
||||||
futures.append(f)
|
futures.append(f)
|
||||||
|
|
||||||
results = await client.gather(futures)
|
results = await client.gather(futures)
|
||||||
@ -1755,7 +1760,11 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
|||||||
|
|
||||||
|
|
||||||
@xgboost_model_doc(
|
@xgboost_model_doc(
|
||||||
"Implementation of the Scikit-Learn API for XGBoost Ranking.",
|
"""Implementation of the Scikit-Learn API for XGBoost Ranking.
|
||||||
|
|
||||||
|
.. versionadded:: 1.4.0
|
||||||
|
|
||||||
|
""",
|
||||||
["estimators", "model"],
|
["estimators", "model"],
|
||||||
end_note="""
|
end_note="""
|
||||||
Note
|
Note
|
||||||
@ -1868,7 +1877,11 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
|
|||||||
|
|
||||||
|
|
||||||
@xgboost_model_doc(
|
@xgboost_model_doc(
|
||||||
"Implementation of the Scikit-Learn API for XGBoost Random Forest Regressor.",
|
"""Implementation of the Scikit-Learn API for XGBoost Random Forest Regressor.
|
||||||
|
|
||||||
|
.. versionadded:: 1.4.0
|
||||||
|
|
||||||
|
""",
|
||||||
["model", "objective"],
|
["model", "objective"],
|
||||||
extra_parameters="""
|
extra_parameters="""
|
||||||
n_estimators : int
|
n_estimators : int
|
||||||
@ -1904,7 +1917,11 @@ class DaskXGBRFRegressor(DaskXGBRegressor):
|
|||||||
|
|
||||||
|
|
||||||
@xgboost_model_doc(
|
@xgboost_model_doc(
|
||||||
"Implementation of the Scikit-Learn API for XGBoost Random Forest Classifier.",
|
"""Implementation of the Scikit-Learn API for XGBoost Random Forest Classifier.
|
||||||
|
|
||||||
|
.. versionadded:: 1.4.0
|
||||||
|
|
||||||
|
""",
|
||||||
["model", "objective"],
|
["model", "objective"],
|
||||||
extra_parameters="""
|
extra_parameters="""
|
||||||
n_estimators : int
|
n_estimators : int
|
||||||
|
|||||||
@ -304,9 +304,15 @@ class TestDistributedGPU:
|
|||||||
|
|
||||||
futures = []
|
futures = []
|
||||||
for i in range(len(workers)):
|
for i in range(len(workers)):
|
||||||
futures.append(client.submit(worker_fn, workers[i],
|
futures.append(
|
||||||
m.create_fn_args(workers[i]), pure=False,
|
client.submit(
|
||||||
workers=[workers[i]]))
|
worker_fn,
|
||||||
|
workers[i],
|
||||||
|
m._create_fn_args(workers[i]),
|
||||||
|
pure=False,
|
||||||
|
workers=[workers[i]]
|
||||||
|
)
|
||||||
|
)
|
||||||
client.gather(futures)
|
client.gather(futures)
|
||||||
|
|
||||||
def test_interface_consistency(self) -> None:
|
def test_interface_consistency(self) -> None:
|
||||||
|
|||||||
@ -1046,9 +1046,12 @@ class TestWithDask:
|
|||||||
|
|
||||||
futures = []
|
futures = []
|
||||||
for i in range(len(workers)):
|
for i in range(len(workers)):
|
||||||
futures.append(client.submit(worker_fn, workers[i],
|
futures.append(
|
||||||
m.create_fn_args(workers[i]), pure=False,
|
client.submit(
|
||||||
workers=[workers[i]]))
|
worker_fn, workers[i],
|
||||||
|
m._create_fn_args(workers[i]), pure=False,
|
||||||
|
workers=[workers[i]])
|
||||||
|
)
|
||||||
client.gather(futures)
|
client.gather(futures)
|
||||||
|
|
||||||
has_what = client.has_what()
|
has_what = client.has_what()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user