Add Model and Configurable interface. (#4945)
* Apply Configurable to objective functions. * Apply Model to Learner and Regtree, gbm. * Add Load/SaveConfig to objs. * Refactor obj tests to use smart pointer. * Dummy methods for Save/Load Model.
This commit is contained in:
@@ -3,6 +3,10 @@
|
||||
*/
|
||||
#include <dmlc/filesystem.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <xgboost/json.h>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <random>
|
||||
#include <cinttypes>
|
||||
#include "./helpers.h"
|
||||
@@ -36,7 +40,7 @@ void CreateBigTestData(const std::string& filename, size_t n_entries) {
|
||||
}
|
||||
}
|
||||
|
||||
void CheckObjFunctionImpl(xgboost::ObjFunction * obj,
|
||||
void CheckObjFunctionImpl(std::unique_ptr<xgboost::ObjFunction> const& obj,
|
||||
std::vector<xgboost::bst_float> preds,
|
||||
std::vector<xgboost::bst_float> labels,
|
||||
std::vector<xgboost::bst_float> weights,
|
||||
@@ -59,7 +63,7 @@ void CheckObjFunctionImpl(xgboost::ObjFunction * obj,
|
||||
}
|
||||
}
|
||||
|
||||
void CheckObjFunction(xgboost::ObjFunction * obj,
|
||||
void CheckObjFunction(std::unique_ptr<xgboost::ObjFunction> const& obj,
|
||||
std::vector<xgboost::bst_float> preds,
|
||||
std::vector<xgboost::bst_float> labels,
|
||||
std::vector<xgboost::bst_float> weights,
|
||||
@@ -73,13 +77,33 @@ void CheckObjFunction(xgboost::ObjFunction * obj,
|
||||
CheckObjFunctionImpl(obj, preds, labels, weights, info, out_grad, out_hess);
|
||||
}
|
||||
|
||||
void CheckRankingObjFunction(xgboost::ObjFunction * obj,
|
||||
std::vector<xgboost::bst_float> preds,
|
||||
std::vector<xgboost::bst_float> labels,
|
||||
std::vector<xgboost::bst_float> weights,
|
||||
std::vector<xgboost::bst_uint> groups,
|
||||
std::vector<xgboost::bst_float> out_grad,
|
||||
std::vector<xgboost::bst_float> out_hess) {
|
||||
xgboost::Json CheckConfigReloadImpl(xgboost::Configurable* const configurable,
|
||||
std::string name) {
|
||||
xgboost::Json config_0 { xgboost::Object() };
|
||||
configurable->SaveConfig(&config_0);
|
||||
configurable->LoadConfig(config_0);
|
||||
|
||||
xgboost::Json config_1 { xgboost::Object() };
|
||||
configurable->SaveConfig(&config_1);
|
||||
|
||||
std::string str_0, str_1;
|
||||
xgboost::Json::Dump(config_0, &str_0);
|
||||
xgboost::Json::Dump(config_1, &str_1);
|
||||
EXPECT_EQ(str_0, str_1);
|
||||
|
||||
if (name != "") {
|
||||
EXPECT_EQ(xgboost::get<xgboost::String>(config_1["name"]), name);
|
||||
}
|
||||
return config_1;
|
||||
}
|
||||
|
||||
void CheckRankingObjFunction(std::unique_ptr<xgboost::ObjFunction> const& obj,
|
||||
std::vector<xgboost::bst_float> preds,
|
||||
std::vector<xgboost::bst_float> labels,
|
||||
std::vector<xgboost::bst_float> weights,
|
||||
std::vector<xgboost::bst_uint> groups,
|
||||
std::vector<xgboost::bst_float> out_grad,
|
||||
std::vector<xgboost::bst_float> out_hess) {
|
||||
xgboost::MetaInfo info;
|
||||
info.num_row_ = labels.size();
|
||||
info.labels_.HostVector() = labels;
|
||||
|
||||
Reference in New Issue
Block a user