Add Model and Configurable interface. (#4945)

* Apply Configurable to objective functions.
* Apply Model to Learner and Regtree, gbm.
* Add Load/SaveConfig to objs.
* Refactor obj tests to use smart pointer.
* Dummy methods for Save/Load Model.
This commit is contained in:
Jiaming Yuan
2019-10-18 01:56:02 -04:00
committed by GitHub
parent 9fc681001a
commit ae536756ae
31 changed files with 521 additions and 187 deletions

View File

@@ -617,6 +617,35 @@ std::string RegTree::DumpModel(const FeatureMap& fmap,
return result;
}
void RegTree::LoadModel(dmlc::Stream* fi) {
CHECK_EQ(fi->Read(&param, sizeof(TreeParam)), sizeof(TreeParam));
nodes_.resize(param.num_nodes);
stats_.resize(param.num_nodes);
CHECK_NE(param.num_nodes, 0);
CHECK_EQ(fi->Read(dmlc::BeginPtr(nodes_), sizeof(Node) * nodes_.size()),
sizeof(Node) * nodes_.size());
CHECK_EQ(fi->Read(dmlc::BeginPtr(stats_), sizeof(RTreeNodeStat) * stats_.size()),
sizeof(RTreeNodeStat) * stats_.size());
// chg deleted nodes
deleted_nodes_.resize(0);
for (int i = param.num_roots; i < param.num_nodes; ++i) {
if (nodes_[i].IsDeleted()) deleted_nodes_.push_back(i);
}
CHECK_EQ(static_cast<int>(deleted_nodes_.size()), param.num_deleted);
}
/*!
* \brief save model to stream
* \param fo output stream
*/
void RegTree::SaveModel(dmlc::Stream* fo) const {
CHECK_EQ(param.num_nodes, static_cast<int>(nodes_.size()));
CHECK_EQ(param.num_nodes, static_cast<int>(stats_.size()));
fo->Write(&param, sizeof(TreeParam));
CHECK_NE(param.num_nodes, 0);
fo->Write(dmlc::BeginPtr(nodes_), sizeof(Node) * nodes_.size());
fo->Write(dmlc::BeginPtr(stats_), sizeof(RTreeNodeStat) * nodes_.size());
}
void RegTree::FillNodeMeanValues() {
size_t num_nodes = this->param.num_nodes;
if (this->node_mean_values_.size() == num_nodes) {

View File

@@ -1053,12 +1053,12 @@ class GPUHistMakerSpecialised {
common::MemoryBufferStream fs(&s_model);
int rank = rabit::GetRank();
if (rank == 0) {
local_trees.front().Save(&fs);
local_trees.front().SaveModel(&fs);
}
fs.Seek(0);
rabit::Broadcast(&s_model, 0);
RegTree reference_tree{};
reference_tree.Load(&fs);
reference_tree.LoadModel(&fs);
for (const auto& tree : local_trees) {
CHECK(tree == reference_tree);
}

View File

@@ -35,13 +35,13 @@ class TreeSyncher: public TreeUpdater {
int rank = rabit::GetRank();
if (rank == 0) {
for (auto tree : trees) {
tree->Save(&fs);
tree->SaveModel(&fs);
}
}
fs.Seek(0);
rabit::Broadcast(&s_model, 0);
for (auto tree : trees) {
tree->Load(&fs);
tree->LoadModel(&fs);
}
}
};