Model IO in JSON. (#5110)
This commit is contained in:
@@ -85,6 +85,20 @@ class GBLinear : public GradientBooster {
|
||||
model_.Save(fo);
|
||||
}
|
||||
|
||||
void SaveModel(Json* p_out) const override {
|
||||
auto& out = *p_out;
|
||||
out["name"] = String{"gblinear"};
|
||||
|
||||
out["model"] = Object();
|
||||
auto& model = out["model"];
|
||||
model_.SaveModel(&model);
|
||||
}
|
||||
void LoadModel(Json const& in) override {
|
||||
CHECK_EQ(get<String>(in["name"]), "gblinear");
|
||||
auto const& model = in["model"];
|
||||
model_.LoadModel(model);
|
||||
}
|
||||
|
||||
void DoBoost(DMatrix *p_fmat,
|
||||
HostDeviceVector<GradientPair> *in_gpair,
|
||||
ObjFunction* obj) override {
|
||||
|
||||
38
src/gbm/gblinear_model.cc
Normal file
38
src/gbm/gblinear_model.cc
Normal file
@@ -0,0 +1,38 @@
|
||||
/*!
|
||||
* Copyright 2019 by Contributors
|
||||
*/
|
||||
#include <utility>
|
||||
#include <limits>
|
||||
#include "xgboost/json.h"
|
||||
#include "gblinear_model.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace gbm {
|
||||
|
||||
void GBLinearModel::SaveModel(Json* p_out) const {
|
||||
using WeightType = std::remove_reference<decltype(std::declval<decltype(weight)>().back())>::type;
|
||||
using JsonFloat = Number::Float;
|
||||
static_assert(std::is_same<WeightType, JsonFloat>::value,
|
||||
"Weight type should be of the same type with JSON float");
|
||||
auto& out = *p_out;
|
||||
|
||||
size_t const n_weights = weight.size();
|
||||
std::vector<Json> j_weights(n_weights);
|
||||
for (size_t i = 0; i < n_weights; ++i) {
|
||||
j_weights[i] = weight[i];
|
||||
}
|
||||
out["weights"] = std::move(j_weights);
|
||||
}
|
||||
|
||||
void GBLinearModel::LoadModel(Json const& in) {
|
||||
auto const& j_weights = get<Array const>(in["weights"]);
|
||||
auto n_weights = j_weights.size();
|
||||
weight.resize(n_weights);
|
||||
for (size_t i = 0; i < n_weights; ++i) {
|
||||
weight[i] = get<Number const>(j_weights[i]);
|
||||
}
|
||||
}
|
||||
|
||||
DMLC_REGISTER_PARAMETER(DeprecatedGBLinearModelParam);
|
||||
} // namespace gbm
|
||||
} // namespace xgboost
|
||||
@@ -62,27 +62,21 @@ class GBLinearModel : public Model {
|
||||
learner_model_param_->num_output_group);
|
||||
std::fill(weight.begin(), weight.end(), 0.0f);
|
||||
}
|
||||
|
||||
void SaveModel(Json *p_out) const override;
|
||||
void LoadModel(Json const &in) override;
|
||||
|
||||
// save the model to file
|
||||
inline void Save(dmlc::Stream* fo) const {
|
||||
void Save(dmlc::Stream *fo) const {
|
||||
fo->Write(¶m, sizeof(param));
|
||||
fo->Write(weight);
|
||||
}
|
||||
// load model from file
|
||||
inline void Load(dmlc::Stream* fi) {
|
||||
void Load(dmlc::Stream *fi) {
|
||||
CHECK_EQ(fi->Read(¶m, sizeof(param)), sizeof(param));
|
||||
fi->Read(&weight);
|
||||
}
|
||||
|
||||
void LoadModel(dmlc::Stream* fi) override {
|
||||
// They are the same right now until we can split up the saved parameter from model.
|
||||
this->Load(fi);
|
||||
}
|
||||
|
||||
void SaveModel(dmlc::Stream* fo) const override {
|
||||
// They are the same right now until we can split up the saved parameter from model.
|
||||
this->Save(fo);
|
||||
}
|
||||
|
||||
// model bias
|
||||
inline bst_float *bias() {
|
||||
return &weight[learner_model_param_->num_feature *
|
||||
|
||||
@@ -289,8 +289,19 @@ void GBTree::CommitModel(std::vector<std::vector<std::unique_ptr<RegTree>>>&& ne
|
||||
monitor_.Stop("CommitModel");
|
||||
}
|
||||
|
||||
void GBTree::LoadModel(Json const& in) {
|
||||
CHECK_EQ(get<String>(in["name"]), "gbtree");
|
||||
model_.LoadModel(in["model"]);
|
||||
}
|
||||
|
||||
void GBTree::SaveModel(Json* p_out) const {
|
||||
auto& out = *p_out;
|
||||
out["name"] = String("gbtree");
|
||||
out["model"] = Object();
|
||||
auto& model = out["model"];
|
||||
model_.SaveModel(&model);
|
||||
}
|
||||
|
||||
// dart
|
||||
class Dart : public GBTree {
|
||||
public:
|
||||
explicit Dart(LearnerModelParam const* booster_config) :
|
||||
@@ -303,6 +314,30 @@ class Dart : public GBTree {
|
||||
}
|
||||
}
|
||||
|
||||
void SaveModel(Json *p_out) const override {
|
||||
auto &out = *p_out;
|
||||
out["name"] = String("dart");
|
||||
out["gbtree"] = Object();
|
||||
GBTree::SaveModel(&(out["gbtree"]));
|
||||
|
||||
std::vector<Json> j_weight_drop(weight_drop_.size());
|
||||
for (size_t i = 0; i < weight_drop_.size(); ++i) {
|
||||
j_weight_drop[i] = Number(weight_drop_[i]);
|
||||
}
|
||||
out["weight_drop"] = Array(j_weight_drop);
|
||||
}
|
||||
void LoadModel(Json const& in) override {
|
||||
CHECK_EQ(get<String>(in["name"]), "dart");
|
||||
auto const& gbtree = in["gbtree"];
|
||||
GBTree::LoadModel(gbtree);
|
||||
|
||||
auto const& j_weight_drop = get<Array>(in["weight_drop"]);
|
||||
weight_drop_.resize(j_weight_drop.size());
|
||||
for (size_t i = 0; i < weight_drop_.size(); ++i) {
|
||||
weight_drop_[i] = get<Number const>(j_weight_drop[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void Load(dmlc::Stream* fi) override {
|
||||
GBTree::Load(fi);
|
||||
weight_drop_.resize(model_.param.num_trees);
|
||||
@@ -387,7 +422,7 @@ class Dart : public GBTree {
|
||||
if (init_out_preds) {
|
||||
size_t n = num_group * p_fmat->Info().num_row_;
|
||||
const auto& base_margin =
|
||||
p_fmat->Info().base_margin_.ConstHostVector();
|
||||
p_fmat->Info().base_margin_.ConstHostVector();
|
||||
out_preds->resize(n);
|
||||
if (base_margin.size() != 0) {
|
||||
CHECK_EQ(out_preds->size(), n);
|
||||
|
||||
@@ -192,6 +192,9 @@ class GBTree : public GradientBooster {
|
||||
model_.Save(fo);
|
||||
}
|
||||
|
||||
void SaveModel(Json* p_out) const override;
|
||||
void LoadModel(Json const& in) override;
|
||||
|
||||
bool AllowLazyCheckPoint() const override {
|
||||
return model_.learner_model_param_->num_output_group == 1 ||
|
||||
tparam_.updater_seq.find("distcol") != std::string::npos;
|
||||
|
||||
85
src/gbm/gbtree_model.cc
Normal file
85
src/gbm/gbtree_model.cc
Normal file
@@ -0,0 +1,85 @@
|
||||
/*!
|
||||
* Copyright 2019 by Contributors
|
||||
*/
|
||||
#include "xgboost/json.h"
|
||||
#include "xgboost/logging.h"
|
||||
#include "gbtree_model.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace gbm {
|
||||
|
||||
void GBTreeModel::Save(dmlc::Stream* fo) const {
|
||||
CHECK_EQ(param.num_trees, static_cast<int32_t>(trees.size()));
|
||||
fo->Write(¶m, sizeof(param));
|
||||
for (const auto & tree : trees) {
|
||||
tree->Save(fo);
|
||||
}
|
||||
if (tree_info.size() != 0) {
|
||||
fo->Write(dmlc::BeginPtr(tree_info), sizeof(int32_t) * tree_info.size());
|
||||
}
|
||||
}
|
||||
|
||||
void GBTreeModel::Load(dmlc::Stream* fi) {
|
||||
CHECK_EQ(fi->Read(¶m, sizeof(param)), sizeof(param))
|
||||
<< "GBTree: invalid model file";
|
||||
trees.clear();
|
||||
trees_to_update.clear();
|
||||
for (int32_t i = 0; i < param.num_trees; ++i) {
|
||||
std::unique_ptr<RegTree> ptr(new RegTree());
|
||||
ptr->Load(fi);
|
||||
trees.push_back(std::move(ptr));
|
||||
}
|
||||
tree_info.resize(param.num_trees);
|
||||
if (param.num_trees != 0) {
|
||||
CHECK_EQ(
|
||||
fi->Read(dmlc::BeginPtr(tree_info), sizeof(int32_t) * param.num_trees),
|
||||
sizeof(int32_t) * param.num_trees);
|
||||
}
|
||||
}
|
||||
|
||||
void GBTreeModel::SaveModel(Json* p_out) const {
|
||||
auto& out = *p_out;
|
||||
CHECK_EQ(param.num_trees, static_cast<int>(trees.size()));
|
||||
out["model_param"] = toJson(param);
|
||||
std::vector<Json> trees_json;
|
||||
size_t t = 0;
|
||||
for (auto const& tree : trees) {
|
||||
Json tree_json{Object()};
|
||||
tree->SaveModel(&tree_json);
|
||||
tree_json["id"] = std::to_string(t);
|
||||
trees_json.emplace_back(tree_json);
|
||||
t++;
|
||||
}
|
||||
|
||||
std::vector<Json> tree_info_json(tree_info.size());
|
||||
for (size_t i = 0; i < tree_info.size(); ++i) {
|
||||
tree_info_json[i] = Integer(tree_info[i]);
|
||||
}
|
||||
|
||||
out["trees"] = Array(std::move(trees_json));
|
||||
out["tree_info"] = Array(std::move(tree_info_json));
|
||||
}
|
||||
|
||||
void GBTreeModel::LoadModel(Json const& in) {
|
||||
fromJson(in["model_param"], ¶m);
|
||||
|
||||
trees.clear();
|
||||
trees_to_update.clear();
|
||||
|
||||
auto const& trees_json = get<Array const>(in["trees"]);
|
||||
trees.resize(trees_json.size());
|
||||
|
||||
for (size_t t = 0; t < trees.size(); ++t) {
|
||||
trees[t].reset( new RegTree() );
|
||||
trees[t]->LoadModel(trees_json[t]);
|
||||
}
|
||||
|
||||
tree_info.resize(param.num_trees);
|
||||
auto const& tree_info_json = get<Array const>(in["tree_info"]);
|
||||
for (int32_t i = 0; i < param.num_trees; ++i) {
|
||||
tree_info[i] = get<Integer const>(tree_info_json[i]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gbm
|
||||
} // namespace xgboost
|
||||
@@ -84,43 +84,11 @@ struct GBTreeModel : public Model {
|
||||
}
|
||||
}
|
||||
|
||||
void LoadModel(dmlc::Stream* fi) override {
|
||||
// They are the same right now until we can split up the saved parameter from model.
|
||||
this->Load(fi);
|
||||
}
|
||||
void SaveModel(dmlc::Stream* fo) const override {
|
||||
// They are the same right now until we can split up the saved parameter from model.
|
||||
this->Save(fo);
|
||||
}
|
||||
void Load(dmlc::Stream* fi);
|
||||
void Save(dmlc::Stream* fo) const;
|
||||
|
||||
void Load(dmlc::Stream* fi) {
|
||||
CHECK_EQ(fi->Read(¶m, sizeof(param)), sizeof(param))
|
||||
<< "GBTree: invalid model file";
|
||||
trees.clear();
|
||||
trees_to_update.clear();
|
||||
for (int i = 0; i < param.num_trees; ++i) {
|
||||
std::unique_ptr<RegTree> ptr(new RegTree());
|
||||
ptr->LoadModel(fi);
|
||||
trees.push_back(std::move(ptr));
|
||||
}
|
||||
tree_info.resize(param.num_trees);
|
||||
if (param.num_trees != 0) {
|
||||
CHECK_EQ(
|
||||
fi->Read(dmlc::BeginPtr(tree_info), sizeof(int) * param.num_trees),
|
||||
sizeof(int) * param.num_trees);
|
||||
}
|
||||
}
|
||||
|
||||
void Save(dmlc::Stream* fo) const {
|
||||
CHECK_EQ(param.num_trees, static_cast<int>(trees.size()));
|
||||
fo->Write(¶m, sizeof(param));
|
||||
for (const auto & tree : trees) {
|
||||
tree->SaveModel(fo);
|
||||
}
|
||||
if (tree_info.size() != 0) {
|
||||
fo->Write(dmlc::BeginPtr(tree_info), sizeof(int) * tree_info.size());
|
||||
}
|
||||
}
|
||||
void SaveModel(Json* p_out) const override;
|
||||
void LoadModel(Json const& p_out) override;
|
||||
|
||||
std::vector<std::string> DumpModel(const FeatureMap& fmap, bool with_stats,
|
||||
std::string format) const {
|
||||
|
||||
Reference in New Issue
Block a user