Add global configuration (#6414)

* Add management functions for global configuration: XGBSetGlobalConfig(), XGBGetGlobalConfig().
* Add Python interface: set_config(), get_config(), and config_context().
* Add unit tests for Python
* Add R interface: xgb.set.config(), xgb.get.config()
* Add unit tests for R

Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
Philip Hyunsu Cho
2020-12-03 00:05:18 -08:00
committed by GitHub
parent c2ba4fb957
commit fb56da5e8b
29 changed files with 637 additions and 86 deletions

View File

@@ -17,6 +17,7 @@ try:
from .sklearn import XGBModel, XGBClassifier, XGBRegressor, XGBRanker
from .sklearn import XGBRFClassifier, XGBRFRegressor
from .plotting import plot_importance, plot_tree, to_graphviz
from .config import set_config, get_config, config_context
except ImportError:
pass
@@ -29,4 +30,5 @@ __all__ = ['DMatrix', 'DeviceQuantileDMatrix', 'Booster',
'RabitTracker',
'XGBModel', 'XGBClassifier', 'XGBRegressor', 'XGBRanker',
'XGBRFClassifier', 'XGBRFRegressor',
'plot_importance', 'plot_tree', 'to_graphviz', 'dask']
'plot_importance', 'plot_tree', 'to_graphviz', 'dask',
'set_config', 'get_config', 'config_context']

View File

@@ -0,0 +1,142 @@
# pylint: disable=missing-function-docstring
"""Global configuration for XGBoost"""
import ctypes
import json
from contextlib import contextmanager
from functools import wraps
from .core import _LIB, _check_call, c_str, py_str
def config_doc(*, header=None, extra_note=None, parameters=None, returns=None,
see_also=None):
"""Decorator to format docstring for config functions.
Parameters
----------
header: str
An introducion to the function
extra_note: str
Additional notes
parameters: str
Parameters of the function
returns: str
Return value
see_also: str
Related functions
"""
doc_template = """
{header}
Global configuration consists of a collection of parameters that can be applied in the
global scope. See https://xgboost.readthedocs.io/en/stable/parameter.html for the full
list of parameters supported in the global configuration.
{extra_note}
.. versionadded:: 1.3.0
"""
common_example = """
Example
-------
.. code-block:: python
import xgboost as xgb
# Show all messages, including ones pertaining to debugging
xgb.set_config(verbosity=2)
# Get current value of global configuration
# This is a dict containing all parameters in the global configuration,
# including 'verbosity'
config = xgb.get_config()
assert config['verbosity'] == 2
# Example of using the context manager xgb.config_context().
# The context manager will restore the previous value of the global
# configuration upon exiting.
with xgb.config_context(verbosity=0):
# Suppress warning caused by model generated with XGBoost version < 1.0.0
bst = xgb.Booster(model_file='./old_model.bin')
assert xgb.get_config()['verbosity'] == 2 # old value restored
"""
def none_to_str(value):
return '' if value is None else value
def config_doc_decorator(func):
func.__doc__ = (doc_template.format(header=none_to_str(header),
extra_note=none_to_str(extra_note))
+ none_to_str(parameters) + none_to_str(returns)
+ none_to_str(common_example) + none_to_str(see_also))
@wraps(func)
def wrap(*args, **kwargs):
return func(*args, **kwargs)
return wrap
return config_doc_decorator
@config_doc(header="""
Set global configuration.
""",
parameters="""
Parameters
----------
new_config: Dict[str, Any]
Keyword arguments representing the parameters and their values
""")
def set_config(**new_config):
config = json.dumps(new_config)
_check_call(_LIB.XGBSetGlobalConfig(c_str(config)))
@config_doc(header="""
Get current values of the global configuration.
""",
returns="""
Returns
-------
args: Dict[str, Any]
The list of global parameters and their values
""")
def get_config():
config_str = ctypes.c_char_p()
_check_call(_LIB.XGBGetGlobalConfig(ctypes.byref(config_str)))
config = json.loads(py_str(config_str.value))
return config
@contextmanager
@config_doc(header="""
Context manager for global XGBoost configuration.
""",
parameters="""
Parameters
----------
new_config: Dict[str, Any]
Keyword arguments representing the parameters and their values
""",
extra_note="""
.. note::
All settings, not just those presently modified, will be returned to their
previous values when the context manager is exited. This is not thread-safe.
""",
see_also="""
See Also
--------
set_config: Set global XGBoost configuration
get_config: Get current values of the global configuration
""")
def config_context(**new_config):
old_config = get_config().copy()
set_config(**new_config)
try:
yield
finally:
set_config(**old_config)

