[dask] Use nthread in DMatrix construction. (#7337)
This is consistent with the thread overriding behavior.
This commit is contained in:
@@ -434,7 +434,7 @@ class TestDistributedGPU:
|
||||
|
||||
def worker_fn(worker_addr: str, data_ref: Dict) -> None:
|
||||
with dxgb.RabitContext(rabit_args):
|
||||
local_dtrain = dxgb._dmatrix_from_list_of_parts(**data_ref)
|
||||
local_dtrain = dxgb._dmatrix_from_list_of_parts(**data_ref, nthread=7)
|
||||
fw_rows = local_dtrain.get_float_info("feature_weights").shape[0]
|
||||
assert fw_rows == local_dtrain.num_col()
|
||||
|
||||
|
||||
@@ -1275,7 +1275,9 @@ class TestWithDask:
|
||||
|
||||
def worker_fn(worker_addr: str, data_ref: Dict) -> None:
|
||||
with xgb.dask.RabitContext(rabit_args):
|
||||
local_dtrain = xgb.dask._dmatrix_from_list_of_parts(**data_ref)
|
||||
local_dtrain = xgb.dask._dmatrix_from_list_of_parts(
|
||||
**data_ref, nthread=7
|
||||
)
|
||||
total = np.array([local_dtrain.num_row()])
|
||||
total = xgb.rabit.allreduce(total, xgb.rabit.Op.SUM)
|
||||
assert total[0] == kRows
|
||||
|
||||
Reference in New Issue
Block a user