45 lines
1.3 KiB
C++
45 lines
1.3 KiB
C++
/*!
|
|
* Copyright 2015 by Contributors
|
|
* \file global.cc
|
|
* \brief Enable all kinds of global static registry and variables.
|
|
*/
|
|
#include <xgboost/objective.h>
|
|
#include <xgboost/metric.h>
|
|
|
|
namespace dmlc {
|
|
DMLC_REGISTRY_ENABLE(::xgboost::ObjFunctionReg);
|
|
DMLC_REGISTRY_ENABLE(::xgboost::MetricReg);
|
|
} // namespace dmlc
|
|
|
|
namespace xgboost {
|
|
// implement factory functions
|
|
ObjFunction* ObjFunction::Create(const char* name) {
|
|
auto *e = ::dmlc::Registry< ::xgboost::ObjFunctionReg>::Get()->Find(name);
|
|
if (e == nullptr) {
|
|
LOG(FATAL) << "Unknown objective function " << name;
|
|
}
|
|
return (e->body)();
|
|
}
|
|
|
|
Metric* Metric::Create(const char* name) {
|
|
std::string buf = name;
|
|
std::string prefix = name;
|
|
auto pos = buf.find('@');
|
|
if (pos == std::string::npos) {
|
|
auto *e = ::dmlc::Registry< ::xgboost::MetricReg>::Get()->Find(name);
|
|
if (e == nullptr) {
|
|
LOG(FATAL) << "Unknown objective function " << name;
|
|
}
|
|
return (e->body)(nullptr);
|
|
} else {
|
|
std::string prefix = buf.substr(0, pos);
|
|
auto *e = ::dmlc::Registry< ::xgboost::MetricReg>::Get()->Find(prefix.c_str());
|
|
if (e == nullptr) {
|
|
LOG(FATAL) << "Unknown objective function " << name;
|
|
}
|
|
return (e->body)(buf.substr(pos + 1, buf.length()).c_str());
|
|
}
|
|
}
|
|
} // namespace xgboost
|
|
|