Add global configuration (#6414)
* Add management functions for global configuration: XGBSetGlobalConfig(), XGBGetGlobalConfig(). * Add Python interface: set_config(), get_config(), and config_context(). * Add unit tests for Python * Add R interface: xgb.set.config(), xgb.get.config() * Add unit tests for R Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
committed by
GitHub
parent
c2ba4fb957
commit
fb56da5e8b
@@ -18,9 +18,11 @@
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/version_config.h"
|
||||
#include "xgboost/json.h"
|
||||
#include "xgboost/global_config.h"
|
||||
|
||||
#include "c_api_error.h"
|
||||
#include "../common/io.h"
|
||||
#include "../common/charconv.h"
|
||||
#include "../data/adapter.h"
|
||||
#include "../data/simple_dmatrix.h"
|
||||
#include "../data/proxy_dmatrix.h"
|
||||
@@ -46,6 +48,91 @@ XGB_DLL int XGBRegisterLogCallback(void (*callback)(const char*)) {
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGBSetGlobalConfig(const char* json_str) {
|
||||
API_BEGIN();
|
||||
std::string str{json_str};
|
||||
Json config{Json::Load(StringView{str.data(), str.size()})};
|
||||
for (auto& items : get<Object>(config)) {
|
||||
switch (items.second.GetValue().Type()) {
|
||||
case xgboost::Value::ValueKind::kInteger: {
|
||||
items.second = String{std::to_string(get<Integer const>(items.second))};
|
||||
break;
|
||||
}
|
||||
case xgboost::Value::ValueKind::kBoolean: {
|
||||
if (get<Boolean const>(items.second)) {
|
||||
items.second = String{"true"};
|
||||
} else {
|
||||
items.second = String{"false"};
|
||||
}
|
||||
break;
|
||||
}
|
||||
case xgboost::Value::ValueKind::kNumber: {
|
||||
auto n = get<Number const>(items.second);
|
||||
char chars[NumericLimits<float>::kToCharsSize];
|
||||
auto ec = to_chars(chars, chars + sizeof(chars), n).ec;
|
||||
CHECK(ec == std::errc());
|
||||
items.second = String{chars};
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
auto unknown = FromJson(config, GlobalConfigThreadLocalStore::Get());
|
||||
if (!unknown.empty()) {
|
||||
std::stringstream ss;
|
||||
ss << "Unknown global parameters: { ";
|
||||
size_t i = 0;
|
||||
for (auto const& item : unknown) {
|
||||
ss << item.first;
|
||||
i++;
|
||||
if (i != unknown.size()) {
|
||||
ss << ", ";
|
||||
}
|
||||
}
|
||||
LOG(FATAL) << ss.str() << " }";
|
||||
}
|
||||
API_END();
|
||||
}
|
||||
|
||||
using GlobalConfigAPIThreadLocalStore = dmlc::ThreadLocalStore<XGBAPIThreadLocalEntry>;
|
||||
|
||||
XGB_DLL int XGBGetGlobalConfig(const char** json_str) {
|
||||
API_BEGIN();
|
||||
auto const& global_config = *GlobalConfigThreadLocalStore::Get();
|
||||
Json config {ToJson(global_config)};
|
||||
auto const* mgr = global_config.__MANAGER__();
|
||||
|
||||
for (auto& item : get<Object>(config)) {
|
||||
auto const &str = get<String const>(item.second);
|
||||
auto const &name = item.first;
|
||||
auto e = mgr->Find(name);
|
||||
CHECK(e);
|
||||
|
||||
if (dynamic_cast<dmlc::parameter::FieldEntry<int32_t> const*>(e) ||
|
||||
dynamic_cast<dmlc::parameter::FieldEntry<int64_t> const*>(e) ||
|
||||
dynamic_cast<dmlc::parameter::FieldEntry<uint32_t> const*>(e) ||
|
||||
dynamic_cast<dmlc::parameter::FieldEntry<uint64_t> const*>(e)) {
|
||||
auto i = std::strtoimax(str.data(), nullptr, 10);
|
||||
CHECK_LE(i, static_cast<intmax_t>(std::numeric_limits<int64_t>::max()));
|
||||
item.second = Integer(static_cast<int64_t>(i));
|
||||
} else if (dynamic_cast<dmlc::parameter::FieldEntry<float> const *>(e) ||
|
||||
dynamic_cast<dmlc::parameter::FieldEntry<double> const *>(e)) {
|
||||
float f;
|
||||
auto ec = from_chars(str.data(), str.data() + str.size(), f).ec;
|
||||
CHECK(ec == std::errc());
|
||||
item.second = Number(f);
|
||||
} else if (dynamic_cast<dmlc::parameter::FieldEntry<bool> const *>(e)) {
|
||||
item.second = Boolean(str != "0");
|
||||
}
|
||||
}
|
||||
|
||||
auto& local = *GlobalConfigAPIThreadLocalStore::Get();
|
||||
Json::Dump(config, &local.ret_str);
|
||||
*json_str = local.ret_str.c_str();
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixCreateFromFile(const char *fname,
|
||||
int silent,
|
||||
DMatrixHandle *out) {
|
||||
|
||||
Reference in New Issue
Block a user