Fix num_roots to be 1. (#5165)

This commit is contained in:
Jiaming Yuan 2019-12-30 02:18:45 +08:00 committed by GitHub
parent d55489af14
commit 139ccc9902
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 3 additions and 0 deletions

View File

@ -56,6 +56,7 @@ struct TreeParam : public dmlc::Parameter<TreeParam> {
"TreeParam: 64 bit align"); "TreeParam: 64 bit align");
std::memset(this, 0, sizeof(TreeParam)); std::memset(this, 0, sizeof(TreeParam));
num_nodes = 1; num_nodes = 1;
deprecated_num_roots = 1;
} }
// declare the parameters // declare the parameters
DMLC_DECLARE_PARAMETER(TreeParam) { DMLC_DECLARE_PARAMETER(TreeParam) {

View File

@ -47,6 +47,7 @@ struct GBTreeModelParam : public dmlc::Parameter<GBTreeModelParam> {
std::memset(this, 0, sizeof(GBTreeModelParam)); // FIXME(trivialfis): Why? std::memset(this, 0, sizeof(GBTreeModelParam)); // FIXME(trivialfis): Why?
static_assert(sizeof(GBTreeModelParam) == (4 + 2 + 2 + 32) * sizeof(int32_t), static_assert(sizeof(GBTreeModelParam) == (4 + 2 + 2 + 32) * sizeof(int32_t),
"64/32 bit compatibility issue"); "64/32 bit compatibility issue");
deprecated_num_roots = 1;
} }
// declare parameters, only declare those that need to be set. // declare parameters, only declare those that need to be set.

View File

@ -640,6 +640,7 @@ void RegTree::Save(dmlc::Stream* fo) const {
CHECK_EQ(param.num_nodes, static_cast<int>(nodes_.size())); CHECK_EQ(param.num_nodes, static_cast<int>(nodes_.size()));
CHECK_EQ(param.num_nodes, static_cast<int>(stats_.size())); CHECK_EQ(param.num_nodes, static_cast<int>(stats_.size()));
fo->Write(&param, sizeof(TreeParam)); fo->Write(&param, sizeof(TreeParam));
CHECK_EQ(param.deprecated_num_roots, 1);
CHECK_NE(param.num_nodes, 0); CHECK_NE(param.num_nodes, 0);
fo->Write(dmlc::BeginPtr(nodes_), sizeof(Node) * nodes_.size()); fo->Write(dmlc::BeginPtr(nodes_), sizeof(Node) * nodes_.size());
fo->Write(dmlc::BeginPtr(stats_), sizeof(RTreeNodeStat) * nodes_.size()); fo->Write(dmlc::BeginPtr(stats_), sizeof(RTreeNodeStat) * nodes_.size());