Add global configuration (#6414)

* Add management functions for global configuration: XGBSetGlobalConfig(), XGBGetGlobalConfig().
* Add Python interface: set_config(), get_config(), and config_context().
* Add unit tests for Python
* Add R interface: xgb.set.config(), xgb.get.config()
* Add unit tests for R

Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
Philip Hyunsu Cho 2020-12-03 00:05:18 -08:00 committed by GitHub
parent c2ba4fb957
commit fb56da5e8b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 637 additions and 86 deletions

View File

@ -53,7 +53,6 @@ Suggests:
testthat,
lintr,
igraph (>= 1.0.1),
jsonlite,
float,
crayon,
titanic
@ -64,5 +63,6 @@ Imports:
methods,
data.table (>= 1.9.6),
magrittr (>= 1.5),
jsonlite (>= 1.0),
RoxygenNote: 7.1.1
SystemRequirements: GNU make, C++14

View File

@ -36,6 +36,7 @@ export(xgb.create.features)
export(xgb.cv)
export(xgb.dump)
export(xgb.gblinear.history)
export(xgb.get.config)
export(xgb.ggplot.deepness)
export(xgb.ggplot.importance)
export(xgb.ggplot.shap.summary)
@ -52,6 +53,7 @@ export(xgb.plot.tree)
export(xgb.save)
export(xgb.save.raw)
export(xgb.serialize)
export(xgb.set.config)
export(xgb.train)
export(xgb.unserialize)
export(xgboost)
@ -78,6 +80,8 @@ importFrom(graphics,lines)
importFrom(graphics,par)
importFrom(graphics,points)
importFrom(graphics,title)
importFrom(jsonlite,fromJSON)
importFrom(jsonlite,toJSON)
importFrom(magrittr,"%>%")
importFrom(stats,median)
importFrom(stats,predict)

38
R-package/R/xgb.config.R Normal file
View File

@ -0,0 +1,38 @@
#' Global configuration consists of a collection of parameters that can be applied in the global
#' scope. See \url{https://xgboost.readthedocs.io/en/stable/parameter.html} for the full list of
#' parameters supported in the global configuration. Use \code{xgb.set.config} to update the
#' values of one or more global-scope parameters. Use \code{xgb.get.config} to fetch the current
#' values of all global-scope parameters (listed in
#' \url{https://xgboost.readthedocs.io/en/stable/parameter.html}).
#'
#' @rdname xgbConfig
#' @title Set and get global configuration
#' @name xgb.set.config, xgb.get.config
#' @export xgb.set.config xgb.get.config
#' @param ... List of parameters to be set, as keyword arguments
#' @return
#' \code{xgb.set.config} returns \code{TRUE} to signal success. \code{xgb.get.config} returns
#' a list containing all global-scope parameters and their values.
#'
#' @examples
#' # Set verbosity level to silent (0)
#' xgb.set.config(verbosity = 0)
#' # Now global verbosity level is 0
#' config <- xgb.get.config()
#' print(config$verbosity)
#' # Set verbosity level to warning (1)
#' xgb.set.config(verbosity = 1)
#' # Now global verbosity level is 1
#' config <- xgb.get.config()
#' print(config$verbosity)
xgb.set.config <- function(...) {
new_config <- list(...)
.Call(XGBSetGlobalConfig_R, jsonlite::toJSON(new_config, auto_unbox = TRUE))
return(TRUE)
}
#' @rdname xgbConfig
xgb.get.config <- function() {
config <- .Call(XGBGetGlobalConfig_R)
return(jsonlite::fromJSON(config))
}

View File

@ -91,6 +91,8 @@ NULL
#' @importFrom data.table setkeyv
#' @importFrom data.table setnames
#' @importFrom magrittr %>%
#' @importFrom jsonlite fromJSON
#' @importFrom jsonlite toJSON
#' @importFrom utils object.size str tail
#' @importFrom stats predict
#' @importFrom stats median

View File

@ -0,0 +1,39 @@
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/xgb.config.R
\name{xgb.set.config, xgb.get.config}
\alias{xgb.set.config, xgb.get.config}
\alias{xgb.set.config}
\alias{xgb.get.config}
\title{Set and get global configuration}
\usage{
xgb.set.config(...)
xgb.get.config()
}
\arguments{
\item{...}{List of parameters to be set, as keyword arguments}
}
\value{
\code{xgb.set.config} returns \code{TRUE} to signal success. \code{xgb.get.config} returns
a list containing all global-scope parameters and their values.
}
\description{
Global configuration consists of a collection of parameters that can be applied in the global
scope. See \url{https://xgboost.readthedocs.io/en/stable/parameter.html} for the full list of
parameters supported in the global configuration. Use \code{xgb.set.config} to update the
values of one or more global-scope parameters. Use \code{xgb.get.config} to fetch the current
values of all global-scope parameters (listed in
\url{https://xgboost.readthedocs.io/en/stable/parameter.html}).
}
\examples{
# Set verbosity level to silent (0)
xgb.set.config(verbosity = 0)
# Now global verbosity level is 0
config <- xgb.get.config()
print(config$verbosity)
# Set verbosity level to warning (1)
xgb.set.config(verbosity = 1)
# Now global verbosity level is 1
config <- xgb.get.config()
print(config$verbosity)
}

View File

@ -43,6 +43,8 @@ extern SEXP XGDMatrixNumRow_R(SEXP);
extern SEXP XGDMatrixSaveBinary_R(SEXP, SEXP, SEXP);
extern SEXP XGDMatrixSetInfo_R(SEXP, SEXP, SEXP);
extern SEXP XGDMatrixSliceDMatrix_R(SEXP, SEXP);
extern SEXP XGBSetGlobalConfig_R(SEXP);
extern SEXP XGBGetGlobalConfig_R();
static const R_CallMethodDef CallEntries[] = {
{"XGBoosterBoostOneIter_R", (DL_FUNC) &XGBoosterBoostOneIter_R, 4},
@ -73,6 +75,8 @@ static const R_CallMethodDef CallEntries[] = {
{"XGDMatrixSaveBinary_R", (DL_FUNC) &XGDMatrixSaveBinary_R, 3},
{"XGDMatrixSetInfo_R", (DL_FUNC) &XGDMatrixSetInfo_R, 3},
{"XGDMatrixSliceDMatrix_R", (DL_FUNC) &XGDMatrixSliceDMatrix_R, 2},
{"XGBSetGlobalConfig_R", (DL_FUNC) &XGBSetGlobalConfig_R, 1},
{"XGBGetGlobalConfig_R", (DL_FUNC) &XGBGetGlobalConfig_R, 0},
{NULL, NULL, 0}
};

View File

@ -49,6 +49,21 @@ void _DMatrixFinalizer(SEXP ext) {
R_API_END();
}
SEXP XGBSetGlobalConfig_R(SEXP json_str) {
R_API_BEGIN();
CHECK_CALL(XGBSetGlobalConfig(CHAR(asChar(json_str))));
R_API_END();
return R_NilValue;
}
SEXP XGBGetGlobalConfig_R() {
const char* json_str;
R_API_BEGIN();
CHECK_CALL(XGBGetGlobalConfig(&json_str));
R_API_END();
return mkString(json_str);
}
SEXP XGDMatrixCreateFromFile_R(SEXP fname, SEXP silent) {
SEXP ret;
R_API_BEGIN();

View File

@ -21,6 +21,19 @@
*/
XGB_DLL SEXP XGCheckNullPtr_R(SEXP handle);
/*!
* \brief Set global configuration
* \param json_str a JSON string representing the list of key-value pairs
* \return R_NilValue
*/
XGB_DLL SEXP XGBSetGlobalConfig_R(SEXP json_str);
/*!
* \brief Get global configuration
* \return JSON string
*/
XGB_DLL SEXP XGBGetGlobalConfig_R();
/*!
* \brief load a data matrix
* \param fname name of the content

View File

@ -16,7 +16,7 @@ void CustomLogMessage::Log(const std::string& msg) {
namespace xgboost {
ConsoleLogger::~ConsoleLogger() {
if (cur_verbosity_ == LogVerbosity::kIgnore ||
cur_verbosity_ <= global_verbosity_) {
cur_verbosity_ <= GlobalVerbosity()) {
dmlc::CustomLogMessage::Log(log_stream_.str());
}
}

View File

@ -0,0 +1,11 @@
context('Test global configuration')
test_that('Global configuration works with verbosity', {
old_verbosity <- xgb.get.config()$verbosity
for (v in c(0, 1, 2, 3)) {
xgb.set.config(verbosity = v)
expect_equal(xgb.get.config()$verbosity, v)
}
xgb.set.config(verbosity = old_verbosity)
expect_equal(xgb.get.config()$verbosity, old_verbosity)
})

View File

@ -67,6 +67,7 @@
// global
#include "../src/learner.cc"
#include "../src/logging.cc"
#include "../src/global_config.cc"
#include "../src/common/common.cc"
#include "../src/common/random.cc"
#include "../src/common/charconv.cc"

View File

@ -16,6 +16,13 @@ Before running XGBoost, we must set three types of parameters: general parameter
:backlinks: none
:local:
********************
Global Configuration
********************
The following parameters can be set in the global scope, using ``xgb.config_context()`` (Python) or ``xgb.set.config()`` (R).
* ``verbosity``: Verbosity of printing messages. Valid values of 0 (silent), 1 (warning), 2 (info), and 3 (debug).
******************
General Parameters
******************

View File

@ -6,6 +6,14 @@ This page gives the Python API reference of xgboost, please also refer to Python
:backlinks: none
:local:
Global Configuration
--------------------
.. autofunction:: xgboost.config_context
.. autofunction:: xgboost.set_config
.. autofunction:: xgboost.get_config
Core Data Structure
-------------------
.. automodule:: xgboost.core

View File

@ -63,6 +63,23 @@ XGB_DLL const char *XGBGetLastError(void);
*/
XGB_DLL int XGBRegisterLogCallback(void (*callback)(const char*));
/*!
* \brief Set global configuration (collection of parameters that apply globally). This function
* accepts the list of key-value pairs representing the global-scope parameters to be
* configured. The list of key-value pairs are passed in as a JSON string.
* \param json_str a JSON string representing the list of key-value pairs. The JSON object shall
* be flat: no value can be a JSON object or an array.
* \return 0 for success, -1 for failure
*/
XGB_DLL int XGBSetGlobalConfig(const char* json_str);
/*!
* \brief Get current global configuration (collection of parameters that apply globally).
* \param json_str pointer to received returned global configuration, represented as a JSON string.
* \return 0 for success, -1 for failure
*/
XGB_DLL int XGBGetGlobalConfig(const char** json_str);
/*!
* \brief load a data matrix
* \param fname the name of the file

View File

@ -0,0 +1,30 @@
/*!
* Copyright 2020 by Contributors
* \file global_config.h
* \brief Global configuration for XGBoost
* \author Hyunsu Cho
*/
#ifndef XGBOOST_GLOBAL_CONFIG_H_
#define XGBOOST_GLOBAL_CONFIG_H_
#include <xgboost/parameter.h>
#include <vector>
#include <string>
namespace xgboost {
class Json;
struct GlobalConfiguration : public XGBoostParameter<GlobalConfiguration> {
int verbosity;
DMLC_DECLARE_PARAMETER(GlobalConfiguration) {
DMLC_DECLARE_FIELD(verbosity)
.set_range(0, 3)
.set_default(1) // shows only warning
.describe("Flag to print out detailed breakdown of runtime.");
}
};
using GlobalConfigThreadLocalStore = dmlc::ThreadLocalStore<GlobalConfiguration>;
} // namespace xgboost
#endif // XGBOOST_GLOBAL_CONFIG_H_

View File

@ -557,7 +557,6 @@ using String = JsonString;
using Null = JsonNull;
// Utils tailored for XGBoost.
template <typename Parameter>
Object ToJson(Parameter const& param) {
Object obj;
@ -568,13 +567,13 @@ Object ToJson(Parameter const& param) {
}
template <typename Parameter>
void FromJson(Json const& obj, Parameter* param) {
Args FromJson(Json const& obj, Parameter* param) {
auto const& j_param = get<Object const>(obj);
std::map<std::string, std::string> m;
for (auto const& kv : j_param) {
m[kv.first] = get<String const>(kv.second);
}
param->UpdateAllowUnknown(m);
return param->UpdateAllowUnknown(m);
}
} // namespace xgboost
#endif // XGBOOST_JSON_H_

View File

@ -45,7 +45,6 @@ struct XGBAPIThreadLocalEntry {
PredictionCacheEntry prediction_entry;
};
/*!
* \brief Learner class that does training and prediction.
* This is the user facing module of xgboost training.

View File

@ -13,6 +13,7 @@
#include <xgboost/base.h>
#include <xgboost/parameter.h>
#include <xgboost/global_config.h>
#include <sstream>
#include <map>
@ -35,19 +36,6 @@ class BaseLogger {
std::ostringstream log_stream_;
};
// Parsing both silent and debug_verbose is to provide backward compatibility.
struct ConsoleLoggerParam : public XGBoostParameter<ConsoleLoggerParam> {
int verbosity;
DMLC_DECLARE_PARAMETER(ConsoleLoggerParam) {
DMLC_DECLARE_FIELD(verbosity)
.set_range(0, 3)
.set_default(1) // shows only warning
.describe("Flag to print out detailed breakdown of runtime.");
DMLC_DECLARE_ALIAS(verbosity, debug_verbose);
}
};
class ConsoleLogger : public BaseLogger {
public:
enum class LogVerbosity {
@ -60,9 +48,6 @@ class ConsoleLogger : public BaseLogger {
using LV = LogVerbosity;
private:
static LogVerbosity global_verbosity_;
static ConsoleLoggerParam param_;
LogVerbosity cur_verbosity_;
public:

View File

@ -17,6 +17,7 @@ try:
from .sklearn import XGBModel, XGBClassifier, XGBRegressor, XGBRanker
from .sklearn import XGBRFClassifier, XGBRFRegressor
from .plotting import plot_importance, plot_tree, to_graphviz
from .config import set_config, get_config, config_context
except ImportError:
pass
@ -29,4 +30,5 @@ __all__ = ['DMatrix', 'DeviceQuantileDMatrix', 'Booster',
'RabitTracker',
'XGBModel', 'XGBClassifier', 'XGBRegressor', 'XGBRanker',
'XGBRFClassifier', 'XGBRFRegressor',
'plot_importance', 'plot_tree', 'to_graphviz', 'dask']
'plot_importance', 'plot_tree', 'to_graphviz', 'dask',
'set_config', 'get_config', 'config_context']

View File

@ -0,0 +1,142 @@
# pylint: disable=missing-function-docstring
"""Global configuration for XGBoost"""
import ctypes
import json
from contextlib import contextmanager
from functools import wraps
from .core import _LIB, _check_call, c_str, py_str
def config_doc(*, header=None, extra_note=None, parameters=None, returns=None,
see_also=None):
"""Decorator to format docstring for config functions.
Parameters
----------
header: str
An introducion to the function
extra_note: str
Additional notes
parameters: str
Parameters of the function
returns: str
Return value
see_also: str
Related functions
"""
doc_template = """
{header}
Global configuration consists of a collection of parameters that can be applied in the
global scope. See https://xgboost.readthedocs.io/en/stable/parameter.html for the full
list of parameters supported in the global configuration.
{extra_note}
.. versionadded:: 1.3.0
"""
common_example = """
Example
-------
.. code-block:: python
import xgboost as xgb
# Show all messages, including ones pertaining to debugging
xgb.set_config(verbosity=2)
# Get current value of global configuration
# This is a dict containing all parameters in the global configuration,
# including 'verbosity'
config = xgb.get_config()
assert config['verbosity'] == 2
# Example of using the context manager xgb.config_context().
# The context manager will restore the previous value of the global
# configuration upon exiting.
with xgb.config_context(verbosity=0):
# Suppress warning caused by model generated with XGBoost version < 1.0.0
bst = xgb.Booster(model_file='./old_model.bin')
assert xgb.get_config()['verbosity'] == 2 # old value restored
"""
def none_to_str(value):
return '' if value is None else value
def config_doc_decorator(func):
func.__doc__ = (doc_template.format(header=none_to_str(header),
extra_note=none_to_str(extra_note))
+ none_to_str(parameters) + none_to_str(returns)
+ none_to_str(common_example) + none_to_str(see_also))
@wraps(func)
def wrap(*args, **kwargs):
return func(*args, **kwargs)
return wrap
return config_doc_decorator
@config_doc(header="""
Set global configuration.
""",
parameters="""
Parameters
----------
new_config: Dict[str, Any]
Keyword arguments representing the parameters and their values
""")
def set_config(**new_config):
config = json.dumps(new_config)
_check_call(_LIB.XGBSetGlobalConfig(c_str(config)))
@config_doc(header="""
Get current values of the global configuration.
""",
returns="""
Returns
-------
args: Dict[str, Any]
The list of global parameters and their values
""")
def get_config():
config_str = ctypes.c_char_p()
_check_call(_LIB.XGBGetGlobalConfig(ctypes.byref(config_str)))
config = json.loads(py_str(config_str.value))
return config
@contextmanager
@config_doc(header="""
Context manager for global XGBoost configuration.
""",
parameters="""
Parameters
----------
new_config: Dict[str, Any]
Keyword arguments representing the parameters and their values
""",
extra_note="""
.. note::
All settings, not just those presently modified, will be returned to their
previous values when the context manager is exited. This is not thread-safe.
""",
see_also="""
See Also
--------
set_config: Set global XGBoost configuration
get_config: Get current values of the global configuration
""")
def config_context(**new_config):
old_config = get_config().copy()
set_config(**new_config)
try:
yield
finally:
set_config(**old_config)

View File

@ -22,7 +22,7 @@ from typing import List
import numpy
from . import rabit
from . import rabit, config
from .compat import LazyLoader
from .compat import sparse, scipy_sparse
@ -639,6 +639,7 @@ async def _train_async(client,
workers = list(_get_workers_from_data(dtrain, evals))
_rabit_args = await _get_rabit_args(len(workers), client)
_global_config = config.get_config()
def dispatched_train(worker_addr, rabit_args, dtrain_ref, dtrain_idt, evals_ref):
'''Perform training on a single worker. A local function prevents pickling.
@ -646,7 +647,7 @@ async def _train_async(client,
'''
LOGGER.info('Training on %s', str(worker_addr))
worker = distributed.get_worker()
with RabitContext(rabit_args):
with RabitContext(rabit_args), config.config_context(**_global_config):
local_dtrain = _dmatrix_from_list_of_parts(**dtrain_ref)
local_evals = []
if evals_ref:
@ -770,18 +771,21 @@ async def _predict_async(client, model, data, missing=numpy.nan, **kwargs):
raise TypeError(_expect([DaskDMatrix, da.Array, dd.DataFrame],
type(data)))
_global_config = config.get_config()
def mapped_predict(partition, is_df):
worker = distributed.get_worker()
booster.set_param({'nthread': worker.nthreads})
m = DMatrix(partition, missing=missing, nthread=worker.nthreads)
predt = booster.predict(m, validate_features=False, **kwargs)
if is_df:
if lazy_isinstance(partition, 'cudf', 'core.dataframe.DataFrame'):
import cudf # pylint: disable=import-error
predt = cudf.DataFrame(predt, columns=['prediction'])
else:
predt = DataFrame(predt, columns=['prediction'])
return predt
with config.config_context(**_global_config):
booster.set_param({'nthread': worker.nthreads})
m = DMatrix(partition, missing=missing, nthread=worker.nthreads)
predt = booster.predict(m, validate_features=False, **kwargs)
if is_df:
if lazy_isinstance(partition, 'cudf', 'core.dataframe.DataFrame'):
import cudf # pylint: disable=import-error
predt = cudf.DataFrame(predt, columns=['prediction'])
else:
predt = DataFrame(predt, columns=['prediction'])
return predt
# Predict on dask collection directly.
if isinstance(data, (da.Array, dd.DataFrame)):
return await _direct_predict_impl(client, data, mapped_predict)
@ -797,31 +801,32 @@ async def _predict_async(client, model, data, missing=numpy.nan, **kwargs):
def dispatched_predict(worker_id, list_of_orders, list_of_parts):
'''Perform prediction on each worker.'''
LOGGER.info('Predicting on %d', worker_id)
worker = distributed.get_worker()
list_of_parts = _get_worker_parts_ordered(meta_names, list_of_parts)
predictions = []
with config.config_context(**_global_config):
worker = distributed.get_worker()
list_of_parts = _get_worker_parts_ordered(meta_names, list_of_parts)
predictions = []
booster.set_param({'nthread': worker.nthreads})
for i, parts in enumerate(list_of_parts):
(data, _, _, base_margin, _, _) = parts
order = list_of_orders[i]
local_part = DMatrix(
data,
base_margin=base_margin,
feature_names=feature_names,
feature_types=feature_types,
missing=missing,
nthread=worker.nthreads
)
predt = booster.predict(
data=local_part,
validate_features=local_part.num_row() != 0,
**kwargs)
columns = 1 if len(predt.shape) == 1 else predt.shape[1]
ret = ((dask.delayed(predt), columns), order)
predictions.append(ret)
booster.set_param({'nthread': worker.nthreads})
for i, parts in enumerate(list_of_parts):
(data, _, _, base_margin, _, _) = parts
order = list_of_orders[i]
local_part = DMatrix(
data,
base_margin=base_margin,
feature_names=feature_names,
feature_types=feature_types,
missing=missing,
nthread=worker.nthreads
)
predt = booster.predict(
data=local_part,
validate_features=local_part.num_row() != 0,
**kwargs)
columns = 1 if len(predt.shape) == 1 else predt.shape[1]
ret = ((dask.delayed(predt), columns), order)
predictions.append(ret)
return predictions
return predictions
def dispatched_get_shape(worker_id, list_of_orders, list_of_parts):
'''Get shape of data in each worker.'''

View File

@ -18,9 +18,11 @@
#include "xgboost/logging.h"
#include "xgboost/version_config.h"
#include "xgboost/json.h"
#include "xgboost/global_config.h"
#include "c_api_error.h"
#include "../common/io.h"
#include "../common/charconv.h"
#include "../data/adapter.h"
#include "../data/simple_dmatrix.h"
#include "../data/proxy_dmatrix.h"
@ -46,6 +48,91 @@ XGB_DLL int XGBRegisterLogCallback(void (*callback)(const char*)) {
API_END();
}
XGB_DLL int XGBSetGlobalConfig(const char* json_str) {
API_BEGIN();
std::string str{json_str};
Json config{Json::Load(StringView{str.data(), str.size()})};
for (auto& items : get<Object>(config)) {
switch (items.second.GetValue().Type()) {
case xgboost::Value::ValueKind::kInteger: {
items.second = String{std::to_string(get<Integer const>(items.second))};
break;
}
case xgboost::Value::ValueKind::kBoolean: {
if (get<Boolean const>(items.second)) {
items.second = String{"true"};
} else {
items.second = String{"false"};
}
break;
}
case xgboost::Value::ValueKind::kNumber: {
auto n = get<Number const>(items.second);
char chars[NumericLimits<float>::kToCharsSize];
auto ec = to_chars(chars, chars + sizeof(chars), n).ec;
CHECK(ec == std::errc());
items.second = String{chars};
break;
}
default:
break;
}
}
auto unknown = FromJson(config, GlobalConfigThreadLocalStore::Get());
if (!unknown.empty()) {
std::stringstream ss;
ss << "Unknown global parameters: { ";
size_t i = 0;
for (auto const& item : unknown) {
ss << item.first;
i++;
if (i != unknown.size()) {
ss << ", ";
}
}
LOG(FATAL) << ss.str() << " }";
}
API_END();
}
using GlobalConfigAPIThreadLocalStore = dmlc::ThreadLocalStore<XGBAPIThreadLocalEntry>;
XGB_DLL int XGBGetGlobalConfig(const char** json_str) {
API_BEGIN();
auto const& global_config = *GlobalConfigThreadLocalStore::Get();
Json config {ToJson(global_config)};
auto const* mgr = global_config.__MANAGER__();
for (auto& item : get<Object>(config)) {
auto const &str = get<String const>(item.second);
auto const &name = item.first;
auto e = mgr->Find(name);
CHECK(e);
if (dynamic_cast<dmlc::parameter::FieldEntry<int32_t> const*>(e) ||
dynamic_cast<dmlc::parameter::FieldEntry<int64_t> const*>(e) ||
dynamic_cast<dmlc::parameter::FieldEntry<uint32_t> const*>(e) ||
dynamic_cast<dmlc::parameter::FieldEntry<uint64_t> const*>(e)) {
auto i = std::strtoimax(str.data(), nullptr, 10);
CHECK_LE(i, static_cast<intmax_t>(std::numeric_limits<int64_t>::max()));
item.second = Integer(static_cast<int64_t>(i));
} else if (dynamic_cast<dmlc::parameter::FieldEntry<float> const *>(e) ||
dynamic_cast<dmlc::parameter::FieldEntry<double> const *>(e)) {
float f;
auto ec = from_chars(str.data(), str.data() + str.size(), f).ec;
CHECK(ec == std::errc());
item.second = Number(f);
} else if (dynamic_cast<dmlc::parameter::FieldEntry<bool> const *>(e)) {
item.second = Boolean(str != "0");
}
}
auto& local = *GlobalConfigAPIThreadLocalStore::Get();
Json::Dump(config, &local.ret_str);
*json_str = local.ret_str.c_str();
API_END();
}
XGB_DLL int XGDMatrixCreateFromFile(const char *fname,
int silent,
DMatrixHandle *out) {

14
src/global_config.cc Normal file
View File

@ -0,0 +1,14 @@
/*!
* Copyright 2020 by Contributors
* \file global_config.cc
* \brief Global configuration for XGBoost
* \author Hyunsu Cho
*/
#include <dmlc/thread_local.h>
#include "xgboost/global_config.h"
#include "xgboost/json.h"
namespace xgboost {
DMLC_REGISTER_PARAMETER(GlobalConfiguration);
} // namespace xgboost

View File

@ -490,6 +490,12 @@ class LearnerConfiguration : public Learner {
// Extract all parameters
std::vector<std::string> keys;
// First global parameters
Json const global_config{ToJson(*GlobalConfigThreadLocalStore::Get())};
for (auto const& items : get<Object const>(global_config)) {
keys.emplace_back(items.first);
}
// Parameters in various xgboost components.
while (!stack.empty()) {
auto j_obj = stack.top();
stack.pop();

View File

@ -11,12 +11,13 @@
#include "xgboost/parameter.h"
#include "xgboost/logging.h"
#include "xgboost/json.h"
#if !defined(XGBOOST_STRICT_R_MODE) || XGBOOST_STRICT_R_MODE == 0
// Override logging mechanism for non-R interfaces
void dmlc::CustomLogMessage::Log(const std::string& msg) {
const xgboost::LogCallbackRegistry* registry
= xgboost::LogCallbackRegistryStore::Get();
const xgboost::LogCallbackRegistry *registry =
xgboost::LogCallbackRegistryStore::Get();
auto callback = registry->Get();
callback(msg.c_str());
}
@ -40,35 +41,15 @@ TrackerLogger::~TrackerLogger() {
namespace xgboost {
DMLC_REGISTER_PARAMETER(ConsoleLoggerParam);
ConsoleLogger::LogVerbosity ConsoleLogger::global_verbosity_ =
ConsoleLogger::DefaultVerbosity();
ConsoleLoggerParam ConsoleLogger::param_ = ConsoleLoggerParam();
bool ConsoleLogger::ShouldLog(LogVerbosity verbosity) {
return verbosity <= global_verbosity_ || verbosity == LV::kIgnore;
return static_cast<int>(verbosity) <=
(GlobalConfigThreadLocalStore::Get()->verbosity) ||
verbosity == LV::kIgnore;
}
void ConsoleLogger::Configure(Args const& args) {
param_.UpdateAllowUnknown(args);
switch (param_.verbosity) {
case 0:
global_verbosity_ = LogVerbosity::kSilent;
break;
case 1:
global_verbosity_ = LogVerbosity::kWarning;
break;
case 2:
global_verbosity_ = LogVerbosity::kInfo;
break;
case 3:
global_verbosity_ = LogVerbosity::kDebug;
default:
// global verbosity doesn't require kIgnore
break;
}
auto& param = *GlobalConfigThreadLocalStore::Get();
param.UpdateAllowUnknown(args);
}
ConsoleLogger::LogVerbosity ConsoleLogger::DefaultVerbosity() {
@ -76,7 +57,25 @@ ConsoleLogger::LogVerbosity ConsoleLogger::DefaultVerbosity() {
}
ConsoleLogger::LogVerbosity ConsoleLogger::GlobalVerbosity() {
return global_verbosity_;
LogVerbosity global_verbosity { LogVerbosity::kWarning };
switch (GlobalConfigThreadLocalStore::Get()->verbosity) {
case 0:
global_verbosity = LogVerbosity::kSilent;
break;
case 1:
global_verbosity = LogVerbosity::kWarning;
break;
case 2:
global_verbosity = LogVerbosity::kInfo;
break;
case 3:
global_verbosity = LogVerbosity::kDebug;
default:
// global verbosity doesn't require kIgnore
break;
}
return global_verbosity;
}
ConsoleLogger::ConsoleLogger(LogVerbosity cur_verb) :

View File

@ -212,4 +212,50 @@ TEST(CAPI, Exception) {
// Not null
ASSERT_TRUE(error);
}
TEST(CAPI, XGBGlobalConfig) {
int ret;
{
const char *config_str = R"json(
{
"verbosity": 0
}
)json";
ret = XGBSetGlobalConfig(config_str);
ASSERT_EQ(ret, 0);
const char *updated_config_cstr;
ret = XGBGetGlobalConfig(&updated_config_cstr);
ASSERT_EQ(ret, 0);
std::string updated_config_str{updated_config_cstr};
auto updated_config =
Json::Load({updated_config_str.data(), updated_config_str.size()});
ASSERT_EQ(get<Integer>(updated_config["verbosity"]), 0);
}
{
const char *config_str = R"json(
{
"foo": 0
}
)json";
ret = XGBSetGlobalConfig(config_str);
ASSERT_EQ(ret , -1);
auto err = std::string{XGBGetLastError()};
ASSERT_NE(err.find("foo"), std::string::npos);
}
{
const char *config_str = R"json(
{
"foo": 0,
"verbosity": 0
}
)json";
ret = XGBSetGlobalConfig(config_str);
ASSERT_EQ(ret , -1);
auto err = std::string{XGBGetLastError()};
ASSERT_NE(err.find("foo"), std::string::npos);
ASSERT_EQ(err.find("verbosity"), std::string::npos);
}
}
} // namespace xgboost

View File

@ -0,0 +1,22 @@
#include <gtest/gtest.h>
#include <xgboost/json.h>
#include <xgboost/logging.h>
#include <xgboost/global_config.h>
namespace xgboost {
TEST(GlobalConfiguration, Verbosity) {
// Configure verbosity via global configuration
Json config{JsonObject()};
config["verbosity"] = String("0");
auto& global_config = *GlobalConfigThreadLocalStore::Get();
FromJson(config, &global_config);
// Now verbosity should be updated
EXPECT_EQ(ConsoleLogger::GlobalVerbosity(), ConsoleLogger::LogVerbosity::kSilent);
EXPECT_NE(ConsoleLogger::LogVerbosity::kSilent, ConsoleLogger::DefaultVerbosity());
// GetConfig() should also return updated verbosity
Json current_config { ToJson(*GlobalConfigThreadLocalStore::Get()) };
EXPECT_EQ(get<String>(current_config["verbosity"]), "0");
}
} // namespace xgboost

View File

@ -0,0 +1,16 @@
# -*- coding: utf-8 -*-
import xgboost as xgb
import pytest
import testing as tm
@pytest.mark.parametrize('verbosity_level', [0, 1, 2, 3])
def test_global_config_verbosity(verbosity_level):
def get_current_verbosity():
return xgb.get_config()['verbosity']
old_verbosity = get_current_verbosity()
with xgb.config_context(verbosity=verbosity_level):
new_verbosity = get_current_verbosity()
assert new_verbosity == verbosity_level
assert old_verbosity == get_current_verbosity()

View File

@ -637,6 +637,46 @@ def test_aft_survival():
class TestWithDask:
def test_global_config(self, client):
X, y = generate_array()
xgb.config.set_config(verbosity=0)
dtrain = DaskDMatrix(client, X, y)
before_fname = './before_training-test_global_config'
after_fname = './after_training-test_global_config'
class TestCallback(xgb.callback.TrainingCallback):
def write_file(self, fname):
with open(fname, 'w') as fd:
fd.write(str(xgb.config.get_config()['verbosity']))
def before_training(self, model):
self.write_file(before_fname)
assert xgb.config.get_config()['verbosity'] == 0
return model
def after_training(self, model):
assert xgb.config.get_config()['verbosity'] == 0
return model
def before_iteration(self, model, epoch, evals_log):
assert xgb.config.get_config()['verbosity'] == 0
return False
def after_iteration(self, model, epoch, evals_log):
self.write_file(after_fname)
assert xgb.config.get_config()['verbosity'] == 0
return False
xgb.dask.train(client, {}, dtrain, num_boost_round=4, callbacks=[TestCallback()])[
'booster']
with open(before_fname, 'r') as before, open(after_fname, 'r') as after:
assert before.read() == '0'
assert after.read() == '0'
os.remove(before_fname)
os.remove(after_fname)
def run_updater_test(self, client, params, num_rounds, dataset,
tree_method):
params['tree_method'] = tree_method