Add global configuration (#6414)

* Add management functions for global configuration: XGBSetGlobalConfig(), XGBGetGlobalConfig().
* Add Python interface: set_config(), get_config(), and config_context().
* Add unit tests for Python
* Add R interface: xgb.set.config(), xgb.get.config()
* Add unit tests for R

Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
Philip Hyunsu Cho
2020-12-03 00:05:18 -08:00
committed by GitHub
parent c2ba4fb957
commit fb56da5e8b
29 changed files with 637 additions and 86 deletions

View File

@@ -212,4 +212,50 @@ TEST(CAPI, Exception) {
// Not null
ASSERT_TRUE(error);
}
TEST(CAPI, XGBGlobalConfig) {
int ret;
{
const char *config_str = R"json(
{
"verbosity": 0
}
)json";
ret = XGBSetGlobalConfig(config_str);
ASSERT_EQ(ret, 0);
const char *updated_config_cstr;
ret = XGBGetGlobalConfig(&updated_config_cstr);
ASSERT_EQ(ret, 0);
std::string updated_config_str{updated_config_cstr};
auto updated_config =
Json::Load({updated_config_str.data(), updated_config_str.size()});
ASSERT_EQ(get<Integer>(updated_config["verbosity"]), 0);
}
{
const char *config_str = R"json(
{
"foo": 0
}
)json";
ret = XGBSetGlobalConfig(config_str);
ASSERT_EQ(ret , -1);
auto err = std::string{XGBGetLastError()};
ASSERT_NE(err.find("foo"), std::string::npos);
}
{
const char *config_str = R"json(
{
"foo": 0,
"verbosity": 0
}
)json";
ret = XGBSetGlobalConfig(config_str);
ASSERT_EQ(ret , -1);
auto err = std::string{XGBGetLastError()};
ASSERT_NE(err.find("foo"), std::string::npos);
ASSERT_EQ(err.find("verbosity"), std::string::npos);
}
}
} // namespace xgboost

View File

@@ -0,0 +1,22 @@
#include <gtest/gtest.h>
#include <xgboost/json.h>
#include <xgboost/logging.h>
#include <xgboost/global_config.h>
namespace xgboost {
TEST(GlobalConfiguration, Verbosity) {
// Configure verbosity via global configuration
Json config{JsonObject()};
config["verbosity"] = String("0");
auto& global_config = *GlobalConfigThreadLocalStore::Get();
FromJson(config, &global_config);
// Now verbosity should be updated
EXPECT_EQ(ConsoleLogger::GlobalVerbosity(), ConsoleLogger::LogVerbosity::kSilent);
EXPECT_NE(ConsoleLogger::LogVerbosity::kSilent, ConsoleLogger::DefaultVerbosity());
// GetConfig() should also return updated verbosity
Json current_config { ToJson(*GlobalConfigThreadLocalStore::Get()) };
EXPECT_EQ(get<String>(current_config["verbosity"]), "0");
}
} // namespace xgboost

View File

@@ -0,0 +1,16 @@
# -*- coding: utf-8 -*-
import xgboost as xgb
import pytest
import testing as tm
@pytest.mark.parametrize('verbosity_level', [0, 1, 2, 3])
def test_global_config_verbosity(verbosity_level):
def get_current_verbosity():
return xgb.get_config()['verbosity']
old_verbosity = get_current_verbosity()
with xgb.config_context(verbosity=verbosity_level):
new_verbosity = get_current_verbosity()
assert new_verbosity == verbosity_level
assert old_verbosity == get_current_verbosity()

View File

@@ -637,6 +637,46 @@ def test_aft_survival():
class TestWithDask:
def test_global_config(self, client):
X, y = generate_array()
xgb.config.set_config(verbosity=0)
dtrain = DaskDMatrix(client, X, y)
before_fname = './before_training-test_global_config'
after_fname = './after_training-test_global_config'
class TestCallback(xgb.callback.TrainingCallback):
def write_file(self, fname):
with open(fname, 'w') as fd:
fd.write(str(xgb.config.get_config()['verbosity']))
def before_training(self, model):
self.write_file(before_fname)
assert xgb.config.get_config()['verbosity'] == 0
return model
def after_training(self, model):
assert xgb.config.get_config()['verbosity'] == 0
return model
def before_iteration(self, model, epoch, evals_log):
assert xgb.config.get_config()['verbosity'] == 0
return False
def after_iteration(self, model, epoch, evals_log):
self.write_file(after_fname)
assert xgb.config.get_config()['verbosity'] == 0
return False
xgb.dask.train(client, {}, dtrain, num_boost_round=4, callbacks=[TestCallback()])[
'booster']
with open(before_fname, 'r') as before, open(after_fname, 'r') as after:
assert before.read() == '0'
assert after.read() == '0'
os.remove(before_fname)
os.remove(after_fname)
def run_updater_test(self, client, params, num_rounds, dataset,
tree_method):
params['tree_method'] = tree_method