Cleanup code for distributed training. (#9805)
* Cleanup code for distributed training. - Merge `GetNcclResult` into nccl stub. - Split up utilities from the main dask module. - Let Channel return `Result` to accommodate nccl channel. - Remove old `use_label_encoder` parameter.
This commit is contained in:
@@ -94,6 +94,8 @@ from xgboost.sklearn import (
|
||||
from xgboost.tracker import RabitTracker, get_host_ip
|
||||
from xgboost.training import train as worker_train
|
||||
|
||||
from .utils import get_n_threads
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import dask
|
||||
import distributed
|
||||
@@ -908,6 +910,34 @@ async def _check_workers_are_alive(
|
||||
raise RuntimeError(f"Missing required workers: {missing_workers}")
|
||||
|
||||
|
||||
def _get_dmatrices(
|
||||
train_ref: dict,
|
||||
train_id: int,
|
||||
*refs: dict,
|
||||
evals_id: Sequence[int],
|
||||
evals_name: Sequence[str],
|
||||
n_threads: int,
|
||||
) -> Tuple[DMatrix, List[Tuple[DMatrix, str]]]:
|
||||
Xy = _dmatrix_from_list_of_parts(**train_ref, nthread=n_threads)
|
||||
evals: List[Tuple[DMatrix, str]] = []
|
||||
for i, ref in enumerate(refs):
|
||||
if evals_id[i] == train_id:
|
||||
evals.append((Xy, evals_name[i]))
|
||||
continue
|
||||
if ref.get("ref", None) is not None:
|
||||
if ref["ref"] != train_id:
|
||||
raise ValueError(
|
||||
"The training DMatrix should be used as a reference to evaluation"
|
||||
" `QuantileDMatrix`."
|
||||
)
|
||||
del ref["ref"]
|
||||
eval_Xy = _dmatrix_from_list_of_parts(**ref, nthread=n_threads, ref=Xy)
|
||||
else:
|
||||
eval_Xy = _dmatrix_from_list_of_parts(**ref, nthread=n_threads)
|
||||
evals.append((eval_Xy, evals_name[i]))
|
||||
return Xy, evals
|
||||
|
||||
|
||||
async def _train_async(
|
||||
client: "distributed.Client",
|
||||
global_config: Dict[str, Any],
|
||||
@@ -940,41 +970,20 @@ async def _train_async(
|
||||
) -> Optional[TrainReturnT]:
|
||||
worker = distributed.get_worker()
|
||||
local_param = parameters.copy()
|
||||
n_threads = 0
|
||||
# dask worker nthreads, "state" is available in 2022.6.1
|
||||
dwnt = worker.state.nthreads if hasattr(worker, "state") else worker.nthreads
|
||||
for p in ["nthread", "n_jobs"]:
|
||||
if (
|
||||
local_param.get(p, None) is not None
|
||||
and local_param.get(p, dwnt) != dwnt
|
||||
):
|
||||
LOGGER.info("Overriding `nthreads` defined in dask worker.")
|
||||
n_threads = local_param[p]
|
||||
break
|
||||
if n_threads == 0 or n_threads is None:
|
||||
n_threads = dwnt
|
||||
n_threads = get_n_threads(local_param, worker)
|
||||
local_param.update({"nthread": n_threads, "n_jobs": n_threads})
|
||||
|
||||
local_history: TrainingCallback.EvalsLog = {}
|
||||
|
||||
with CommunicatorContext(**rabit_args), config.config_context(**global_config):
|
||||
Xy = _dmatrix_from_list_of_parts(**train_ref, nthread=n_threads)
|
||||
evals: List[Tuple[DMatrix, str]] = []
|
||||
for i, ref in enumerate(refs):
|
||||
if evals_id[i] == train_id:
|
||||
evals.append((Xy, evals_name[i]))
|
||||
continue
|
||||
if ref.get("ref", None) is not None:
|
||||
if ref["ref"] != train_id:
|
||||
raise ValueError(
|
||||
"The training DMatrix should be used as a reference"
|
||||
" to evaluation `QuantileDMatrix`."
|
||||
)
|
||||
del ref["ref"]
|
||||
eval_Xy = _dmatrix_from_list_of_parts(
|
||||
**ref, nthread=n_threads, ref=Xy
|
||||
)
|
||||
else:
|
||||
eval_Xy = _dmatrix_from_list_of_parts(**ref, nthread=n_threads)
|
||||
evals.append((eval_Xy, evals_name[i]))
|
||||
Xy, evals = _get_dmatrices(
|
||||
train_ref,
|
||||
train_id,
|
||||
*refs,
|
||||
evals_id=evals_id,
|
||||
evals_name=evals_name,
|
||||
n_threads=n_threads,
|
||||
)
|
||||
|
||||
booster = worker_train(
|
||||
params=local_param,
|
||||
|
||||
24
python-package/xgboost/dask/utils.py
Normal file
24
python-package/xgboost/dask/utils.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""Utilities for the XGBoost Dask interface."""
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
|
||||
LOGGER = logging.getLogger("[xgboost.dask]")
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import distributed
|
||||
|
||||
|
||||
def get_n_threads(local_param: Dict[str, Any], worker: "distributed.Worker") -> int:
|
||||
"""Get the number of threads from a worker and the user-supplied parameters."""
|
||||
# dask worker nthreads, "state" is available in 2022.6.1
|
||||
dwnt = worker.state.nthreads if hasattr(worker, "state") else worker.nthreads
|
||||
n_threads = None
|
||||
for p in ["nthread", "n_jobs"]:
|
||||
if local_param.get(p, None) is not None and local_param.get(p, dwnt) != dwnt:
|
||||
LOGGER.info("Overriding `nthreads` defined in dask worker.")
|
||||
n_threads = local_param[p]
|
||||
break
|
||||
if n_threads == 0 or n_threads is None:
|
||||
n_threads = dwnt
|
||||
return n_threads
|
||||
@@ -808,7 +808,6 @@ class XGBModel(XGBModelBase):
|
||||
"kwargs",
|
||||
"missing",
|
||||
"n_estimators",
|
||||
"use_label_encoder",
|
||||
"enable_categorical",
|
||||
"early_stopping_rounds",
|
||||
"callbacks",
|
||||
|
||||
@@ -138,7 +138,6 @@ _inverse_pyspark_param_alias_map = {v: k for k, v in _pyspark_param_alias_map.it
|
||||
_unsupported_xgb_params = [
|
||||
"gpu_id", # we have "device" pyspark param instead.
|
||||
"enable_categorical", # Use feature_types param to specify categorical feature instead
|
||||
"use_label_encoder",
|
||||
"n_jobs", # Do not allow user to set it, will use `spark.task.cpus` value instead.
|
||||
"nthread", # Ditto
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user