Add Model and Configurable interface. (#4945)

* Apply Configurable to objective functions.
* Apply Model to Learner and Regtree, gbm.
* Add Load/SaveConfig to objs.
* Refactor obj tests to use smart pointer.
* Dummy methods for Save/Load Model.
This commit is contained in:
Jiaming Yuan
2019-10-18 01:56:02 -04:00
committed by GitHub
parent 9fc681001a
commit ae536756ae
31 changed files with 521 additions and 187 deletions

View File

@@ -6,6 +6,7 @@
#include <dmlc/parameter.h>
#include <xgboost/base.h>
#include <xgboost/feature_map.h>
#include <xgboost/model.h>
#include <vector>
#include <string>
#include <cstring>
@@ -34,7 +35,7 @@ struct GBLinearModelParam : public dmlc::Parameter<GBLinearModelParam> {
};
// model for linear booster
class GBLinearModel {
class GBLinearModel : public Model {
public:
// parameter
GBLinearModelParam param;
@@ -57,6 +58,17 @@ class GBLinearModel {
CHECK_EQ(fi->Read(&param, sizeof(param)), sizeof(param));
fi->Read(&weight);
}
void LoadModel(dmlc::Stream* fi) override {
// They are the same right now until we can split up the saved parameter from model.
this->Load(fi);
}
void SaveModel(dmlc::Stream* fo) const override {
// They are the same right now until we can split up the saved parameter from model.
this->Save(fo);
}
// model bias
inline bst_float* bias() {
return &weight[param.num_feature * param.num_output_group];

View File

@@ -14,7 +14,7 @@
#include <xgboost/gbm.h>
#include <xgboost/predictor.h>
#include <xgboost/tree_updater.h>
#include <xgboost/enum_class_param.h>
#include <xgboost/parameter.h>
#include <vector>
#include <map>

View File

@@ -4,6 +4,7 @@
#pragma once
#include <dmlc/parameter.h>
#include <dmlc/io.h>
#include <xgboost/model.h>
#include <xgboost/tree_model.h>
#include <memory>
@@ -61,7 +62,7 @@ struct GBTreeModelParam : public dmlc::Parameter<GBTreeModelParam> {
}
};
struct GBTreeModel {
struct GBTreeModel : public Model {
explicit GBTreeModel(bst_float base_margin) : base_margin(base_margin) {}
void Configure(const Args& cfg) {
// initialize model parameters if not yet been initialized.
@@ -81,6 +82,15 @@ struct GBTreeModel {
}
}
void LoadModel(dmlc::Stream* fi) override {
// They are the same right now until we can split up the saved parameter from model.
this->Load(fi);
}
void SaveModel(dmlc::Stream* fo) const override {
// They are the same right now until we can split up the saved parameter from model.
this->Save(fo);
}
void Load(dmlc::Stream* fi) {
CHECK_EQ(fi->Read(&param, sizeof(param)), sizeof(param))
<< "GBTree: invalid model file";
@@ -88,7 +98,7 @@ struct GBTreeModel {
trees_to_update.clear();
for (int i = 0; i < param.num_trees; ++i) {
std::unique_ptr<RegTree> ptr(new RegTree());
ptr->Load(fi);
ptr->LoadModel(fi);
trees.push_back(std::move(ptr));
}
tree_info.resize(param.num_trees);
@@ -103,7 +113,7 @@ struct GBTreeModel {
CHECK_EQ(param.num_trees, static_cast<int>(trees.size()));
fo->Write(&param, sizeof(param));
for (const auto & tree : trees) {
tree->Save(fo);
tree->SaveModel(fo);
}
if (tree_info.size() != 0) {
fo->Write(dmlc::BeginPtr(tree_info), sizeof(int) * tree_info.size());