Add global configuration (#6414)

* Add management functions for global configuration: XGBSetGlobalConfig(), XGBGetGlobalConfig().
* Add Python interface: set_config(), get_config(), and config_context().
* Add unit tests for Python
* Add R interface: xgb.set.config(), xgb.get.config()
* Add unit tests for R

Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
Philip Hyunsu Cho
2020-12-03 00:05:18 -08:00
committed by GitHub
parent c2ba4fb957
commit fb56da5e8b
29 changed files with 637 additions and 86 deletions

View File

@@ -18,9 +18,11 @@
#include "xgboost/logging.h"
#include "xgboost/version_config.h"
#include "xgboost/json.h"
#include "xgboost/global_config.h"
#include "c_api_error.h"
#include "../common/io.h"
#include "../common/charconv.h"
#include "../data/adapter.h"
#include "../data/simple_dmatrix.h"
#include "../data/proxy_dmatrix.h"
@@ -46,6 +48,91 @@ XGB_DLL int XGBRegisterLogCallback(void (*callback)(const char*)) {
API_END();
}
XGB_DLL int XGBSetGlobalConfig(const char* json_str) {
API_BEGIN();
std::string str{json_str};
Json config{Json::Load(StringView{str.data(), str.size()})};
for (auto& items : get<Object>(config)) {
switch (items.second.GetValue().Type()) {
case xgboost::Value::ValueKind::kInteger: {
items.second = String{std::to_string(get<Integer const>(items.second))};
break;
}
case xgboost::Value::ValueKind::kBoolean: {
if (get<Boolean const>(items.second)) {
items.second = String{"true"};
} else {
items.second = String{"false"};
}
break;
}
case xgboost::Value::ValueKind::kNumber: {
auto n = get<Number const>(items.second);
char chars[NumericLimits<float>::kToCharsSize];
auto ec = to_chars(chars, chars + sizeof(chars), n).ec;
CHECK(ec == std::errc());
items.second = String{chars};
break;
}
default:
break;
}
}
auto unknown = FromJson(config, GlobalConfigThreadLocalStore::Get());
if (!unknown.empty()) {
std::stringstream ss;
ss << "Unknown global parameters: { ";
size_t i = 0;
for (auto const& item : unknown) {
ss << item.first;
i++;
if (i != unknown.size()) {
ss << ", ";
}
}
LOG(FATAL) << ss.str() << " }";
}
API_END();
}
using GlobalConfigAPIThreadLocalStore = dmlc::ThreadLocalStore<XGBAPIThreadLocalEntry>;
XGB_DLL int XGBGetGlobalConfig(const char** json_str) {
API_BEGIN();
auto const& global_config = *GlobalConfigThreadLocalStore::Get();
Json config {ToJson(global_config)};
auto const* mgr = global_config.__MANAGER__();
for (auto& item : get<Object>(config)) {
auto const &str = get<String const>(item.second);
auto const &name = item.first;
auto e = mgr->Find(name);
CHECK(e);
if (dynamic_cast<dmlc::parameter::FieldEntry<int32_t> const*>(e) ||
dynamic_cast<dmlc::parameter::FieldEntry<int64_t> const*>(e) ||
dynamic_cast<dmlc::parameter::FieldEntry<uint32_t> const*>(e) ||
dynamic_cast<dmlc::parameter::FieldEntry<uint64_t> const*>(e)) {
auto i = std::strtoimax(str.data(), nullptr, 10);
CHECK_LE(i, static_cast<intmax_t>(std::numeric_limits<int64_t>::max()));
item.second = Integer(static_cast<int64_t>(i));
} else if (dynamic_cast<dmlc::parameter::FieldEntry<float> const *>(e) ||
dynamic_cast<dmlc::parameter::FieldEntry<double> const *>(e)) {
float f;
auto ec = from_chars(str.data(), str.data() + str.size(), f).ec;
CHECK(ec == std::errc());
item.second = Number(f);
} else if (dynamic_cast<dmlc::parameter::FieldEntry<bool> const *>(e)) {
item.second = Boolean(str != "0");
}
}
auto& local = *GlobalConfigAPIThreadLocalStore::Get();
Json::Dump(config, &local.ret_str);
*json_str = local.ret_str.c_str();
API_END();
}
XGB_DLL int XGDMatrixCreateFromFile(const char *fname,
int silent,
DMatrixHandle *out) {

14
src/global_config.cc Normal file
View File

@@ -0,0 +1,14 @@
/*!
* Copyright 2020 by Contributors
* \file global_config.cc
* \brief Global configuration for XGBoost
* \author Hyunsu Cho
*/
#include <dmlc/thread_local.h>
#include "xgboost/global_config.h"
#include "xgboost/json.h"
namespace xgboost {
DMLC_REGISTER_PARAMETER(GlobalConfiguration);
} // namespace xgboost

View File

@@ -490,6 +490,12 @@ class LearnerConfiguration : public Learner {
// Extract all parameters
std::vector<std::string> keys;
// First global parameters
Json const global_config{ToJson(*GlobalConfigThreadLocalStore::Get())};
for (auto const& items : get<Object const>(global_config)) {
keys.emplace_back(items.first);
}
// Parameters in various xgboost components.
while (!stack.empty()) {
auto j_obj = stack.top();
stack.pop();

View File

@@ -11,12 +11,13 @@
#include "xgboost/parameter.h"
#include "xgboost/logging.h"
#include "xgboost/json.h"
#if !defined(XGBOOST_STRICT_R_MODE) || XGBOOST_STRICT_R_MODE == 0
// Override logging mechanism for non-R interfaces
void dmlc::CustomLogMessage::Log(const std::string& msg) {
const xgboost::LogCallbackRegistry* registry
= xgboost::LogCallbackRegistryStore::Get();
const xgboost::LogCallbackRegistry *registry =
xgboost::LogCallbackRegistryStore::Get();
auto callback = registry->Get();
callback(msg.c_str());
}
@@ -40,35 +41,15 @@ TrackerLogger::~TrackerLogger() {
namespace xgboost {
DMLC_REGISTER_PARAMETER(ConsoleLoggerParam);
ConsoleLogger::LogVerbosity ConsoleLogger::global_verbosity_ =
ConsoleLogger::DefaultVerbosity();
ConsoleLoggerParam ConsoleLogger::param_ = ConsoleLoggerParam();
bool ConsoleLogger::ShouldLog(LogVerbosity verbosity) {
return verbosity <= global_verbosity_ || verbosity == LV::kIgnore;
return static_cast<int>(verbosity) <=
(GlobalConfigThreadLocalStore::Get()->verbosity) ||
verbosity == LV::kIgnore;
}
void ConsoleLogger::Configure(Args const& args) {
param_.UpdateAllowUnknown(args);
switch (param_.verbosity) {
case 0:
global_verbosity_ = LogVerbosity::kSilent;
break;
case 1:
global_verbosity_ = LogVerbosity::kWarning;
break;
case 2:
global_verbosity_ = LogVerbosity::kInfo;
break;
case 3:
global_verbosity_ = LogVerbosity::kDebug;
default:
// global verbosity doesn't require kIgnore
break;
}
auto& param = *GlobalConfigThreadLocalStore::Get();
param.UpdateAllowUnknown(args);
}
ConsoleLogger::LogVerbosity ConsoleLogger::DefaultVerbosity() {
@@ -76,7 +57,25 @@ ConsoleLogger::LogVerbosity ConsoleLogger::DefaultVerbosity() {
}
ConsoleLogger::LogVerbosity ConsoleLogger::GlobalVerbosity() {
return global_verbosity_;
LogVerbosity global_verbosity { LogVerbosity::kWarning };
switch (GlobalConfigThreadLocalStore::Get()->verbosity) {
case 0:
global_verbosity = LogVerbosity::kSilent;
break;
case 1:
global_verbosity = LogVerbosity::kWarning;
break;
case 2:
global_verbosity = LogVerbosity::kInfo;
break;
case 3:
global_verbosity = LogVerbosity::kDebug;
default:
// global verbosity doesn't require kIgnore
break;
}
return global_verbosity;
}
ConsoleLogger::ConsoleLogger(LogVerbosity cur_verb) :