Cleanup code for distributed training. (#9805)

* Cleanup code for distributed training.

- Merge `GetNcclResult` into nccl stub.
- Split up utilities from the main dask module.
- Let Channel return `Result` to accommodate nccl channel.
- Remove old `use_label_encoder` parameter.
This commit is contained in:
Jiaming Yuan
2023-11-25 09:10:56 +08:00
committed by GitHub
parent e9260de3f3
commit 8fe1a2213c
19 changed files with 221 additions and 192 deletions

View File

@@ -94,6 +94,8 @@ from xgboost.sklearn import (
from xgboost.tracker import RabitTracker, get_host_ip
from xgboost.training import train as worker_train
from .utils import get_n_threads
if TYPE_CHECKING:
import dask
import distributed
@@ -908,6 +910,34 @@ async def _check_workers_are_alive(
raise RuntimeError(f"Missing required workers: {missing_workers}")
def _get_dmatrices(
train_ref: dict,
train_id: int,
*refs: dict,
evals_id: Sequence[int],
evals_name: Sequence[str],
n_threads: int,
) -> Tuple[DMatrix, List[Tuple[DMatrix, str]]]:
Xy = _dmatrix_from_list_of_parts(**train_ref, nthread=n_threads)
evals: List[Tuple[DMatrix, str]] = []
for i, ref in enumerate(refs):
if evals_id[i] == train_id:
evals.append((Xy, evals_name[i]))
continue
if ref.get("ref", None) is not None:
if ref["ref"] != train_id:
raise ValueError(
"The training DMatrix should be used as a reference to evaluation"
" `QuantileDMatrix`."
)
del ref["ref"]
eval_Xy = _dmatrix_from_list_of_parts(**ref, nthread=n_threads, ref=Xy)
else:
eval_Xy = _dmatrix_from_list_of_parts(**ref, nthread=n_threads)
evals.append((eval_Xy, evals_name[i]))
return Xy, evals
async def _train_async(
client: "distributed.Client",
global_config: Dict[str, Any],
@@ -940,41 +970,20 @@ async def _train_async(
) -> Optional[TrainReturnT]:
worker = distributed.get_worker()
local_param = parameters.copy()
n_threads = 0
# dask worker nthreads, "state" is available in 2022.6.1
dwnt = worker.state.nthreads if hasattr(worker, "state") else worker.nthreads
for p in ["nthread", "n_jobs"]:
if (
local_param.get(p, None) is not None
and local_param.get(p, dwnt) != dwnt
):
LOGGER.info("Overriding `nthreads` defined in dask worker.")
n_threads = local_param[p]
break
if n_threads == 0 or n_threads is None:
n_threads = dwnt
n_threads = get_n_threads(local_param, worker)
local_param.update({"nthread": n_threads, "n_jobs": n_threads})
local_history: TrainingCallback.EvalsLog = {}
with CommunicatorContext(**rabit_args), config.config_context(**global_config):
Xy = _dmatrix_from_list_of_parts(**train_ref, nthread=n_threads)
evals: List[Tuple[DMatrix, str]] = []
for i, ref in enumerate(refs):
if evals_id[i] == train_id:
evals.append((Xy, evals_name[i]))
continue
if ref.get("ref", None) is not None:
if ref["ref"] != train_id:
raise ValueError(
"The training DMatrix should be used as a reference"
" to evaluation `QuantileDMatrix`."
)
del ref["ref"]
eval_Xy = _dmatrix_from_list_of_parts(
**ref, nthread=n_threads, ref=Xy
)
else:
eval_Xy = _dmatrix_from_list_of_parts(**ref, nthread=n_threads)
evals.append((eval_Xy, evals_name[i]))
Xy, evals = _get_dmatrices(
train_ref,
train_id,
*refs,
evals_id=evals_id,
evals_name=evals_name,
n_threads=n_threads,
)
booster = worker_train(
params=local_param,

View File

@@ -0,0 +1,24 @@
"""Utilities for the XGBoost Dask interface."""
import logging
from typing import TYPE_CHECKING, Any, Dict
LOGGER = logging.getLogger("[xgboost.dask]")
if TYPE_CHECKING:
import distributed
def get_n_threads(local_param: Dict[str, Any], worker: "distributed.Worker") -> int:
"""Get the number of threads from a worker and the user-supplied parameters."""
# dask worker nthreads, "state" is available in 2022.6.1
dwnt = worker.state.nthreads if hasattr(worker, "state") else worker.nthreads
n_threads = None
for p in ["nthread", "n_jobs"]:
if local_param.get(p, None) is not None and local_param.get(p, dwnt) != dwnt:
LOGGER.info("Overriding `nthreads` defined in dask worker.")
n_threads = local_param[p]
break
if n_threads == 0 or n_threads is None:
n_threads = dwnt
return n_threads

View File

@@ -808,7 +808,6 @@ class XGBModel(XGBModelBase):
"kwargs",
"missing",
"n_estimators",
"use_label_encoder",
"enable_categorical",
"early_stopping_rounds",
"callbacks",

View File

@@ -138,7 +138,6 @@ _inverse_pyspark_param_alias_map = {v: k for k, v in _pyspark_param_alias_map.it
_unsupported_xgb_params = [
"gpu_id", # we have "device" pyspark param instead.
"enable_categorical", # Use feature_types param to specify categorical feature instead
"use_label_encoder",
"n_jobs", # Do not allow user to set it, will use `spark.task.cpus` value instead.
"nthread", # Ditto
]