[dask] Use nthread in DMatrix construction. (#7337)
This is consistent with the thread overriding behavior.
This commit is contained in:
parent
b8e8f0fcd9
commit
f999897615
@ -674,6 +674,7 @@ def _create_device_quantile_dmatrix(
|
|||||||
feature_weights: Optional[Any],
|
feature_weights: Optional[Any],
|
||||||
meta_names: List[str],
|
meta_names: List[str],
|
||||||
missing: float,
|
missing: float,
|
||||||
|
nthread: int,
|
||||||
parts: Optional[_DataParts],
|
parts: Optional[_DataParts],
|
||||||
max_bin: int,
|
max_bin: int,
|
||||||
enable_categorical: bool,
|
enable_categorical: bool,
|
||||||
@ -717,7 +718,7 @@ def _create_device_quantile_dmatrix(
|
|||||||
missing=missing,
|
missing=missing,
|
||||||
feature_names=feature_names,
|
feature_names=feature_names,
|
||||||
feature_types=feature_types,
|
feature_types=feature_types,
|
||||||
nthread=worker.nthreads,
|
nthread=nthread,
|
||||||
max_bin=max_bin,
|
max_bin=max_bin,
|
||||||
enable_categorical=enable_categorical,
|
enable_categorical=enable_categorical,
|
||||||
)
|
)
|
||||||
@ -731,6 +732,7 @@ def _create_dmatrix(
|
|||||||
feature_weights: Optional[Any],
|
feature_weights: Optional[Any],
|
||||||
meta_names: List[str],
|
meta_names: List[str],
|
||||||
missing: float,
|
missing: float,
|
||||||
|
nthread: int,
|
||||||
enable_categorical: bool,
|
enable_categorical: bool,
|
||||||
parts: Optional[_DataParts]
|
parts: Optional[_DataParts]
|
||||||
) -> DMatrix:
|
) -> DMatrix:
|
||||||
@ -778,7 +780,7 @@ def _create_dmatrix(
|
|||||||
missing=missing,
|
missing=missing,
|
||||||
feature_names=feature_names,
|
feature_names=feature_names,
|
||||||
feature_types=feature_types,
|
feature_types=feature_types,
|
||||||
nthread=worker.nthreads,
|
nthread=nthread,
|
||||||
enable_categorical=enable_categorical,
|
enable_categorical=enable_categorical,
|
||||||
)
|
)
|
||||||
dmatrix.set_info(
|
dmatrix.set_info(
|
||||||
@ -856,46 +858,53 @@ async def _train_async(
|
|||||||
rabit_args: List[bytes],
|
rabit_args: List[bytes],
|
||||||
dtrain_ref: Dict,
|
dtrain_ref: Dict,
|
||||||
dtrain_idt: int,
|
dtrain_idt: int,
|
||||||
evals_ref: Dict
|
evals_ref: List[Tuple[Dict, str, int]],
|
||||||
) -> Optional[Dict[str, Union[Booster, Dict]]]:
|
) -> Optional[Dict[str, Union[Booster, Dict]]]:
|
||||||
'''Perform training on a single worker. A local function prevents pickling.
|
"""Perform training on a single worker. A local function prevents pickling."""
|
||||||
|
LOGGER.debug("Training on %s", str(worker_addr))
|
||||||
'''
|
|
||||||
LOGGER.debug('Training on %s', str(worker_addr))
|
|
||||||
worker = distributed.get_worker()
|
worker = distributed.get_worker()
|
||||||
|
|
||||||
|
n_threads: int = 0
|
||||||
|
local_param = params.copy()
|
||||||
|
for p in ["nthread", "n_jobs"]:
|
||||||
|
if local_param.get(p, worker.nthreads) != worker.nthreads:
|
||||||
|
LOGGER.info("Overriding `nthreads` defined in dask worker.")
|
||||||
|
n_threads = local_param[p]
|
||||||
|
break
|
||||||
|
if n_threads == 0:
|
||||||
|
n_threads = worker.nthreads
|
||||||
|
local_param.update({"nthread": n_threads, "n_jobs": n_threads})
|
||||||
|
|
||||||
with RabitContext(rabit_args), config.config_context(**global_config):
|
with RabitContext(rabit_args), config.config_context(**global_config):
|
||||||
local_dtrain = _dmatrix_from_list_of_parts(**dtrain_ref)
|
local_dtrain = _dmatrix_from_list_of_parts(**dtrain_ref, nthread=n_threads)
|
||||||
local_evals = []
|
local_evals = []
|
||||||
if evals_ref:
|
if evals_ref:
|
||||||
for ref, name, idt in evals_ref:
|
for ref, name, idt in evals_ref:
|
||||||
if idt == dtrain_idt:
|
if idt == dtrain_idt:
|
||||||
local_evals.append((local_dtrain, name))
|
local_evals.append((local_dtrain, name))
|
||||||
continue
|
continue
|
||||||
local_evals.append((_dmatrix_from_list_of_parts(**ref), name))
|
local_evals.append(
|
||||||
|
(_dmatrix_from_list_of_parts(**ref, nthread=n_threads), name)
|
||||||
|
)
|
||||||
|
|
||||||
local_history: Dict = {}
|
local_history: Dict = {}
|
||||||
local_param = params.copy() # just to be consistent
|
bst = worker_train(
|
||||||
msg = 'Overriding `nthreads` defined in dask worker.'
|
params=local_param,
|
||||||
override = ['nthread', 'n_jobs']
|
dtrain=local_dtrain,
|
||||||
for p in override:
|
num_boost_round=num_boost_round,
|
||||||
val = local_param.get(p, None)
|
evals_result=local_history,
|
||||||
if val is not None and val != worker.nthreads:
|
evals=local_evals,
|
||||||
LOGGER.info(msg)
|
obj=obj,
|
||||||
else:
|
feval=feval,
|
||||||
local_param[p] = worker.nthreads
|
early_stopping_rounds=early_stopping_rounds,
|
||||||
bst = worker_train(params=local_param,
|
verbose_eval=verbose_eval,
|
||||||
dtrain=local_dtrain,
|
xgb_model=xgb_model,
|
||||||
num_boost_round=num_boost_round,
|
callbacks=callbacks,
|
||||||
evals_result=local_history,
|
)
|
||||||
evals=local_evals,
|
|
||||||
obj=obj,
|
|
||||||
feval=feval,
|
|
||||||
early_stopping_rounds=early_stopping_rounds,
|
|
||||||
verbose_eval=verbose_eval,
|
|
||||||
xgb_model=xgb_model,
|
|
||||||
callbacks=callbacks)
|
|
||||||
ret: Optional[Dict[str, Union[Booster, Dict]]] = {
|
ret: Optional[Dict[str, Union[Booster, Dict]]] = {
|
||||||
'booster': bst, 'history': local_history}
|
"booster": bst,
|
||||||
|
"history": local_history,
|
||||||
|
}
|
||||||
if local_dtrain.num_row() == 0:
|
if local_dtrain.num_row() == 0:
|
||||||
ret = None
|
ret = None
|
||||||
return ret
|
return ret
|
||||||
@ -924,7 +933,7 @@ async def _train_async(
|
|||||||
evals_per_worker,
|
evals_per_worker,
|
||||||
pure=False,
|
pure=False,
|
||||||
workers=[worker_addr],
|
workers=[worker_addr],
|
||||||
allow_other_workers=False
|
allow_other_workers=False,
|
||||||
)
|
)
|
||||||
futures.append(f)
|
futures.append(f)
|
||||||
|
|
||||||
|
|||||||
@ -434,7 +434,7 @@ class TestDistributedGPU:
|
|||||||
|
|
||||||
def worker_fn(worker_addr: str, data_ref: Dict) -> None:
|
def worker_fn(worker_addr: str, data_ref: Dict) -> None:
|
||||||
with dxgb.RabitContext(rabit_args):
|
with dxgb.RabitContext(rabit_args):
|
||||||
local_dtrain = dxgb._dmatrix_from_list_of_parts(**data_ref)
|
local_dtrain = dxgb._dmatrix_from_list_of_parts(**data_ref, nthread=7)
|
||||||
fw_rows = local_dtrain.get_float_info("feature_weights").shape[0]
|
fw_rows = local_dtrain.get_float_info("feature_weights").shape[0]
|
||||||
assert fw_rows == local_dtrain.num_col()
|
assert fw_rows == local_dtrain.num_col()
|
||||||
|
|
||||||
|
|||||||
@ -1275,7 +1275,9 @@ class TestWithDask:
|
|||||||
|
|
||||||
def worker_fn(worker_addr: str, data_ref: Dict) -> None:
|
def worker_fn(worker_addr: str, data_ref: Dict) -> None:
|
||||||
with xgb.dask.RabitContext(rabit_args):
|
with xgb.dask.RabitContext(rabit_args):
|
||||||
local_dtrain = xgb.dask._dmatrix_from_list_of_parts(**data_ref)
|
local_dtrain = xgb.dask._dmatrix_from_list_of_parts(
|
||||||
|
**data_ref, nthread=7
|
||||||
|
)
|
||||||
total = np.array([local_dtrain.num_row()])
|
total = np.array([local_dtrain.num_row()])
|
||||||
total = xgb.rabit.allreduce(total, xgb.rabit.Op.SUM)
|
total = xgb.rabit.allreduce(total, xgb.rabit.Op.SUM)
|
||||||
assert total[0] == kRows
|
assert total[0] == kRows
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user