Add global configuration (#6414)

* Add management functions for global configuration: XGBSetGlobalConfig(), XGBGetGlobalConfig().
* Add Python interface: set_config(), get_config(), and config_context().
* Add unit tests for Python
* Add R interface: xgb.set.config(), xgb.get.config()
* Add unit tests for R

Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
Philip Hyunsu Cho
2020-12-03 00:05:18 -08:00
committed by GitHub
parent c2ba4fb957
commit fb56da5e8b
29 changed files with 637 additions and 86 deletions

View File

@@ -18,9 +18,11 @@
#include "xgboost/logging.h"
#include "xgboost/version_config.h"
#include "xgboost/json.h"
#include "xgboost/global_config.h"
#include "c_api_error.h"
#include "../common/io.h"
#include "../common/charconv.h"
#include "../data/adapter.h"
#include "../data/simple_dmatrix.h"
#include "../data/proxy_dmatrix.h"
@@ -46,6 +48,91 @@ XGB_DLL int XGBRegisterLogCallback(void (*callback)(const char*)) {
API_END();
}
XGB_DLL int XGBSetGlobalConfig(const char* json_str) {
API_BEGIN();
std::string str{json_str};
Json config{Json::Load(StringView{str.data(), str.size()})};
for (auto& items : get<Object>(config)) {
switch (items.second.GetValue().Type()) {
case xgboost::Value::ValueKind::kInteger: {
items.second = String{std::to_string(get<Integer const>(items.second))};
break;
}
case xgboost::Value::ValueKind::kBoolean: {
if (get<Boolean const>(items.second)) {
items.second = String{"true"};
} else {
items.second = String{"false"};
}
break;
}
case xgboost::Value::ValueKind::kNumber: {
auto n = get<Number const>(items.second);
char chars[NumericLimits<float>::kToCharsSize];
auto ec = to_chars(chars, chars + sizeof(chars), n).ec;
CHECK(ec == std::errc());
items.second = String{chars};
break;
}
default:
break;
}
}
auto unknown = FromJson(config, GlobalConfigThreadLocalStore::Get());
if (!unknown.empty()) {
std::stringstream ss;
ss << "Unknown global parameters: { ";
size_t i = 0;
for (auto const& item : unknown) {
ss << item.first;
i++;
if (i != unknown.size()) {
ss << ", ";
}
}
LOG(FATAL) << ss.str() << " }";
}
API_END();
}
using GlobalConfigAPIThreadLocalStore = dmlc::ThreadLocalStore<XGBAPIThreadLocalEntry>;
XGB_DLL int XGBGetGlobalConfig(const char** json_str) {
API_BEGIN();
auto const& global_config = *GlobalConfigThreadLocalStore::Get();
Json config {ToJson(global_config)};
auto const* mgr = global_config.__MANAGER__();
for (auto& item : get<Object>(config)) {
auto const &str = get<String const>(item.second);
auto const &name = item.first;
auto e = mgr->Find(name);
CHECK(e);
if (dynamic_cast<dmlc::parameter::FieldEntry<int32_t> const*>(e) ||
dynamic_cast<dmlc::parameter::FieldEntry<int64_t> const*>(e) ||
dynamic_cast<dmlc::parameter::FieldEntry<uint32_t> const*>(e) ||
dynamic_cast<dmlc::parameter::FieldEntry<uint64_t> const*>(e)) {
auto i = std::strtoimax(str.data(), nullptr, 10);
CHECK_LE(i, static_cast<intmax_t>(std::numeric_limits<int64_t>::max()));
item.second = Integer(static_cast<int64_t>(i));
} else if (dynamic_cast<dmlc::parameter::FieldEntry<float> const *>(e) ||
dynamic_cast<dmlc::parameter::FieldEntry<double> const *>(e)) {
float f;
auto ec = from_chars(str.data(), str.data() + str.size(), f).ec;
CHECK(ec == std::errc());
item.second = Number(f);
} else if (dynamic_cast<dmlc::parameter::FieldEntry<bool> const *>(e)) {
item.second = Boolean(str != "0");
}
}
auto& local = *GlobalConfigAPIThreadLocalStore::Get();
Json::Dump(config, &local.ret_str);
*json_str = local.ret_str.c_str();
API_END();
}
XGB_DLL int XGDMatrixCreateFromFile(const char *fname,
int silent,
DMatrixHandle *out) {