Add Model and Configurable interface. (#4945)

* Apply Configurable to objective functions.
* Apply Model to Learner and Regtree, gbm.
* Add Load/SaveConfig to objs.
* Refactor obj tests to use smart pointer.
* Dummy methods for Save/Load Model.
This commit is contained in:
Jiaming Yuan
2019-10-18 01:56:02 -04:00
committed by GitHub
parent 9fc681001a
commit ae536756ae
31 changed files with 521 additions and 187 deletions

View File

@@ -2,33 +2,59 @@
#include <xgboost/objective.h>
#include <xgboost/generic_parameters.h>
#include "../helpers.h"
#include <xgboost/json.h>
namespace xgboost {
TEST(Objective, PairwiseRankingGPair) {
xgboost::GenericParameter tparam;
std::vector<std::pair<std::string, std::string>> args;
tparam.InitAllowUnknown(args);
xgboost::ObjFunction * obj =
xgboost::ObjFunction::Create("rank:pairwise", &tparam);
std::unique_ptr<xgboost::ObjFunction> obj {
xgboost::ObjFunction::Create("rank:pairwise", &tparam)
};
obj->Configure(args);
CheckConfigReload(obj, "rank:pairwise");
// Test with setting sample weight to second query group
CheckRankingObjFunction(obj,
{0, 0.1f, 0, 0.1f},
{0, 1, 0, 1},
{2.0f, 0.0f},
{0, 2, 4},
{1.9f, -1.9f, 0.0f, 0.0f},
{1.995f, 1.995f, 0.0f, 0.0f});
{0, 0.1f, 0, 0.1f},
{0, 1, 0, 1},
{2.0f, 0.0f},
{0, 2, 4},
{1.9f, -1.9f, 0.0f, 0.0f},
{1.995f, 1.995f, 0.0f, 0.0f});
CheckRankingObjFunction(obj,
{0, 0.1f, 0, 0.1f},
{0, 1, 0, 1},
{1.0f, 1.0f},
{0, 2, 4},
{0.95f, -0.95f, 0.95f, -0.95f},
{0.9975f, 0.9975f, 0.9975f, 0.9975f});
{0, 0.1f, 0, 0.1f},
{0, 1, 0, 1},
{1.0f, 1.0f},
{0, 2, 4},
{0.95f, -0.95f, 0.95f, -0.95f},
{0.9975f, 0.9975f, 0.9975f, 0.9975f});
ASSERT_NO_THROW(obj->DefaultEvalMetric());
delete obj;
}
TEST(Objective, NDCG_Json_IO) {
xgboost::GenericParameter tparam;
tparam.InitAllowUnknown(Args{});
std::unique_ptr<xgboost::ObjFunction> obj {
xgboost::ObjFunction::Create("rank:ndcg", &tparam)
};
obj->Configure(Args{});
Json j_obj {Object()};
obj->SaveConfig(&j_obj);
ASSERT_EQ(get<String>(j_obj["name"]), "rank:ndcg");;
auto const& j_param = j_obj["lambda_rank_param"];
ASSERT_EQ(get<String>(j_param["num_pairsample"]), "1");
ASSERT_EQ(get<String>(j_param["fix_list_weight"]), "0");
}
} // namespace xgboost