Fix global config default value. (#6470)
This commit is contained in:
parent
d6386e45e8
commit
703c2d06aa
@ -15,7 +15,7 @@ namespace xgboost {
|
||||
class Json;
|
||||
|
||||
struct GlobalConfiguration : public XGBoostParameter<GlobalConfiguration> {
|
||||
int verbosity;
|
||||
int verbosity { 1 };
|
||||
DMLC_DECLARE_PARAMETER(GlobalConfiguration) {
|
||||
DMLC_DECLARE_FIELD(verbosity)
|
||||
.set_range(0, 3)
|
||||
|
||||
@ -626,6 +626,7 @@ def _get_workers_from_data(dtrain: DaskDMatrix, evals=()):
|
||||
|
||||
|
||||
async def _train_async(client,
|
||||
global_config,
|
||||
params,
|
||||
dtrain: DaskDMatrix,
|
||||
*args,
|
||||
@ -639,7 +640,6 @@ async def _train_async(client,
|
||||
|
||||
workers = list(_get_workers_from_data(dtrain, evals))
|
||||
_rabit_args = await _get_rabit_args(len(workers), client)
|
||||
_global_config = config.get_config()
|
||||
|
||||
def dispatched_train(worker_addr, rabit_args, dtrain_ref, dtrain_idt, evals_ref):
|
||||
'''Perform training on a single worker. A local function prevents pickling.
|
||||
@ -647,7 +647,7 @@ async def _train_async(client,
|
||||
'''
|
||||
LOGGER.info('Training on %s', str(worker_addr))
|
||||
worker = distributed.get_worker()
|
||||
with RabitContext(rabit_args), config.config_context(**_global_config):
|
||||
with RabitContext(rabit_args), config.config_context(**global_config):
|
||||
local_dtrain = _dmatrix_from_list_of_parts(**dtrain_ref)
|
||||
local_evals = []
|
||||
if evals_ref:
|
||||
@ -735,8 +735,11 @@ def train(client, params, dtrain, *args, evals=(), early_stopping_rounds=None,
|
||||
'''
|
||||
_assert_dask_support()
|
||||
client = _xgb_get_client(client)
|
||||
# Get global configuration before transferring computation to another thread or
|
||||
# process.
|
||||
global_config = config.get_config()
|
||||
return client.sync(
|
||||
_train_async, client, params, dtrain=dtrain, *args, evals=evals,
|
||||
_train_async, client, global_config, params, dtrain=dtrain, *args, evals=evals,
|
||||
early_stopping_rounds=early_stopping_rounds, **kwargs)
|
||||
|
||||
|
||||
@ -760,7 +763,7 @@ async def _direct_predict_impl(client, data, predict_fn):
|
||||
|
||||
|
||||
# pylint: disable=too-many-statements
|
||||
async def _predict_async(client, model, data, missing=numpy.nan, **kwargs):
|
||||
async def _predict_async(client, global_config, model, data, missing=numpy.nan, **kwargs):
|
||||
if isinstance(model, Booster):
|
||||
booster = model
|
||||
elif isinstance(model, dict):
|
||||
@ -771,11 +774,9 @@ async def _predict_async(client, model, data, missing=numpy.nan, **kwargs):
|
||||
raise TypeError(_expect([DaskDMatrix, da.Array, dd.DataFrame],
|
||||
type(data)))
|
||||
|
||||
_global_config = config.get_config()
|
||||
|
||||
def mapped_predict(partition, is_df):
|
||||
worker = distributed.get_worker()
|
||||
with config.config_context(**_global_config):
|
||||
with config.config_context(**global_config):
|
||||
booster.set_param({'nthread': worker.nthreads})
|
||||
m = DMatrix(partition, missing=missing, nthread=worker.nthreads)
|
||||
predt = booster.predict(m, validate_features=False, **kwargs)
|
||||
@ -801,7 +802,7 @@ async def _predict_async(client, model, data, missing=numpy.nan, **kwargs):
|
||||
def dispatched_predict(worker_id, list_of_orders, list_of_parts):
|
||||
'''Perform prediction on each worker.'''
|
||||
LOGGER.info('Predicting on %d', worker_id)
|
||||
with config.config_context(**_global_config):
|
||||
with config.config_context(**global_config):
|
||||
worker = distributed.get_worker()
|
||||
list_of_parts = _get_worker_parts_ordered(meta_names, list_of_parts)
|
||||
predictions = []
|
||||
@ -907,11 +908,12 @@ def predict(client, model, data, missing=numpy.nan, **kwargs):
|
||||
'''
|
||||
_assert_dask_support()
|
||||
client = _xgb_get_client(client)
|
||||
return client.sync(_predict_async, client, model, data,
|
||||
global_config = config.get_config()
|
||||
return client.sync(_predict_async, client, global_config, model, data,
|
||||
missing=missing, **kwargs)
|
||||
|
||||
|
||||
async def _inplace_predict_async(client, model, data,
|
||||
async def _inplace_predict_async(client, global_config, model, data,
|
||||
iteration_range=(0, 0),
|
||||
predict_type='value',
|
||||
missing=numpy.nan):
|
||||
@ -927,6 +929,7 @@ async def _inplace_predict_async(client, model, data,
|
||||
|
||||
def mapped_predict(data, is_df):
|
||||
worker = distributed.get_worker()
|
||||
config.set_config(**global_config)
|
||||
booster.set_param({'nthread': worker.nthreads})
|
||||
prediction = booster.inplace_predict(
|
||||
data,
|
||||
@ -976,7 +979,9 @@ def inplace_predict(client, model, data,
|
||||
'''
|
||||
_assert_dask_support()
|
||||
client = _xgb_get_client(client)
|
||||
return client.sync(_inplace_predict_async, client, model=model, data=data,
|
||||
global_config = config.get_config()
|
||||
return client.sync(_inplace_predict_async, client, global_config, model=model,
|
||||
data=data,
|
||||
iteration_range=iteration_range,
|
||||
predict_type=predict_type,
|
||||
missing=missing)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user