View File

@@ -22,7 +22,7 @@ from typing import List
import numpy
from . import rabit
from . import rabit, config
from .compat import LazyLoader
from .compat import sparse, scipy_sparse
@@ -639,6 +639,7 @@ async def _train_async(client,
workers = list(_get_workers_from_data(dtrain, evals))
_rabit_args = await _get_rabit_args(len(workers), client)
_global_config = config.get_config()
def dispatched_train(worker_addr, rabit_args, dtrain_ref, dtrain_idt, evals_ref):
'''Perform training on a single worker. A local function prevents pickling.
@@ -646,7 +647,7 @@ async def _train_async(client,
'''
LOGGER.info('Training on %s', str(worker_addr))
worker = distributed.get_worker()
with RabitContext(rabit_args):
with RabitContext(rabit_args), config.config_context(**_global_config):
local_dtrain = _dmatrix_from_list_of_parts(**dtrain_ref)
local_evals = []
if evals_ref:
@@ -770,18 +771,21 @@ async def _predict_async(client, model, data, missing=numpy.nan, **kwargs):
raise TypeError(_expect([DaskDMatrix, da.Array, dd.DataFrame],
type(data)))
_global_config = config.get_config()
def mapped_predict(partition, is_df):
worker = distributed.get_worker()
booster.set_param({'nthread': worker.nthreads})
m = DMatrix(partition, missing=missing, nthread=worker.nthreads)
predt = booster.predict(m, validate_features=False, **kwargs)
if is_df:
if lazy_isinstance(partition, 'cudf', 'core.dataframe.DataFrame'):
import cudf # pylint: disable=import-error
predt = cudf.DataFrame(predt, columns=['prediction'])
else:
predt = DataFrame(predt, columns=['prediction'])
return predt
with config.config_context(**_global_config):
booster.set_param({'nthread': worker.nthreads})
m = DMatrix(partition, missing=missing, nthread=worker.nthreads)
predt = booster.predict(m, validate_features=False, **kwargs)
if is_df:
if lazy_isinstance(partition, 'cudf', 'core.dataframe.DataFrame'):
import cudf # pylint: disable=import-error
predt = cudf.DataFrame(predt, columns=['prediction'])
else:
predt = DataFrame(predt, columns=['prediction'])
return predt
# Predict on dask collection directly.
if isinstance(data, (da.Array, dd.DataFrame)):
return await _direct_predict_impl(client, data, mapped_predict)
@@ -797,31 +801,32 @@ async def _predict_async(client, model, data, missing=numpy.nan, **kwargs):
def dispatched_predict(worker_id, list_of_orders, list_of_parts):
'''Perform prediction on each worker.'''
LOGGER.info('Predicting on %d', worker_id)
worker = distributed.get_worker()
list_of_parts = _get_worker_parts_ordered(meta_names, list_of_parts)
predictions = []
with config.config_context(**_global_config):
worker = distributed.get_worker()
list_of_parts = _get_worker_parts_ordered(meta_names, list_of_parts)
predictions = []
booster.set_param({'nthread': worker.nthreads})
for i, parts in enumerate(list_of_parts):
(data, _, _, base_margin, _, _) = parts
order = list_of_orders[i]
local_part = DMatrix(
data,
base_margin=base_margin,
feature_names=feature_names,
feature_types=feature_types,
missing=missing,
nthread=worker.nthreads
)
predt = booster.predict(
data=local_part,
validate_features=local_part.num_row() != 0,
**kwargs)
columns = 1 if len(predt.shape) == 1 else predt.shape[1]
ret = ((dask.delayed(predt), columns), order)
predictions.append(ret)
booster.set_param({'nthread': worker.nthreads})
for i, parts in enumerate(list_of_parts):
(data, _, _, base_margin, _, _) = parts
order = list_of_orders[i]
local_part = DMatrix(
data,
base_margin=base_margin,
feature_names=feature_names,
feature_types=feature_types,
missing=missing,
nthread=worker.nthreads
)
predt = booster.predict(
data=local_part,
validate_features=local_part.num_row() != 0,
**kwargs)
columns = 1 if len(predt.shape) == 1 else predt.shape[1]
ret = ((dask.delayed(predt), columns), order)
predictions.append(ret)
return predictions
return predictions
def dispatched_get_shape(worker_id, list_of_orders, list_of_parts):
'''Get shape of data in each worker.'''