[dask] Use nthread in DMatrix construction. (#7337)

This is consistent with the thread overriding behavior.
This commit is contained in:
Jiaming Yuan 2021-10-20 15:16:40 +08:00 committed by GitHub
parent b8e8f0fcd9
commit f999897615
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 44 additions and 33 deletions

View File

@ -674,6 +674,7 @@ def _create_device_quantile_dmatrix(
feature_weights: Optional[Any], feature_weights: Optional[Any],
meta_names: List[str], meta_names: List[str],
missing: float, missing: float,
nthread: int,
parts: Optional[_DataParts], parts: Optional[_DataParts],
max_bin: int, max_bin: int,
enable_categorical: bool, enable_categorical: bool,
@ -717,7 +718,7 @@ def _create_device_quantile_dmatrix(
missing=missing, missing=missing,
feature_names=feature_names, feature_names=feature_names,
feature_types=feature_types, feature_types=feature_types,
nthread=worker.nthreads, nthread=nthread,
max_bin=max_bin, max_bin=max_bin,
enable_categorical=enable_categorical, enable_categorical=enable_categorical,
) )
@ -731,6 +732,7 @@ def _create_dmatrix(
feature_weights: Optional[Any], feature_weights: Optional[Any],
meta_names: List[str], meta_names: List[str],
missing: float, missing: float,
nthread: int,
enable_categorical: bool, enable_categorical: bool,
parts: Optional[_DataParts] parts: Optional[_DataParts]
) -> DMatrix: ) -> DMatrix:
@ -778,7 +780,7 @@ def _create_dmatrix(
missing=missing, missing=missing,
feature_names=feature_names, feature_names=feature_names,
feature_types=feature_types, feature_types=feature_types,
nthread=worker.nthreads, nthread=nthread,
enable_categorical=enable_categorical, enable_categorical=enable_categorical,
) )
dmatrix.set_info( dmatrix.set_info(
@ -856,46 +858,53 @@ async def _train_async(
rabit_args: List[bytes], rabit_args: List[bytes],
dtrain_ref: Dict, dtrain_ref: Dict,
dtrain_idt: int, dtrain_idt: int,
evals_ref: Dict evals_ref: List[Tuple[Dict, str, int]],
) -> Optional[Dict[str, Union[Booster, Dict]]]: ) -> Optional[Dict[str, Union[Booster, Dict]]]:
'''Perform training on a single worker. A local function prevents pickling. """Perform training on a single worker. A local function prevents pickling."""
LOGGER.debug("Training on %s", str(worker_addr))
'''
LOGGER.debug('Training on %s', str(worker_addr))
worker = distributed.get_worker() worker = distributed.get_worker()
n_threads: int = 0
local_param = params.copy()
for p in ["nthread", "n_jobs"]:
if local_param.get(p, worker.nthreads) != worker.nthreads:
LOGGER.info("Overriding `nthreads` defined in dask worker.")
n_threads = local_param[p]
break
if n_threads == 0:
n_threads = worker.nthreads
local_param.update({"nthread": n_threads, "n_jobs": n_threads})
with RabitContext(rabit_args), config.config_context(**global_config): with RabitContext(rabit_args), config.config_context(**global_config):
local_dtrain = _dmatrix_from_list_of_parts(**dtrain_ref) local_dtrain = _dmatrix_from_list_of_parts(**dtrain_ref, nthread=n_threads)
local_evals = [] local_evals = []
if evals_ref: if evals_ref:
for ref, name, idt in evals_ref: for ref, name, idt in evals_ref:
if idt == dtrain_idt: if idt == dtrain_idt:
local_evals.append((local_dtrain, name)) local_evals.append((local_dtrain, name))
continue continue
local_evals.append((_dmatrix_from_list_of_parts(**ref), name)) local_evals.append(
(_dmatrix_from_list_of_parts(**ref, nthread=n_threads), name)
)
local_history: Dict = {} local_history: Dict = {}
local_param = params.copy() # just to be consistent bst = worker_train(
msg = 'Overriding `nthreads` defined in dask worker.' params=local_param,
override = ['nthread', 'n_jobs'] dtrain=local_dtrain,
for p in override: num_boost_round=num_boost_round,
val = local_param.get(p, None) evals_result=local_history,
if val is not None and val != worker.nthreads: evals=local_evals,
LOGGER.info(msg) obj=obj,
else: feval=feval,
local_param[p] = worker.nthreads early_stopping_rounds=early_stopping_rounds,
bst = worker_train(params=local_param, verbose_eval=verbose_eval,
dtrain=local_dtrain, xgb_model=xgb_model,
num_boost_round=num_boost_round, callbacks=callbacks,
evals_result=local_history, )
evals=local_evals,
obj=obj,
feval=feval,
early_stopping_rounds=early_stopping_rounds,
verbose_eval=verbose_eval,
xgb_model=xgb_model,
callbacks=callbacks)
ret: Optional[Dict[str, Union[Booster, Dict]]] = { ret: Optional[Dict[str, Union[Booster, Dict]]] = {
'booster': bst, 'history': local_history} "booster": bst,
"history": local_history,
}
if local_dtrain.num_row() == 0: if local_dtrain.num_row() == 0:
ret = None ret = None
return ret return ret
@ -924,7 +933,7 @@ async def _train_async(
evals_per_worker, evals_per_worker,
pure=False, pure=False,
workers=[worker_addr], workers=[worker_addr],
allow_other_workers=False allow_other_workers=False,
) )
futures.append(f) futures.append(f)

View File

@ -434,7 +434,7 @@ class TestDistributedGPU:
def worker_fn(worker_addr: str, data_ref: Dict) -> None: def worker_fn(worker_addr: str, data_ref: Dict) -> None:
with dxgb.RabitContext(rabit_args): with dxgb.RabitContext(rabit_args):
local_dtrain = dxgb._dmatrix_from_list_of_parts(**data_ref) local_dtrain = dxgb._dmatrix_from_list_of_parts(**data_ref, nthread=7)
fw_rows = local_dtrain.get_float_info("feature_weights").shape[0] fw_rows = local_dtrain.get_float_info("feature_weights").shape[0]
assert fw_rows == local_dtrain.num_col() assert fw_rows == local_dtrain.num_col()

View File

@ -1275,7 +1275,9 @@ class TestWithDask:
def worker_fn(worker_addr: str, data_ref: Dict) -> None: def worker_fn(worker_addr: str, data_ref: Dict) -> None:
with xgb.dask.RabitContext(rabit_args): with xgb.dask.RabitContext(rabit_args):
local_dtrain = xgb.dask._dmatrix_from_list_of_parts(**data_ref) local_dtrain = xgb.dask._dmatrix_from_list_of_parts(
**data_ref, nthread=7
)
total = np.array([local_dtrain.num_row()]) total = np.array([local_dtrain.num_row()])
total = xgb.rabit.allreduce(total, xgb.rabit.Op.SUM) total = xgb.rabit.allreduce(total, xgb.rabit.Op.SUM)
assert total[0] == kRows assert total[0] == kRows