Add Model and Configurable interface. (#4945)

* Apply Configurable to objective functions.
* Apply Model to Learner and Regtree, gbm.
* Add Load/SaveConfig to objs.
* Refactor obj tests to use smart pointer.
* Dummy methods for Save/Load Model.
This commit is contained in:
Jiaming Yuan
2019-10-18 01:56:02 -04:00
committed by GitHub
parent 9fc681001a
commit ae536756ae
31 changed files with 521 additions and 187 deletions

View File

@@ -3,6 +3,10 @@
*/
#include <dmlc/filesystem.h>
#include <xgboost/logging.h>
#include <xgboost/json.h>
#include <gtest/gtest.h>
#include <random>
#include <cinttypes>
#include "./helpers.h"
@@ -36,7 +40,7 @@ void CreateBigTestData(const std::string& filename, size_t n_entries) {
}
}
void CheckObjFunctionImpl(xgboost::ObjFunction * obj,
void CheckObjFunctionImpl(std::unique_ptr<xgboost::ObjFunction> const& obj,
std::vector<xgboost::bst_float> preds,
std::vector<xgboost::bst_float> labels,
std::vector<xgboost::bst_float> weights,
@@ -59,7 +63,7 @@ void CheckObjFunctionImpl(xgboost::ObjFunction * obj,
}
}
void CheckObjFunction(xgboost::ObjFunction * obj,
void CheckObjFunction(std::unique_ptr<xgboost::ObjFunction> const& obj,
std::vector<xgboost::bst_float> preds,
std::vector<xgboost::bst_float> labels,
std::vector<xgboost::bst_float> weights,
@@ -73,13 +77,33 @@ void CheckObjFunction(xgboost::ObjFunction * obj,
CheckObjFunctionImpl(obj, preds, labels, weights, info, out_grad, out_hess);
}
void CheckRankingObjFunction(xgboost::ObjFunction * obj,
std::vector<xgboost::bst_float> preds,
std::vector<xgboost::bst_float> labels,
std::vector<xgboost::bst_float> weights,
std::vector<xgboost::bst_uint> groups,
std::vector<xgboost::bst_float> out_grad,
std::vector<xgboost::bst_float> out_hess) {
xgboost::Json CheckConfigReloadImpl(xgboost::Configurable* const configurable,
std::string name) {
xgboost::Json config_0 { xgboost::Object() };
configurable->SaveConfig(&config_0);
configurable->LoadConfig(config_0);
xgboost::Json config_1 { xgboost::Object() };
configurable->SaveConfig(&config_1);
std::string str_0, str_1;
xgboost::Json::Dump(config_0, &str_0);
xgboost::Json::Dump(config_1, &str_1);
EXPECT_EQ(str_0, str_1);
if (name != "") {
EXPECT_EQ(xgboost::get<xgboost::String>(config_1["name"]), name);
}
return config_1;
}
void CheckRankingObjFunction(std::unique_ptr<xgboost::ObjFunction> const& obj,
std::vector<xgboost::bst_float> preds,
std::vector<xgboost::bst_float> labels,
std::vector<xgboost::bst_float> weights,
std::vector<xgboost::bst_uint> groups,
std::vector<xgboost::bst_float> out_grad,
std::vector<xgboost::bst_float> out_hess) {
xgboost::MetaInfo info;
info.num_row_ = labels.size();
info.labels_.HostVector() = labels;