[dask] Use nthread in DMatrix construction. (#7337)
This is consistent with the thread overriding behavior.
This commit is contained in:
parent
b8e8f0fcd9
commit
f999897615
@ -674,6 +674,7 @@ def _create_device_quantile_dmatrix(
|
||||
feature_weights: Optional[Any],
|
||||
meta_names: List[str],
|
||||
missing: float,
|
||||
nthread: int,
|
||||
parts: Optional[_DataParts],
|
||||
max_bin: int,
|
||||
enable_categorical: bool,
|
||||
@ -717,7 +718,7 @@ def _create_device_quantile_dmatrix(
|
||||
missing=missing,
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types,
|
||||
nthread=worker.nthreads,
|
||||
nthread=nthread,
|
||||
max_bin=max_bin,
|
||||
enable_categorical=enable_categorical,
|
||||
)
|
||||
@ -731,6 +732,7 @@ def _create_dmatrix(
|
||||
feature_weights: Optional[Any],
|
||||
meta_names: List[str],
|
||||
missing: float,
|
||||
nthread: int,
|
||||
enable_categorical: bool,
|
||||
parts: Optional[_DataParts]
|
||||
) -> DMatrix:
|
||||
@ -778,7 +780,7 @@ def _create_dmatrix(
|
||||
missing=missing,
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types,
|
||||
nthread=worker.nthreads,
|
||||
nthread=nthread,
|
||||
enable_categorical=enable_categorical,
|
||||
)
|
||||
dmatrix.set_info(
|
||||
@ -856,46 +858,53 @@ async def _train_async(
|
||||
rabit_args: List[bytes],
|
||||
dtrain_ref: Dict,
|
||||
dtrain_idt: int,
|
||||
evals_ref: Dict
|
||||
evals_ref: List[Tuple[Dict, str, int]],
|
||||
) -> Optional[Dict[str, Union[Booster, Dict]]]:
|
||||
'''Perform training on a single worker. A local function prevents pickling.
|
||||
|
||||
'''
|
||||
LOGGER.debug('Training on %s', str(worker_addr))
|
||||
"""Perform training on a single worker. A local function prevents pickling."""
|
||||
LOGGER.debug("Training on %s", str(worker_addr))
|
||||
worker = distributed.get_worker()
|
||||
|
||||
n_threads: int = 0
|
||||
local_param = params.copy()
|
||||
for p in ["nthread", "n_jobs"]:
|
||||
if local_param.get(p, worker.nthreads) != worker.nthreads:
|
||||
LOGGER.info("Overriding `nthreads` defined in dask worker.")
|
||||
n_threads = local_param[p]
|
||||
break
|
||||
if n_threads == 0:
|
||||
n_threads = worker.nthreads
|
||||
local_param.update({"nthread": n_threads, "n_jobs": n_threads})
|
||||
|
||||
with RabitContext(rabit_args), config.config_context(**global_config):
|
||||
local_dtrain = _dmatrix_from_list_of_parts(**dtrain_ref)
|
||||
local_dtrain = _dmatrix_from_list_of_parts(**dtrain_ref, nthread=n_threads)
|
||||
local_evals = []
|
||||
if evals_ref:
|
||||
for ref, name, idt in evals_ref:
|
||||
if idt == dtrain_idt:
|
||||
local_evals.append((local_dtrain, name))
|
||||
continue
|
||||
local_evals.append((_dmatrix_from_list_of_parts(**ref), name))
|
||||
local_evals.append(
|
||||
(_dmatrix_from_list_of_parts(**ref, nthread=n_threads), name)
|
||||
)
|
||||
|
||||
local_history: Dict = {}
|
||||
local_param = params.copy() # just to be consistent
|
||||
msg = 'Overriding `nthreads` defined in dask worker.'
|
||||
override = ['nthread', 'n_jobs']
|
||||
for p in override:
|
||||
val = local_param.get(p, None)
|
||||
if val is not None and val != worker.nthreads:
|
||||
LOGGER.info(msg)
|
||||
else:
|
||||
local_param[p] = worker.nthreads
|
||||
bst = worker_train(params=local_param,
|
||||
dtrain=local_dtrain,
|
||||
num_boost_round=num_boost_round,
|
||||
evals_result=local_history,
|
||||
evals=local_evals,
|
||||
obj=obj,
|
||||
feval=feval,
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
verbose_eval=verbose_eval,
|
||||
xgb_model=xgb_model,
|
||||
callbacks=callbacks)
|
||||
bst = worker_train(
|
||||
params=local_param,
|
||||
dtrain=local_dtrain,
|
||||
num_boost_round=num_boost_round,
|
||||
evals_result=local_history,
|
||||
evals=local_evals,
|
||||
obj=obj,
|
||||
feval=feval,
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
verbose_eval=verbose_eval,
|
||||
xgb_model=xgb_model,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
ret: Optional[Dict[str, Union[Booster, Dict]]] = {
|
||||
'booster': bst, 'history': local_history}
|
||||
"booster": bst,
|
||||
"history": local_history,
|
||||
}
|
||||
if local_dtrain.num_row() == 0:
|
||||
ret = None
|
||||
return ret
|
||||
@ -924,7 +933,7 @@ async def _train_async(
|
||||
evals_per_worker,
|
||||
pure=False,
|
||||
workers=[worker_addr],
|
||||
allow_other_workers=False
|
||||
allow_other_workers=False,
|
||||
)
|
||||
futures.append(f)
|
||||
|
||||
|
||||
@ -434,7 +434,7 @@ class TestDistributedGPU:
|
||||
|
||||
def worker_fn(worker_addr: str, data_ref: Dict) -> None:
|
||||
with dxgb.RabitContext(rabit_args):
|
||||
local_dtrain = dxgb._dmatrix_from_list_of_parts(**data_ref)
|
||||
local_dtrain = dxgb._dmatrix_from_list_of_parts(**data_ref, nthread=7)
|
||||
fw_rows = local_dtrain.get_float_info("feature_weights").shape[0]
|
||||
assert fw_rows == local_dtrain.num_col()
|
||||
|
||||
|
||||
@ -1275,7 +1275,9 @@ class TestWithDask:
|
||||
|
||||
def worker_fn(worker_addr: str, data_ref: Dict) -> None:
|
||||
with xgb.dask.RabitContext(rabit_args):
|
||||
local_dtrain = xgb.dask._dmatrix_from_list_of_parts(**data_ref)
|
||||
local_dtrain = xgb.dask._dmatrix_from_list_of_parts(
|
||||
**data_ref, nthread=7
|
||||
)
|
||||
total = np.array([local_dtrain.num_row()])
|
||||
total = xgb.rabit.allreduce(total, xgb.rabit.Op.SUM)
|
||||
assert total[0] == kRows
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user