[dask] Use nthread in DMatrix construction. (#7337)

This is consistent with the thread overriding behavior.
This commit is contained in:
Jiaming Yuan 2021-10-20 15:16:40 +08:00 committed by GitHub
parent b8e8f0fcd9
commit f999897615
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 44 additions and 33 deletions

View File

@ -674,6 +674,7 @@ def _create_device_quantile_dmatrix(
feature_weights: Optional[Any],
meta_names: List[str],
missing: float,
nthread: int,
parts: Optional[_DataParts],
max_bin: int,
enable_categorical: bool,
@ -717,7 +718,7 @@ def _create_device_quantile_dmatrix(
missing=missing,
feature_names=feature_names,
feature_types=feature_types,
nthread=worker.nthreads,
nthread=nthread,
max_bin=max_bin,
enable_categorical=enable_categorical,
)
@ -731,6 +732,7 @@ def _create_dmatrix(
feature_weights: Optional[Any],
meta_names: List[str],
missing: float,
nthread: int,
enable_categorical: bool,
parts: Optional[_DataParts]
) -> DMatrix:
@ -778,7 +780,7 @@ def _create_dmatrix(
missing=missing,
feature_names=feature_names,
feature_types=feature_types,
nthread=worker.nthreads,
nthread=nthread,
enable_categorical=enable_categorical,
)
dmatrix.set_info(
@ -856,46 +858,53 @@ async def _train_async(
rabit_args: List[bytes],
dtrain_ref: Dict,
dtrain_idt: int,
evals_ref: Dict
evals_ref: List[Tuple[Dict, str, int]],
) -> Optional[Dict[str, Union[Booster, Dict]]]:
'''Perform training on a single worker. A local function prevents pickling.
'''
LOGGER.debug('Training on %s', str(worker_addr))
"""Perform training on a single worker. A local function prevents pickling."""
LOGGER.debug("Training on %s", str(worker_addr))
worker = distributed.get_worker()
n_threads: int = 0
local_param = params.copy()
for p in ["nthread", "n_jobs"]:
if local_param.get(p, worker.nthreads) != worker.nthreads:
LOGGER.info("Overriding `nthreads` defined in dask worker.")
n_threads = local_param[p]
break
if n_threads == 0:
n_threads = worker.nthreads
local_param.update({"nthread": n_threads, "n_jobs": n_threads})
with RabitContext(rabit_args), config.config_context(**global_config):
local_dtrain = _dmatrix_from_list_of_parts(**dtrain_ref)
local_dtrain = _dmatrix_from_list_of_parts(**dtrain_ref, nthread=n_threads)
local_evals = []
if evals_ref:
for ref, name, idt in evals_ref:
if idt == dtrain_idt:
local_evals.append((local_dtrain, name))
continue
local_evals.append((_dmatrix_from_list_of_parts(**ref), name))
local_evals.append(
(_dmatrix_from_list_of_parts(**ref, nthread=n_threads), name)
)
local_history: Dict = {}
local_param = params.copy() # just to be consistent
msg = 'Overriding `nthreads` defined in dask worker.'
override = ['nthread', 'n_jobs']
for p in override:
val = local_param.get(p, None)
if val is not None and val != worker.nthreads:
LOGGER.info(msg)
else:
local_param[p] = worker.nthreads
bst = worker_train(params=local_param,
dtrain=local_dtrain,
num_boost_round=num_boost_round,
evals_result=local_history,
evals=local_evals,
obj=obj,
feval=feval,
early_stopping_rounds=early_stopping_rounds,
verbose_eval=verbose_eval,
xgb_model=xgb_model,
callbacks=callbacks)
bst = worker_train(
params=local_param,
dtrain=local_dtrain,
num_boost_round=num_boost_round,
evals_result=local_history,
evals=local_evals,
obj=obj,
feval=feval,
early_stopping_rounds=early_stopping_rounds,
verbose_eval=verbose_eval,
xgb_model=xgb_model,
callbacks=callbacks,
)
ret: Optional[Dict[str, Union[Booster, Dict]]] = {
'booster': bst, 'history': local_history}
"booster": bst,
"history": local_history,
}
if local_dtrain.num_row() == 0:
ret = None
return ret
@ -924,7 +933,7 @@ async def _train_async(
evals_per_worker,
pure=False,
workers=[worker_addr],
allow_other_workers=False
allow_other_workers=False,
)
futures.append(f)

View File

@ -434,7 +434,7 @@ class TestDistributedGPU:
def worker_fn(worker_addr: str, data_ref: Dict) -> None:
with dxgb.RabitContext(rabit_args):
local_dtrain = dxgb._dmatrix_from_list_of_parts(**data_ref)
local_dtrain = dxgb._dmatrix_from_list_of_parts(**data_ref, nthread=7)
fw_rows = local_dtrain.get_float_info("feature_weights").shape[0]
assert fw_rows == local_dtrain.num_col()

View File

@ -1275,7 +1275,9 @@ class TestWithDask:
def worker_fn(worker_addr: str, data_ref: Dict) -> None:
with xgb.dask.RabitContext(rabit_args):
local_dtrain = xgb.dask._dmatrix_from_list_of_parts(**data_ref)
local_dtrain = xgb.dask._dmatrix_from_list_of_parts(
**data_ref, nthread=7
)
total = np.array([local_dtrain.num_row()])
total = xgb.rabit.allreduce(total, xgb.rabit.Op.SUM)
assert total[0] == kRows