Add Model and Configurable interface. (#4945)
* Apply Configurable to objective functions. * Apply Model to Learner and Regtree, gbm. * Add Load/SaveConfig to objs. * Refactor obj tests to use smart pointer. * Dummy methods for Save/Load Model.
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
#include <dmlc/parameter.h>
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/feature_map.h>
|
||||
#include <xgboost/model.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <cstring>
|
||||
@@ -34,7 +35,7 @@ struct GBLinearModelParam : public dmlc::Parameter<GBLinearModelParam> {
|
||||
};
|
||||
|
||||
// model for linear booster
|
||||
class GBLinearModel {
|
||||
class GBLinearModel : public Model {
|
||||
public:
|
||||
// parameter
|
||||
GBLinearModelParam param;
|
||||
@@ -57,6 +58,17 @@ class GBLinearModel {
|
||||
CHECK_EQ(fi->Read(¶m, sizeof(param)), sizeof(param));
|
||||
fi->Read(&weight);
|
||||
}
|
||||
|
||||
void LoadModel(dmlc::Stream* fi) override {
|
||||
// They are the same right now until we can split up the saved parameter from model.
|
||||
this->Load(fi);
|
||||
}
|
||||
|
||||
void SaveModel(dmlc::Stream* fo) const override {
|
||||
// They are the same right now until we can split up the saved parameter from model.
|
||||
this->Save(fo);
|
||||
}
|
||||
|
||||
// model bias
|
||||
inline bst_float* bias() {
|
||||
return &weight[param.num_feature * param.num_output_group];
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
#include <xgboost/gbm.h>
|
||||
#include <xgboost/predictor.h>
|
||||
#include <xgboost/tree_updater.h>
|
||||
#include <xgboost/enum_class_param.h>
|
||||
#include <xgboost/parameter.h>
|
||||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
#include <dmlc/parameter.h>
|
||||
#include <dmlc/io.h>
|
||||
#include <xgboost/model.h>
|
||||
#include <xgboost/tree_model.h>
|
||||
|
||||
#include <memory>
|
||||
@@ -61,7 +62,7 @@ struct GBTreeModelParam : public dmlc::Parameter<GBTreeModelParam> {
|
||||
}
|
||||
};
|
||||
|
||||
struct GBTreeModel {
|
||||
struct GBTreeModel : public Model {
|
||||
explicit GBTreeModel(bst_float base_margin) : base_margin(base_margin) {}
|
||||
void Configure(const Args& cfg) {
|
||||
// initialize model parameters if not yet been initialized.
|
||||
@@ -81,6 +82,15 @@ struct GBTreeModel {
|
||||
}
|
||||
}
|
||||
|
||||
void LoadModel(dmlc::Stream* fi) override {
|
||||
// They are the same right now until we can split up the saved parameter from model.
|
||||
this->Load(fi);
|
||||
}
|
||||
void SaveModel(dmlc::Stream* fo) const override {
|
||||
// They are the same right now until we can split up the saved parameter from model.
|
||||
this->Save(fo);
|
||||
}
|
||||
|
||||
void Load(dmlc::Stream* fi) {
|
||||
CHECK_EQ(fi->Read(¶m, sizeof(param)), sizeof(param))
|
||||
<< "GBTree: invalid model file";
|
||||
@@ -88,7 +98,7 @@ struct GBTreeModel {
|
||||
trees_to_update.clear();
|
||||
for (int i = 0; i < param.num_trees; ++i) {
|
||||
std::unique_ptr<RegTree> ptr(new RegTree());
|
||||
ptr->Load(fi);
|
||||
ptr->LoadModel(fi);
|
||||
trees.push_back(std::move(ptr));
|
||||
}
|
||||
tree_info.resize(param.num_trees);
|
||||
@@ -103,7 +113,7 @@ struct GBTreeModel {
|
||||
CHECK_EQ(param.num_trees, static_cast<int>(trees.size()));
|
||||
fo->Write(¶m, sizeof(param));
|
||||
for (const auto & tree : trees) {
|
||||
tree->Save(fo);
|
||||
tree->SaveModel(fo);
|
||||
}
|
||||
if (tree_info.size() != 0) {
|
||||
fo->Write(dmlc::BeginPtr(tree_info), sizeof(int) * tree_info.size());
|
||||
|
||||
Reference in New Issue
Block a user