Add global configuration (#6414)
* Add management functions for global configuration: XGBSetGlobalConfig(), XGBGetGlobalConfig(). * Add Python interface: set_config(), get_config(), and config_context(). * Add unit tests for Python * Add R interface: xgb.set.config(), xgb.get.config() * Add unit tests for R Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
parent
c2ba4fb957
commit
fb56da5e8b
@ -53,7 +53,6 @@ Suggests:
|
||||
testthat,
|
||||
lintr,
|
||||
igraph (>= 1.0.1),
|
||||
jsonlite,
|
||||
float,
|
||||
crayon,
|
||||
titanic
|
||||
@ -64,5 +63,6 @@ Imports:
|
||||
methods,
|
||||
data.table (>= 1.9.6),
|
||||
magrittr (>= 1.5),
|
||||
jsonlite (>= 1.0),
|
||||
RoxygenNote: 7.1.1
|
||||
SystemRequirements: GNU make, C++14
|
||||
|
||||
@ -36,6 +36,7 @@ export(xgb.create.features)
|
||||
export(xgb.cv)
|
||||
export(xgb.dump)
|
||||
export(xgb.gblinear.history)
|
||||
export(xgb.get.config)
|
||||
export(xgb.ggplot.deepness)
|
||||
export(xgb.ggplot.importance)
|
||||
export(xgb.ggplot.shap.summary)
|
||||
@ -52,6 +53,7 @@ export(xgb.plot.tree)
|
||||
export(xgb.save)
|
||||
export(xgb.save.raw)
|
||||
export(xgb.serialize)
|
||||
export(xgb.set.config)
|
||||
export(xgb.train)
|
||||
export(xgb.unserialize)
|
||||
export(xgboost)
|
||||
@ -78,6 +80,8 @@ importFrom(graphics,lines)
|
||||
importFrom(graphics,par)
|
||||
importFrom(graphics,points)
|
||||
importFrom(graphics,title)
|
||||
importFrom(jsonlite,fromJSON)
|
||||
importFrom(jsonlite,toJSON)
|
||||
importFrom(magrittr,"%>%")
|
||||
importFrom(stats,median)
|
||||
importFrom(stats,predict)
|
||||
|
||||
38
R-package/R/xgb.config.R
Normal file
38
R-package/R/xgb.config.R
Normal file
@ -0,0 +1,38 @@
|
||||
#' Global configuration consists of a collection of parameters that can be applied in the global
|
||||
#' scope. See \url{https://xgboost.readthedocs.io/en/stable/parameter.html} for the full list of
|
||||
#' parameters supported in the global configuration. Use \code{xgb.set.config} to update the
|
||||
#' values of one or more global-scope parameters. Use \code{xgb.get.config} to fetch the current
|
||||
#' values of all global-scope parameters (listed in
|
||||
#' \url{https://xgboost.readthedocs.io/en/stable/parameter.html}).
|
||||
#'
|
||||
#' @rdname xgbConfig
|
||||
#' @title Set and get global configuration
|
||||
#' @name xgb.set.config, xgb.get.config
|
||||
#' @export xgb.set.config xgb.get.config
|
||||
#' @param ... List of parameters to be set, as keyword arguments
|
||||
#' @return
|
||||
#' \code{xgb.set.config} returns \code{TRUE} to signal success. \code{xgb.get.config} returns
|
||||
#' a list containing all global-scope parameters and their values.
|
||||
#'
|
||||
#' @examples
|
||||
#' # Set verbosity level to silent (0)
|
||||
#' xgb.set.config(verbosity = 0)
|
||||
#' # Now global verbosity level is 0
|
||||
#' config <- xgb.get.config()
|
||||
#' print(config$verbosity)
|
||||
#' # Set verbosity level to warning (1)
|
||||
#' xgb.set.config(verbosity = 1)
|
||||
#' # Now global verbosity level is 1
|
||||
#' config <- xgb.get.config()
|
||||
#' print(config$verbosity)
|
||||
xgb.set.config <- function(...) {
|
||||
new_config <- list(...)
|
||||
.Call(XGBSetGlobalConfig_R, jsonlite::toJSON(new_config, auto_unbox = TRUE))
|
||||
return(TRUE)
|
||||
}
|
||||
|
||||
#' @rdname xgbConfig
|
||||
xgb.get.config <- function() {
|
||||
config <- .Call(XGBGetGlobalConfig_R)
|
||||
return(jsonlite::fromJSON(config))
|
||||
}
|
||||
@ -91,6 +91,8 @@ NULL
|
||||
#' @importFrom data.table setkeyv
|
||||
#' @importFrom data.table setnames
|
||||
#' @importFrom magrittr %>%
|
||||
#' @importFrom jsonlite fromJSON
|
||||
#' @importFrom jsonlite toJSON
|
||||
#' @importFrom utils object.size str tail
|
||||
#' @importFrom stats predict
|
||||
#' @importFrom stats median
|
||||
|
||||
39
R-package/man/xgbConfig.Rd
Normal file
39
R-package/man/xgbConfig.Rd
Normal file
@ -0,0 +1,39 @@
|
||||
% Generated by roxygen2: do not edit by hand
|
||||
% Please edit documentation in R/xgb.config.R
|
||||
\name{xgb.set.config, xgb.get.config}
|
||||
\alias{xgb.set.config, xgb.get.config}
|
||||
\alias{xgb.set.config}
|
||||
\alias{xgb.get.config}
|
||||
\title{Set and get global configuration}
|
||||
\usage{
|
||||
xgb.set.config(...)
|
||||
|
||||
xgb.get.config()
|
||||
}
|
||||
\arguments{
|
||||
\item{...}{List of parameters to be set, as keyword arguments}
|
||||
}
|
||||
\value{
|
||||
\code{xgb.set.config} returns \code{TRUE} to signal success. \code{xgb.get.config} returns
|
||||
a list containing all global-scope parameters and their values.
|
||||
}
|
||||
\description{
|
||||
Global configuration consists of a collection of parameters that can be applied in the global
|
||||
scope. See \url{https://xgboost.readthedocs.io/en/stable/parameter.html} for the full list of
|
||||
parameters supported in the global configuration. Use \code{xgb.set.config} to update the
|
||||
values of one or more global-scope parameters. Use \code{xgb.get.config} to fetch the current
|
||||
values of all global-scope parameters (listed in
|
||||
\url{https://xgboost.readthedocs.io/en/stable/parameter.html}).
|
||||
}
|
||||
\examples{
|
||||
# Set verbosity level to silent (0)
|
||||
xgb.set.config(verbosity = 0)
|
||||
# Now global verbosity level is 0
|
||||
config <- xgb.get.config()
|
||||
print(config$verbosity)
|
||||
# Set verbosity level to warning (1)
|
||||
xgb.set.config(verbosity = 1)
|
||||
# Now global verbosity level is 1
|
||||
config <- xgb.get.config()
|
||||
print(config$verbosity)
|
||||
}
|
||||
@ -43,6 +43,8 @@ extern SEXP XGDMatrixNumRow_R(SEXP);
|
||||
extern SEXP XGDMatrixSaveBinary_R(SEXP, SEXP, SEXP);
|
||||
extern SEXP XGDMatrixSetInfo_R(SEXP, SEXP, SEXP);
|
||||
extern SEXP XGDMatrixSliceDMatrix_R(SEXP, SEXP);
|
||||
extern SEXP XGBSetGlobalConfig_R(SEXP);
|
||||
extern SEXP XGBGetGlobalConfig_R();
|
||||
|
||||
static const R_CallMethodDef CallEntries[] = {
|
||||
{"XGBoosterBoostOneIter_R", (DL_FUNC) &XGBoosterBoostOneIter_R, 4},
|
||||
@ -73,6 +75,8 @@ static const R_CallMethodDef CallEntries[] = {
|
||||
{"XGDMatrixSaveBinary_R", (DL_FUNC) &XGDMatrixSaveBinary_R, 3},
|
||||
{"XGDMatrixSetInfo_R", (DL_FUNC) &XGDMatrixSetInfo_R, 3},
|
||||
{"XGDMatrixSliceDMatrix_R", (DL_FUNC) &XGDMatrixSliceDMatrix_R, 2},
|
||||
{"XGBSetGlobalConfig_R", (DL_FUNC) &XGBSetGlobalConfig_R, 1},
|
||||
{"XGBGetGlobalConfig_R", (DL_FUNC) &XGBGetGlobalConfig_R, 0},
|
||||
{NULL, NULL, 0}
|
||||
};
|
||||
|
||||
|
||||
@ -49,6 +49,21 @@ void _DMatrixFinalizer(SEXP ext) {
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
SEXP XGBSetGlobalConfig_R(SEXP json_str) {
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(XGBSetGlobalConfig(CHAR(asChar(json_str))));
|
||||
R_API_END();
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
SEXP XGBGetGlobalConfig_R() {
|
||||
const char* json_str;
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(XGBGetGlobalConfig(&json_str));
|
||||
R_API_END();
|
||||
return mkString(json_str);
|
||||
}
|
||||
|
||||
SEXP XGDMatrixCreateFromFile_R(SEXP fname, SEXP silent) {
|
||||
SEXP ret;
|
||||
R_API_BEGIN();
|
||||
|
||||
@ -21,6 +21,19 @@
|
||||
*/
|
||||
XGB_DLL SEXP XGCheckNullPtr_R(SEXP handle);
|
||||
|
||||
/*!
|
||||
* \brief Set global configuration
|
||||
* \param json_str a JSON string representing the list of key-value pairs
|
||||
* \return R_NilValue
|
||||
*/
|
||||
XGB_DLL SEXP XGBSetGlobalConfig_R(SEXP json_str);
|
||||
|
||||
/*!
|
||||
* \brief Get global configuration
|
||||
* \return JSON string
|
||||
*/
|
||||
XGB_DLL SEXP XGBGetGlobalConfig_R();
|
||||
|
||||
/*!
|
||||
* \brief load a data matrix
|
||||
* \param fname name of the content
|
||||
|
||||
@ -16,7 +16,7 @@ void CustomLogMessage::Log(const std::string& msg) {
|
||||
namespace xgboost {
|
||||
ConsoleLogger::~ConsoleLogger() {
|
||||
if (cur_verbosity_ == LogVerbosity::kIgnore ||
|
||||
cur_verbosity_ <= global_verbosity_) {
|
||||
cur_verbosity_ <= GlobalVerbosity()) {
|
||||
dmlc::CustomLogMessage::Log(log_stream_.str());
|
||||
}
|
||||
}
|
||||
|
||||
11
R-package/tests/testthat/test_config.R
Normal file
11
R-package/tests/testthat/test_config.R
Normal file
@ -0,0 +1,11 @@
|
||||
context('Test global configuration')
|
||||
|
||||
test_that('Global configuration works with verbosity', {
|
||||
old_verbosity <- xgb.get.config()$verbosity
|
||||
for (v in c(0, 1, 2, 3)) {
|
||||
xgb.set.config(verbosity = v)
|
||||
expect_equal(xgb.get.config()$verbosity, v)
|
||||
}
|
||||
xgb.set.config(verbosity = old_verbosity)
|
||||
expect_equal(xgb.get.config()$verbosity, old_verbosity)
|
||||
})
|
||||
@ -67,6 +67,7 @@
|
||||
// global
|
||||
#include "../src/learner.cc"
|
||||
#include "../src/logging.cc"
|
||||
#include "../src/global_config.cc"
|
||||
#include "../src/common/common.cc"
|
||||
#include "../src/common/random.cc"
|
||||
#include "../src/common/charconv.cc"
|
||||
|
||||
@ -16,6 +16,13 @@ Before running XGBoost, we must set three types of parameters: general parameter
|
||||
:backlinks: none
|
||||
:local:
|
||||
|
||||
********************
|
||||
Global Configuration
|
||||
********************
|
||||
The following parameters can be set in the global scope, using ``xgb.config_context()`` (Python) or ``xgb.set.config()`` (R).
|
||||
|
||||
* ``verbosity``: Verbosity of printing messages. Valid values of 0 (silent), 1 (warning), 2 (info), and 3 (debug).
|
||||
|
||||
******************
|
||||
General Parameters
|
||||
******************
|
||||
|
||||
@ -6,6 +6,14 @@ This page gives the Python API reference of xgboost, please also refer to Python
|
||||
:backlinks: none
|
||||
:local:
|
||||
|
||||
Global Configuration
|
||||
--------------------
|
||||
.. autofunction:: xgboost.config_context
|
||||
|
||||
.. autofunction:: xgboost.set_config
|
||||
|
||||
.. autofunction:: xgboost.get_config
|
||||
|
||||
Core Data Structure
|
||||
-------------------
|
||||
.. automodule:: xgboost.core
|
||||
|
||||
@ -63,6 +63,23 @@ XGB_DLL const char *XGBGetLastError(void);
|
||||
*/
|
||||
XGB_DLL int XGBRegisterLogCallback(void (*callback)(const char*));
|
||||
|
||||
/*!
|
||||
* \brief Set global configuration (collection of parameters that apply globally). This function
|
||||
* accepts the list of key-value pairs representing the global-scope parameters to be
|
||||
* configured. The list of key-value pairs are passed in as a JSON string.
|
||||
* \param json_str a JSON string representing the list of key-value pairs. The JSON object shall
|
||||
* be flat: no value can be a JSON object or an array.
|
||||
* \return 0 for success, -1 for failure
|
||||
*/
|
||||
XGB_DLL int XGBSetGlobalConfig(const char* json_str);
|
||||
|
||||
/*!
|
||||
* \brief Get current global configuration (collection of parameters that apply globally).
|
||||
* \param json_str pointer to received returned global configuration, represented as a JSON string.
|
||||
* \return 0 for success, -1 for failure
|
||||
*/
|
||||
XGB_DLL int XGBGetGlobalConfig(const char** json_str);
|
||||
|
||||
/*!
|
||||
* \brief load a data matrix
|
||||
* \param fname the name of the file
|
||||
|
||||
30
include/xgboost/global_config.h
Normal file
30
include/xgboost/global_config.h
Normal file
@ -0,0 +1,30 @@
|
||||
/*!
|
||||
* Copyright 2020 by Contributors
|
||||
* \file global_config.h
|
||||
* \brief Global configuration for XGBoost
|
||||
* \author Hyunsu Cho
|
||||
*/
|
||||
#ifndef XGBOOST_GLOBAL_CONFIG_H_
|
||||
#define XGBOOST_GLOBAL_CONFIG_H_
|
||||
|
||||
#include <xgboost/parameter.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
namespace xgboost {
|
||||
class Json;
|
||||
|
||||
struct GlobalConfiguration : public XGBoostParameter<GlobalConfiguration> {
|
||||
int verbosity;
|
||||
DMLC_DECLARE_PARAMETER(GlobalConfiguration) {
|
||||
DMLC_DECLARE_FIELD(verbosity)
|
||||
.set_range(0, 3)
|
||||
.set_default(1) // shows only warning
|
||||
.describe("Flag to print out detailed breakdown of runtime.");
|
||||
}
|
||||
};
|
||||
|
||||
using GlobalConfigThreadLocalStore = dmlc::ThreadLocalStore<GlobalConfiguration>;
|
||||
} // namespace xgboost
|
||||
|
||||
#endif // XGBOOST_GLOBAL_CONFIG_H_
|
||||
@ -557,7 +557,6 @@ using String = JsonString;
|
||||
using Null = JsonNull;
|
||||
|
||||
// Utils tailored for XGBoost.
|
||||
|
||||
template <typename Parameter>
|
||||
Object ToJson(Parameter const& param) {
|
||||
Object obj;
|
||||
@ -568,13 +567,13 @@ Object ToJson(Parameter const& param) {
|
||||
}
|
||||
|
||||
template <typename Parameter>
|
||||
void FromJson(Json const& obj, Parameter* param) {
|
||||
Args FromJson(Json const& obj, Parameter* param) {
|
||||
auto const& j_param = get<Object const>(obj);
|
||||
std::map<std::string, std::string> m;
|
||||
for (auto const& kv : j_param) {
|
||||
m[kv.first] = get<String const>(kv.second);
|
||||
}
|
||||
param->UpdateAllowUnknown(m);
|
||||
return param->UpdateAllowUnknown(m);
|
||||
}
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_JSON_H_
|
||||
|
||||
@ -45,7 +45,6 @@ struct XGBAPIThreadLocalEntry {
|
||||
PredictionCacheEntry prediction_entry;
|
||||
};
|
||||
|
||||
|
||||
/*!
|
||||
* \brief Learner class that does training and prediction.
|
||||
* This is the user facing module of xgboost training.
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/parameter.h>
|
||||
#include <xgboost/global_config.h>
|
||||
|
||||
#include <sstream>
|
||||
#include <map>
|
||||
@ -35,19 +36,6 @@ class BaseLogger {
|
||||
std::ostringstream log_stream_;
|
||||
};
|
||||
|
||||
// Parsing both silent and debug_verbose is to provide backward compatibility.
|
||||
struct ConsoleLoggerParam : public XGBoostParameter<ConsoleLoggerParam> {
|
||||
int verbosity;
|
||||
|
||||
DMLC_DECLARE_PARAMETER(ConsoleLoggerParam) {
|
||||
DMLC_DECLARE_FIELD(verbosity)
|
||||
.set_range(0, 3)
|
||||
.set_default(1) // shows only warning
|
||||
.describe("Flag to print out detailed breakdown of runtime.");
|
||||
DMLC_DECLARE_ALIAS(verbosity, debug_verbose);
|
||||
}
|
||||
};
|
||||
|
||||
class ConsoleLogger : public BaseLogger {
|
||||
public:
|
||||
enum class LogVerbosity {
|
||||
@ -60,9 +48,6 @@ class ConsoleLogger : public BaseLogger {
|
||||
using LV = LogVerbosity;
|
||||
|
||||
private:
|
||||
static LogVerbosity global_verbosity_;
|
||||
static ConsoleLoggerParam param_;
|
||||
|
||||
LogVerbosity cur_verbosity_;
|
||||
|
||||
public:
|
||||
|
||||
@ -17,6 +17,7 @@ try:
|
||||
from .sklearn import XGBModel, XGBClassifier, XGBRegressor, XGBRanker
|
||||
from .sklearn import XGBRFClassifier, XGBRFRegressor
|
||||
from .plotting import plot_importance, plot_tree, to_graphviz
|
||||
from .config import set_config, get_config, config_context
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
@ -29,4 +30,5 @@ __all__ = ['DMatrix', 'DeviceQuantileDMatrix', 'Booster',
|
||||
'RabitTracker',
|
||||
'XGBModel', 'XGBClassifier', 'XGBRegressor', 'XGBRanker',
|
||||
'XGBRFClassifier', 'XGBRFRegressor',
|
||||
'plot_importance', 'plot_tree', 'to_graphviz', 'dask']
|
||||
'plot_importance', 'plot_tree', 'to_graphviz', 'dask',
|
||||
'set_config', 'get_config', 'config_context']
|
||||
|
||||
142
python-package/xgboost/config.py
Normal file
142
python-package/xgboost/config.py
Normal file
@ -0,0 +1,142 @@
|
||||
# pylint: disable=missing-function-docstring
|
||||
"""Global configuration for XGBoost"""
|
||||
import ctypes
|
||||
import json
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
|
||||
from .core import _LIB, _check_call, c_str, py_str
|
||||
|
||||
|
||||
def config_doc(*, header=None, extra_note=None, parameters=None, returns=None,
|
||||
see_also=None):
|
||||
"""Decorator to format docstring for config functions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
header: str
|
||||
An introducion to the function
|
||||
extra_note: str
|
||||
Additional notes
|
||||
parameters: str
|
||||
Parameters of the function
|
||||
returns: str
|
||||
Return value
|
||||
see_also: str
|
||||
Related functions
|
||||
"""
|
||||
|
||||
doc_template = """
|
||||
{header}
|
||||
|
||||
Global configuration consists of a collection of parameters that can be applied in the
|
||||
global scope. See https://xgboost.readthedocs.io/en/stable/parameter.html for the full
|
||||
list of parameters supported in the global configuration.
|
||||
|
||||
{extra_note}
|
||||
|
||||
.. versionadded:: 1.3.0
|
||||
"""
|
||||
|
||||
common_example = """
|
||||
Example
|
||||
-------
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import xgboost as xgb
|
||||
|
||||
# Show all messages, including ones pertaining to debugging
|
||||
xgb.set_config(verbosity=2)
|
||||
|
||||
# Get current value of global configuration
|
||||
# This is a dict containing all parameters in the global configuration,
|
||||
# including 'verbosity'
|
||||
config = xgb.get_config()
|
||||
assert config['verbosity'] == 2
|
||||
|
||||
# Example of using the context manager xgb.config_context().
|
||||
# The context manager will restore the previous value of the global
|
||||
# configuration upon exiting.
|
||||
with xgb.config_context(verbosity=0):
|
||||
# Suppress warning caused by model generated with XGBoost version < 1.0.0
|
||||
bst = xgb.Booster(model_file='./old_model.bin')
|
||||
assert xgb.get_config()['verbosity'] == 2 # old value restored
|
||||
"""
|
||||
|
||||
def none_to_str(value):
|
||||
return '' if value is None else value
|
||||
|
||||
def config_doc_decorator(func):
|
||||
func.__doc__ = (doc_template.format(header=none_to_str(header),
|
||||
extra_note=none_to_str(extra_note))
|
||||
+ none_to_str(parameters) + none_to_str(returns)
|
||||
+ none_to_str(common_example) + none_to_str(see_also))
|
||||
|
||||
@wraps(func)
|
||||
def wrap(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
return wrap
|
||||
return config_doc_decorator
|
||||
|
||||
|
||||
@config_doc(header="""
|
||||
Set global configuration.
|
||||
""",
|
||||
parameters="""
|
||||
Parameters
|
||||
----------
|
||||
new_config: Dict[str, Any]
|
||||
Keyword arguments representing the parameters and their values
|
||||
""")
|
||||
def set_config(**new_config):
|
||||
config = json.dumps(new_config)
|
||||
_check_call(_LIB.XGBSetGlobalConfig(c_str(config)))
|
||||
|
||||
|
||||
@config_doc(header="""
|
||||
Get current values of the global configuration.
|
||||
""",
|
||||
returns="""
|
||||
Returns
|
||||
-------
|
||||
args: Dict[str, Any]
|
||||
The list of global parameters and their values
|
||||
""")
|
||||
def get_config():
|
||||
config_str = ctypes.c_char_p()
|
||||
_check_call(_LIB.XGBGetGlobalConfig(ctypes.byref(config_str)))
|
||||
config = json.loads(py_str(config_str.value))
|
||||
return config
|
||||
|
||||
|
||||
@contextmanager
|
||||
@config_doc(header="""
|
||||
Context manager for global XGBoost configuration.
|
||||
""",
|
||||
parameters="""
|
||||
Parameters
|
||||
----------
|
||||
new_config: Dict[str, Any]
|
||||
Keyword arguments representing the parameters and their values
|
||||
""",
|
||||
extra_note="""
|
||||
.. note::
|
||||
|
||||
All settings, not just those presently modified, will be returned to their
|
||||
previous values when the context manager is exited. This is not thread-safe.
|
||||
""",
|
||||
see_also="""
|
||||
See Also
|
||||
--------
|
||||
set_config: Set global XGBoost configuration
|
||||
get_config: Get current values of the global configuration
|
||||
""")
|
||||
def config_context(**new_config):
|
||||
old_config = get_config().copy()
|
||||
set_config(**new_config)
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
set_config(**old_config)
|
||||
@ -22,7 +22,7 @@ from typing import List
|
||||
|
||||
import numpy
|
||||
|
||||
from . import rabit
|
||||
from . import rabit, config
|
||||
|
||||
from .compat import LazyLoader
|
||||
from .compat import sparse, scipy_sparse
|
||||
@ -639,6 +639,7 @@ async def _train_async(client,
|
||||
|
||||
workers = list(_get_workers_from_data(dtrain, evals))
|
||||
_rabit_args = await _get_rabit_args(len(workers), client)
|
||||
_global_config = config.get_config()
|
||||
|
||||
def dispatched_train(worker_addr, rabit_args, dtrain_ref, dtrain_idt, evals_ref):
|
||||
'''Perform training on a single worker. A local function prevents pickling.
|
||||
@ -646,7 +647,7 @@ async def _train_async(client,
|
||||
'''
|
||||
LOGGER.info('Training on %s', str(worker_addr))
|
||||
worker = distributed.get_worker()
|
||||
with RabitContext(rabit_args):
|
||||
with RabitContext(rabit_args), config.config_context(**_global_config):
|
||||
local_dtrain = _dmatrix_from_list_of_parts(**dtrain_ref)
|
||||
local_evals = []
|
||||
if evals_ref:
|
||||
@ -770,8 +771,11 @@ async def _predict_async(client, model, data, missing=numpy.nan, **kwargs):
|
||||
raise TypeError(_expect([DaskDMatrix, da.Array, dd.DataFrame],
|
||||
type(data)))
|
||||
|
||||
_global_config = config.get_config()
|
||||
|
||||
def mapped_predict(partition, is_df):
|
||||
worker = distributed.get_worker()
|
||||
with config.config_context(**_global_config):
|
||||
booster.set_param({'nthread': worker.nthreads})
|
||||
m = DMatrix(partition, missing=missing, nthread=worker.nthreads)
|
||||
predt = booster.predict(m, validate_features=False, **kwargs)
|
||||
@ -797,6 +801,7 @@ async def _predict_async(client, model, data, missing=numpy.nan, **kwargs):
|
||||
def dispatched_predict(worker_id, list_of_orders, list_of_parts):
|
||||
'''Perform prediction on each worker.'''
|
||||
LOGGER.info('Predicting on %d', worker_id)
|
||||
with config.config_context(**_global_config):
|
||||
worker = distributed.get_worker()
|
||||
list_of_parts = _get_worker_parts_ordered(meta_names, list_of_parts)
|
||||
predictions = []
|
||||
|
||||
@ -18,9 +18,11 @@
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/version_config.h"
|
||||
#include "xgboost/json.h"
|
||||
#include "xgboost/global_config.h"
|
||||
|
||||
#include "c_api_error.h"
|
||||
#include "../common/io.h"
|
||||
#include "../common/charconv.h"
|
||||
#include "../data/adapter.h"
|
||||
#include "../data/simple_dmatrix.h"
|
||||
#include "../data/proxy_dmatrix.h"
|
||||
@ -46,6 +48,91 @@ XGB_DLL int XGBRegisterLogCallback(void (*callback)(const char*)) {
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGBSetGlobalConfig(const char* json_str) {
|
||||
API_BEGIN();
|
||||
std::string str{json_str};
|
||||
Json config{Json::Load(StringView{str.data(), str.size()})};
|
||||
for (auto& items : get<Object>(config)) {
|
||||
switch (items.second.GetValue().Type()) {
|
||||
case xgboost::Value::ValueKind::kInteger: {
|
||||
items.second = String{std::to_string(get<Integer const>(items.second))};
|
||||
break;
|
||||
}
|
||||
case xgboost::Value::ValueKind::kBoolean: {
|
||||
if (get<Boolean const>(items.second)) {
|
||||
items.second = String{"true"};
|
||||
} else {
|
||||
items.second = String{"false"};
|
||||
}
|
||||
break;
|
||||
}
|
||||
case xgboost::Value::ValueKind::kNumber: {
|
||||
auto n = get<Number const>(items.second);
|
||||
char chars[NumericLimits<float>::kToCharsSize];
|
||||
auto ec = to_chars(chars, chars + sizeof(chars), n).ec;
|
||||
CHECK(ec == std::errc());
|
||||
items.second = String{chars};
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
auto unknown = FromJson(config, GlobalConfigThreadLocalStore::Get());
|
||||
if (!unknown.empty()) {
|
||||
std::stringstream ss;
|
||||
ss << "Unknown global parameters: { ";
|
||||
size_t i = 0;
|
||||
for (auto const& item : unknown) {
|
||||
ss << item.first;
|
||||
i++;
|
||||
if (i != unknown.size()) {
|
||||
ss << ", ";
|
||||
}
|
||||
}
|
||||
LOG(FATAL) << ss.str() << " }";
|
||||
}
|
||||
API_END();
|
||||
}
|
||||
|
||||
using GlobalConfigAPIThreadLocalStore = dmlc::ThreadLocalStore<XGBAPIThreadLocalEntry>;
|
||||
|
||||
XGB_DLL int XGBGetGlobalConfig(const char** json_str) {
|
||||
API_BEGIN();
|
||||
auto const& global_config = *GlobalConfigThreadLocalStore::Get();
|
||||
Json config {ToJson(global_config)};
|
||||
auto const* mgr = global_config.__MANAGER__();
|
||||
|
||||
for (auto& item : get<Object>(config)) {
|
||||
auto const &str = get<String const>(item.second);
|
||||
auto const &name = item.first;
|
||||
auto e = mgr->Find(name);
|
||||
CHECK(e);
|
||||
|
||||
if (dynamic_cast<dmlc::parameter::FieldEntry<int32_t> const*>(e) ||
|
||||
dynamic_cast<dmlc::parameter::FieldEntry<int64_t> const*>(e) ||
|
||||
dynamic_cast<dmlc::parameter::FieldEntry<uint32_t> const*>(e) ||
|
||||
dynamic_cast<dmlc::parameter::FieldEntry<uint64_t> const*>(e)) {
|
||||
auto i = std::strtoimax(str.data(), nullptr, 10);
|
||||
CHECK_LE(i, static_cast<intmax_t>(std::numeric_limits<int64_t>::max()));
|
||||
item.second = Integer(static_cast<int64_t>(i));
|
||||
} else if (dynamic_cast<dmlc::parameter::FieldEntry<float> const *>(e) ||
|
||||
dynamic_cast<dmlc::parameter::FieldEntry<double> const *>(e)) {
|
||||
float f;
|
||||
auto ec = from_chars(str.data(), str.data() + str.size(), f).ec;
|
||||
CHECK(ec == std::errc());
|
||||
item.second = Number(f);
|
||||
} else if (dynamic_cast<dmlc::parameter::FieldEntry<bool> const *>(e)) {
|
||||
item.second = Boolean(str != "0");
|
||||
}
|
||||
}
|
||||
|
||||
auto& local = *GlobalConfigAPIThreadLocalStore::Get();
|
||||
Json::Dump(config, &local.ret_str);
|
||||
*json_str = local.ret_str.c_str();
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixCreateFromFile(const char *fname,
|
||||
int silent,
|
||||
DMatrixHandle *out) {
|
||||
|
||||
14
src/global_config.cc
Normal file
14
src/global_config.cc
Normal file
@ -0,0 +1,14 @@
|
||||
/*!
|
||||
* Copyright 2020 by Contributors
|
||||
* \file global_config.cc
|
||||
* \brief Global configuration for XGBoost
|
||||
* \author Hyunsu Cho
|
||||
*/
|
||||
|
||||
#include <dmlc/thread_local.h>
|
||||
#include "xgboost/global_config.h"
|
||||
#include "xgboost/json.h"
|
||||
|
||||
namespace xgboost {
|
||||
DMLC_REGISTER_PARAMETER(GlobalConfiguration);
|
||||
} // namespace xgboost
|
||||
@ -490,6 +490,12 @@ class LearnerConfiguration : public Learner {
|
||||
|
||||
// Extract all parameters
|
||||
std::vector<std::string> keys;
|
||||
// First global parameters
|
||||
Json const global_config{ToJson(*GlobalConfigThreadLocalStore::Get())};
|
||||
for (auto const& items : get<Object const>(global_config)) {
|
||||
keys.emplace_back(items.first);
|
||||
}
|
||||
// Parameters in various xgboost components.
|
||||
while (!stack.empty()) {
|
||||
auto j_obj = stack.top();
|
||||
stack.pop();
|
||||
|
||||
@ -11,12 +11,13 @@
|
||||
|
||||
#include "xgboost/parameter.h"
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/json.h"
|
||||
|
||||
#if !defined(XGBOOST_STRICT_R_MODE) || XGBOOST_STRICT_R_MODE == 0
|
||||
// Override logging mechanism for non-R interfaces
|
||||
void dmlc::CustomLogMessage::Log(const std::string& msg) {
|
||||
const xgboost::LogCallbackRegistry* registry
|
||||
= xgboost::LogCallbackRegistryStore::Get();
|
||||
const xgboost::LogCallbackRegistry *registry =
|
||||
xgboost::LogCallbackRegistryStore::Get();
|
||||
auto callback = registry->Get();
|
||||
callback(msg.c_str());
|
||||
}
|
||||
@ -40,35 +41,15 @@ TrackerLogger::~TrackerLogger() {
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
DMLC_REGISTER_PARAMETER(ConsoleLoggerParam);
|
||||
|
||||
ConsoleLogger::LogVerbosity ConsoleLogger::global_verbosity_ =
|
||||
ConsoleLogger::DefaultVerbosity();
|
||||
|
||||
ConsoleLoggerParam ConsoleLogger::param_ = ConsoleLoggerParam();
|
||||
|
||||
bool ConsoleLogger::ShouldLog(LogVerbosity verbosity) {
|
||||
return verbosity <= global_verbosity_ || verbosity == LV::kIgnore;
|
||||
return static_cast<int>(verbosity) <=
|
||||
(GlobalConfigThreadLocalStore::Get()->verbosity) ||
|
||||
verbosity == LV::kIgnore;
|
||||
}
|
||||
|
||||
void ConsoleLogger::Configure(Args const& args) {
|
||||
param_.UpdateAllowUnknown(args);
|
||||
switch (param_.verbosity) {
|
||||
case 0:
|
||||
global_verbosity_ = LogVerbosity::kSilent;
|
||||
break;
|
||||
case 1:
|
||||
global_verbosity_ = LogVerbosity::kWarning;
|
||||
break;
|
||||
case 2:
|
||||
global_verbosity_ = LogVerbosity::kInfo;
|
||||
break;
|
||||
case 3:
|
||||
global_verbosity_ = LogVerbosity::kDebug;
|
||||
default:
|
||||
// global verbosity doesn't require kIgnore
|
||||
break;
|
||||
}
|
||||
auto& param = *GlobalConfigThreadLocalStore::Get();
|
||||
param.UpdateAllowUnknown(args);
|
||||
}
|
||||
|
||||
ConsoleLogger::LogVerbosity ConsoleLogger::DefaultVerbosity() {
|
||||
@ -76,7 +57,25 @@ ConsoleLogger::LogVerbosity ConsoleLogger::DefaultVerbosity() {
|
||||
}
|
||||
|
||||
ConsoleLogger::LogVerbosity ConsoleLogger::GlobalVerbosity() {
|
||||
return global_verbosity_;
|
||||
LogVerbosity global_verbosity { LogVerbosity::kWarning };
|
||||
switch (GlobalConfigThreadLocalStore::Get()->verbosity) {
|
||||
case 0:
|
||||
global_verbosity = LogVerbosity::kSilent;
|
||||
break;
|
||||
case 1:
|
||||
global_verbosity = LogVerbosity::kWarning;
|
||||
break;
|
||||
case 2:
|
||||
global_verbosity = LogVerbosity::kInfo;
|
||||
break;
|
||||
case 3:
|
||||
global_verbosity = LogVerbosity::kDebug;
|
||||
default:
|
||||
// global verbosity doesn't require kIgnore
|
||||
break;
|
||||
}
|
||||
|
||||
return global_verbosity;
|
||||
}
|
||||
|
||||
ConsoleLogger::ConsoleLogger(LogVerbosity cur_verb) :
|
||||
|
||||
@ -212,4 +212,50 @@ TEST(CAPI, Exception) {
|
||||
// Not null
|
||||
ASSERT_TRUE(error);
|
||||
}
|
||||
|
||||
TEST(CAPI, XGBGlobalConfig) {
|
||||
int ret;
|
||||
{
|
||||
const char *config_str = R"json(
|
||||
{
|
||||
"verbosity": 0
|
||||
}
|
||||
)json";
|
||||
ret = XGBSetGlobalConfig(config_str);
|
||||
ASSERT_EQ(ret, 0);
|
||||
const char *updated_config_cstr;
|
||||
ret = XGBGetGlobalConfig(&updated_config_cstr);
|
||||
ASSERT_EQ(ret, 0);
|
||||
|
||||
std::string updated_config_str{updated_config_cstr};
|
||||
auto updated_config =
|
||||
Json::Load({updated_config_str.data(), updated_config_str.size()});
|
||||
ASSERT_EQ(get<Integer>(updated_config["verbosity"]), 0);
|
||||
}
|
||||
{
|
||||
const char *config_str = R"json(
|
||||
{
|
||||
"foo": 0
|
||||
}
|
||||
)json";
|
||||
ret = XGBSetGlobalConfig(config_str);
|
||||
ASSERT_EQ(ret , -1);
|
||||
auto err = std::string{XGBGetLastError()};
|
||||
ASSERT_NE(err.find("foo"), std::string::npos);
|
||||
}
|
||||
{
|
||||
const char *config_str = R"json(
|
||||
{
|
||||
"foo": 0,
|
||||
"verbosity": 0
|
||||
}
|
||||
)json";
|
||||
ret = XGBSetGlobalConfig(config_str);
|
||||
ASSERT_EQ(ret , -1);
|
||||
auto err = std::string{XGBGetLastError()};
|
||||
ASSERT_NE(err.find("foo"), std::string::npos);
|
||||
ASSERT_EQ(err.find("verbosity"), std::string::npos);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
|
||||
22
tests/cpp/test_global_config.cc
Normal file
22
tests/cpp/test_global_config.cc
Normal file
@ -0,0 +1,22 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/json.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <xgboost/global_config.h>
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
TEST(GlobalConfiguration, Verbosity) {
|
||||
// Configure verbosity via global configuration
|
||||
Json config{JsonObject()};
|
||||
config["verbosity"] = String("0");
|
||||
auto& global_config = *GlobalConfigThreadLocalStore::Get();
|
||||
FromJson(config, &global_config);
|
||||
// Now verbosity should be updated
|
||||
EXPECT_EQ(ConsoleLogger::GlobalVerbosity(), ConsoleLogger::LogVerbosity::kSilent);
|
||||
EXPECT_NE(ConsoleLogger::LogVerbosity::kSilent, ConsoleLogger::DefaultVerbosity());
|
||||
// GetConfig() should also return updated verbosity
|
||||
Json current_config { ToJson(*GlobalConfigThreadLocalStore::Get()) };
|
||||
EXPECT_EQ(get<String>(current_config["verbosity"]), "0");
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
16
tests/python/test_config.py
Normal file
16
tests/python/test_config.py
Normal file
@ -0,0 +1,16 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import xgboost as xgb
|
||||
import pytest
|
||||
import testing as tm
|
||||
|
||||
|
||||
@pytest.mark.parametrize('verbosity_level', [0, 1, 2, 3])
|
||||
def test_global_config_verbosity(verbosity_level):
|
||||
def get_current_verbosity():
|
||||
return xgb.get_config()['verbosity']
|
||||
|
||||
old_verbosity = get_current_verbosity()
|
||||
with xgb.config_context(verbosity=verbosity_level):
|
||||
new_verbosity = get_current_verbosity()
|
||||
assert new_verbosity == verbosity_level
|
||||
assert old_verbosity == get_current_verbosity()
|
||||
@ -637,6 +637,46 @@ def test_aft_survival():
|
||||
|
||||
|
||||
class TestWithDask:
|
||||
def test_global_config(self, client):
|
||||
X, y = generate_array()
|
||||
xgb.config.set_config(verbosity=0)
|
||||
dtrain = DaskDMatrix(client, X, y)
|
||||
before_fname = './before_training-test_global_config'
|
||||
after_fname = './after_training-test_global_config'
|
||||
|
||||
class TestCallback(xgb.callback.TrainingCallback):
|
||||
def write_file(self, fname):
|
||||
with open(fname, 'w') as fd:
|
||||
fd.write(str(xgb.config.get_config()['verbosity']))
|
||||
|
||||
def before_training(self, model):
|
||||
self.write_file(before_fname)
|
||||
assert xgb.config.get_config()['verbosity'] == 0
|
||||
return model
|
||||
|
||||
def after_training(self, model):
|
||||
assert xgb.config.get_config()['verbosity'] == 0
|
||||
return model
|
||||
|
||||
def before_iteration(self, model, epoch, evals_log):
|
||||
assert xgb.config.get_config()['verbosity'] == 0
|
||||
return False
|
||||
|
||||
def after_iteration(self, model, epoch, evals_log):
|
||||
self.write_file(after_fname)
|
||||
assert xgb.config.get_config()['verbosity'] == 0
|
||||
return False
|
||||
|
||||
xgb.dask.train(client, {}, dtrain, num_boost_round=4, callbacks=[TestCallback()])[
|
||||
'booster']
|
||||
|
||||
with open(before_fname, 'r') as before, open(after_fname, 'r') as after:
|
||||
assert before.read() == '0'
|
||||
assert after.read() == '0'
|
||||
|
||||
os.remove(before_fname)
|
||||
os.remove(after_fname)
|
||||
|
||||
def run_updater_test(self, client, params, num_rounds, dataset,
|
||||
tree_method):
|
||||
params['tree_method'] = tree_method
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user