Add global configuration (#6414)
* Add management functions for global configuration: XGBSetGlobalConfig(), XGBGetGlobalConfig(). * Add Python interface: set_config(), get_config(), and config_context(). * Add unit tests for Python * Add R interface: xgb.set.config(), xgb.get.config() * Add unit tests for R Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
committed by
GitHub
parent
c2ba4fb957
commit
fb56da5e8b
@@ -17,6 +17,7 @@ try:
|
||||
from .sklearn import XGBModel, XGBClassifier, XGBRegressor, XGBRanker
|
||||
from .sklearn import XGBRFClassifier, XGBRFRegressor
|
||||
from .plotting import plot_importance, plot_tree, to_graphviz
|
||||
from .config import set_config, get_config, config_context
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
@@ -29,4 +30,5 @@ __all__ = ['DMatrix', 'DeviceQuantileDMatrix', 'Booster',
|
||||
'RabitTracker',
|
||||
'XGBModel', 'XGBClassifier', 'XGBRegressor', 'XGBRanker',
|
||||
'XGBRFClassifier', 'XGBRFRegressor',
|
||||
'plot_importance', 'plot_tree', 'to_graphviz', 'dask']
|
||||
'plot_importance', 'plot_tree', 'to_graphviz', 'dask',
|
||||
'set_config', 'get_config', 'config_context']
|
||||
|
||||
142
python-package/xgboost/config.py
Normal file
142
python-package/xgboost/config.py
Normal file
@@ -0,0 +1,142 @@
|
||||
# pylint: disable=missing-function-docstring
|
||||
"""Global configuration for XGBoost"""
|
||||
import ctypes
|
||||
import json
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
|
||||
from .core import _LIB, _check_call, c_str, py_str
|
||||
|
||||
|
||||
def config_doc(*, header=None, extra_note=None, parameters=None, returns=None,
|
||||
see_also=None):
|
||||
"""Decorator to format docstring for config functions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
header: str
|
||||
An introducion to the function
|
||||
extra_note: str
|
||||
Additional notes
|
||||
parameters: str
|
||||
Parameters of the function
|
||||
returns: str
|
||||
Return value
|
||||
see_also: str
|
||||
Related functions
|
||||
"""
|
||||
|
||||
doc_template = """
|
||||
{header}
|
||||
|
||||
Global configuration consists of a collection of parameters that can be applied in the
|
||||
global scope. See https://xgboost.readthedocs.io/en/stable/parameter.html for the full
|
||||
list of parameters supported in the global configuration.
|
||||
|
||||
{extra_note}
|
||||
|
||||
.. versionadded:: 1.3.0
|
||||
"""
|
||||
|
||||
common_example = """
|
||||
Example
|
||||
-------
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import xgboost as xgb
|
||||
|
||||
# Show all messages, including ones pertaining to debugging
|
||||
xgb.set_config(verbosity=2)
|
||||
|
||||
# Get current value of global configuration
|
||||
# This is a dict containing all parameters in the global configuration,
|
||||
# including 'verbosity'
|
||||
config = xgb.get_config()
|
||||
assert config['verbosity'] == 2
|
||||
|
||||
# Example of using the context manager xgb.config_context().
|
||||
# The context manager will restore the previous value of the global
|
||||
# configuration upon exiting.
|
||||
with xgb.config_context(verbosity=0):
|
||||
# Suppress warning caused by model generated with XGBoost version < 1.0.0
|
||||
bst = xgb.Booster(model_file='./old_model.bin')
|
||||
assert xgb.get_config()['verbosity'] == 2 # old value restored
|
||||
"""
|
||||
|
||||
def none_to_str(value):
|
||||
return '' if value is None else value
|
||||
|
||||
def config_doc_decorator(func):
|
||||
func.__doc__ = (doc_template.format(header=none_to_str(header),
|
||||
extra_note=none_to_str(extra_note))
|
||||
+ none_to_str(parameters) + none_to_str(returns)
|
||||
+ none_to_str(common_example) + none_to_str(see_also))
|
||||
|
||||
@wraps(func)
|
||||
def wrap(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
return wrap
|
||||
return config_doc_decorator
|
||||
|
||||
|
||||
@config_doc(header="""
|
||||
Set global configuration.
|
||||
""",
|
||||
parameters="""
|
||||
Parameters
|
||||
----------
|
||||
new_config: Dict[str, Any]
|
||||
Keyword arguments representing the parameters and their values
|
||||
""")
|
||||
def set_config(**new_config):
|
||||
config = json.dumps(new_config)
|
||||
_check_call(_LIB.XGBSetGlobalConfig(c_str(config)))
|
||||
|
||||
|
||||
@config_doc(header="""
|
||||
Get current values of the global configuration.
|
||||
""",
|
||||
returns="""
|
||||
Returns
|
||||
-------
|
||||
args: Dict[str, Any]
|
||||
The list of global parameters and their values
|
||||
""")
|
||||
def get_config():
|
||||
config_str = ctypes.c_char_p()
|
||||
_check_call(_LIB.XGBGetGlobalConfig(ctypes.byref(config_str)))
|
||||
config = json.loads(py_str(config_str.value))
|
||||
return config
|
||||
|
||||
|
||||
@contextmanager
|
||||
@config_doc(header="""
|
||||
Context manager for global XGBoost configuration.
|
||||
""",
|
||||
parameters="""
|
||||
Parameters
|
||||
----------
|
||||
new_config: Dict[str, Any]
|
||||
Keyword arguments representing the parameters and their values
|
||||
""",
|
||||
extra_note="""
|
||||
.. note::
|
||||
|
||||
All settings, not just those presently modified, will be returned to their
|
||||
previous values when the context manager is exited. This is not thread-safe.
|
||||
""",
|
||||
see_also="""
|
||||
See Also
|
||||
--------
|
||||
set_config: Set global XGBoost configuration
|
||||
get_config: Get current values of the global configuration
|
||||
""")
|
||||
def config_context(**new_config):
|
||||
old_config = get_config().copy()
|
||||
set_config(**new_config)
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
set_config(**old_config)
|
||||
@@ -22,7 +22,7 @@ from typing import List
|
||||
|
||||
import numpy
|
||||
|
||||
from . import rabit
|
||||
from . import rabit, config
|
||||
|
||||
from .compat import LazyLoader
|
||||
from .compat import sparse, scipy_sparse
|
||||
@@ -639,6 +639,7 @@ async def _train_async(client,
|
||||
|
||||
workers = list(_get_workers_from_data(dtrain, evals))
|
||||
_rabit_args = await _get_rabit_args(len(workers), client)
|
||||
_global_config = config.get_config()
|
||||
|
||||
def dispatched_train(worker_addr, rabit_args, dtrain_ref, dtrain_idt, evals_ref):
|
||||
'''Perform training on a single worker. A local function prevents pickling.
|
||||
@@ -646,7 +647,7 @@ async def _train_async(client,
|
||||
'''
|
||||
LOGGER.info('Training on %s', str(worker_addr))
|
||||
worker = distributed.get_worker()
|
||||
with RabitContext(rabit_args):
|
||||
with RabitContext(rabit_args), config.config_context(**_global_config):
|
||||
local_dtrain = _dmatrix_from_list_of_parts(**dtrain_ref)
|
||||
local_evals = []
|
||||
if evals_ref:
|
||||
@@ -770,18 +771,21 @@ async def _predict_async(client, model, data, missing=numpy.nan, **kwargs):
|
||||
raise TypeError(_expect([DaskDMatrix, da.Array, dd.DataFrame],
|
||||
type(data)))
|
||||
|
||||
_global_config = config.get_config()
|
||||
|
||||
def mapped_predict(partition, is_df):
|
||||
worker = distributed.get_worker()
|
||||
booster.set_param({'nthread': worker.nthreads})
|
||||
m = DMatrix(partition, missing=missing, nthread=worker.nthreads)
|
||||
predt = booster.predict(m, validate_features=False, **kwargs)
|
||||
if is_df:
|
||||
if lazy_isinstance(partition, 'cudf', 'core.dataframe.DataFrame'):
|
||||
import cudf # pylint: disable=import-error
|
||||
predt = cudf.DataFrame(predt, columns=['prediction'])
|
||||
else:
|
||||
predt = DataFrame(predt, columns=['prediction'])
|
||||
return predt
|
||||
with config.config_context(**_global_config):
|
||||
booster.set_param({'nthread': worker.nthreads})
|
||||
m = DMatrix(partition, missing=missing, nthread=worker.nthreads)
|
||||
predt = booster.predict(m, validate_features=False, **kwargs)
|
||||
if is_df:
|
||||
if lazy_isinstance(partition, 'cudf', 'core.dataframe.DataFrame'):
|
||||
import cudf # pylint: disable=import-error
|
||||
predt = cudf.DataFrame(predt, columns=['prediction'])
|
||||
else:
|
||||
predt = DataFrame(predt, columns=['prediction'])
|
||||
return predt
|
||||
# Predict on dask collection directly.
|
||||
if isinstance(data, (da.Array, dd.DataFrame)):
|
||||
return await _direct_predict_impl(client, data, mapped_predict)
|
||||
@@ -797,31 +801,32 @@ async def _predict_async(client, model, data, missing=numpy.nan, **kwargs):
|
||||
def dispatched_predict(worker_id, list_of_orders, list_of_parts):
|
||||
'''Perform prediction on each worker.'''
|
||||
LOGGER.info('Predicting on %d', worker_id)
|
||||
worker = distributed.get_worker()
|
||||
list_of_parts = _get_worker_parts_ordered(meta_names, list_of_parts)
|
||||
predictions = []
|
||||
with config.config_context(**_global_config):
|
||||
worker = distributed.get_worker()
|
||||
list_of_parts = _get_worker_parts_ordered(meta_names, list_of_parts)
|
||||
predictions = []
|
||||
|
||||
booster.set_param({'nthread': worker.nthreads})
|
||||
for i, parts in enumerate(list_of_parts):
|
||||
(data, _, _, base_margin, _, _) = parts
|
||||
order = list_of_orders[i]
|
||||
local_part = DMatrix(
|
||||
data,
|
||||
base_margin=base_margin,
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types,
|
||||
missing=missing,
|
||||
nthread=worker.nthreads
|
||||
)
|
||||
predt = booster.predict(
|
||||
data=local_part,
|
||||
validate_features=local_part.num_row() != 0,
|
||||
**kwargs)
|
||||
columns = 1 if len(predt.shape) == 1 else predt.shape[1]
|
||||
ret = ((dask.delayed(predt), columns), order)
|
||||
predictions.append(ret)
|
||||
booster.set_param({'nthread': worker.nthreads})
|
||||
for i, parts in enumerate(list_of_parts):
|
||||
(data, _, _, base_margin, _, _) = parts
|
||||
order = list_of_orders[i]
|
||||
local_part = DMatrix(
|
||||
data,
|
||||
base_margin=base_margin,
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types,
|
||||
missing=missing,
|
||||
nthread=worker.nthreads
|
||||
)
|
||||
predt = booster.predict(
|
||||
data=local_part,
|
||||
validate_features=local_part.num_row() != 0,
|
||||
**kwargs)
|
||||
columns = 1 if len(predt.shape) == 1 else predt.shape[1]
|
||||
ret = ((dask.delayed(predt), columns), order)
|
||||
predictions.append(ret)
|
||||
|
||||
return predictions
|
||||
return predictions
|
||||
|
||||
def dispatched_get_shape(worker_id, list_of_orders, list_of_parts):
|
||||
'''Get shape of data in each worker.'''
|
||||
|
||||
Reference in New Issue
Block a user