|
|
|
|
@@ -815,9 +815,9 @@ void RegTree::ExpandNode(bst_node_t nidx, bst_feature_t split_index, float split
|
|
|
|
|
linalg::VectorView<float const> left_weight,
|
|
|
|
|
linalg::VectorView<float const> right_weight) {
|
|
|
|
|
CHECK(IsMultiTarget());
|
|
|
|
|
CHECK_LT(split_index, this->param.num_feature);
|
|
|
|
|
CHECK_LT(split_index, this->param_.num_feature);
|
|
|
|
|
CHECK(this->p_mt_tree_);
|
|
|
|
|
CHECK_GT(param.size_leaf_vector, 1);
|
|
|
|
|
CHECK_GT(param_.size_leaf_vector, 1);
|
|
|
|
|
|
|
|
|
|
this->p_mt_tree_->Expand(nidx, split_index, split_cond, default_left, base_weight, left_weight,
|
|
|
|
|
right_weight);
|
|
|
|
|
@@ -826,7 +826,7 @@ void RegTree::ExpandNode(bst_node_t nidx, bst_feature_t split_index, float split
|
|
|
|
|
split_categories_segments_.resize(this->Size());
|
|
|
|
|
this->split_types_.at(nidx) = FeatureType::kNumerical;
|
|
|
|
|
|
|
|
|
|
this->param.num_nodes = this->p_mt_tree_->Size();
|
|
|
|
|
this->param_.num_nodes = this->p_mt_tree_->Size();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void RegTree::ExpandCategorical(bst_node_t nid, bst_feature_t split_index,
|
|
|
|
|
@@ -850,13 +850,13 @@ void RegTree::ExpandCategorical(bst_node_t nid, bst_feature_t split_index,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void RegTree::Load(dmlc::Stream* fi) {
|
|
|
|
|
CHECK_EQ(fi->Read(¶m, sizeof(TreeParam)), sizeof(TreeParam));
|
|
|
|
|
CHECK_EQ(fi->Read(¶m_, sizeof(TreeParam)), sizeof(TreeParam));
|
|
|
|
|
if (!DMLC_IO_NO_ENDIAN_SWAP) {
|
|
|
|
|
param = param.ByteSwap();
|
|
|
|
|
param_ = param_.ByteSwap();
|
|
|
|
|
}
|
|
|
|
|
nodes_.resize(param.num_nodes);
|
|
|
|
|
stats_.resize(param.num_nodes);
|
|
|
|
|
CHECK_NE(param.num_nodes, 0);
|
|
|
|
|
nodes_.resize(param_.num_nodes);
|
|
|
|
|
stats_.resize(param_.num_nodes);
|
|
|
|
|
CHECK_NE(param_.num_nodes, 0);
|
|
|
|
|
CHECK_EQ(fi->Read(dmlc::BeginPtr(nodes_), sizeof(Node) * nodes_.size()),
|
|
|
|
|
sizeof(Node) * nodes_.size());
|
|
|
|
|
if (!DMLC_IO_NO_ENDIAN_SWAP) {
|
|
|
|
|
@@ -873,29 +873,29 @@ void RegTree::Load(dmlc::Stream* fi) {
|
|
|
|
|
}
|
|
|
|
|
// chg deleted nodes
|
|
|
|
|
deleted_nodes_.resize(0);
|
|
|
|
|
for (int i = 1; i < param.num_nodes; ++i) {
|
|
|
|
|
for (int i = 1; i < param_.num_nodes; ++i) {
|
|
|
|
|
if (nodes_[i].IsDeleted()) {
|
|
|
|
|
deleted_nodes_.push_back(i);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
CHECK_EQ(static_cast<int>(deleted_nodes_.size()), param.num_deleted);
|
|
|
|
|
CHECK_EQ(static_cast<int>(deleted_nodes_.size()), param_.num_deleted);
|
|
|
|
|
|
|
|
|
|
split_types_.resize(param.num_nodes, FeatureType::kNumerical);
|
|
|
|
|
split_categories_segments_.resize(param.num_nodes);
|
|
|
|
|
split_types_.resize(param_.num_nodes, FeatureType::kNumerical);
|
|
|
|
|
split_categories_segments_.resize(param_.num_nodes);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void RegTree::Save(dmlc::Stream* fo) const {
|
|
|
|
|
CHECK_EQ(param.num_nodes, static_cast<int>(nodes_.size()));
|
|
|
|
|
CHECK_EQ(param.num_nodes, static_cast<int>(stats_.size()));
|
|
|
|
|
CHECK_EQ(param.deprecated_num_roots, 1);
|
|
|
|
|
CHECK_NE(param.num_nodes, 0);
|
|
|
|
|
CHECK_EQ(param_.num_nodes, static_cast<int>(nodes_.size()));
|
|
|
|
|
CHECK_EQ(param_.num_nodes, static_cast<int>(stats_.size()));
|
|
|
|
|
CHECK_EQ(param_.deprecated_num_roots, 1);
|
|
|
|
|
CHECK_NE(param_.num_nodes, 0);
|
|
|
|
|
CHECK(!HasCategoricalSplit())
|
|
|
|
|
<< "Please use JSON/UBJSON for saving models with categorical splits.";
|
|
|
|
|
|
|
|
|
|
if (DMLC_IO_NO_ENDIAN_SWAP) {
|
|
|
|
|
fo->Write(¶m, sizeof(TreeParam));
|
|
|
|
|
fo->Write(¶m_, sizeof(TreeParam));
|
|
|
|
|
} else {
|
|
|
|
|
TreeParam x = param.ByteSwap();
|
|
|
|
|
TreeParam x = param_.ByteSwap();
|
|
|
|
|
fo->Write(&x, sizeof(x));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -1081,7 +1081,7 @@ void RegTree::LoadModel(Json const& in) {
|
|
|
|
|
bool typed = IsA<I32Array>(in[tf::kParent]);
|
|
|
|
|
auto const& in_obj = get<Object const>(in);
|
|
|
|
|
// basic properties
|
|
|
|
|
FromJson(in["tree_param"], ¶m);
|
|
|
|
|
FromJson(in["tree_param"], ¶m_);
|
|
|
|
|
// categorical splits
|
|
|
|
|
bool has_cat = in_obj.find("split_type") != in_obj.cend();
|
|
|
|
|
if (has_cat) {
|
|
|
|
|
@@ -1092,55 +1092,55 @@ void RegTree::LoadModel(Json const& in) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// multi-target
|
|
|
|
|
if (param.size_leaf_vector > 1) {
|
|
|
|
|
this->p_mt_tree_.reset(new MultiTargetTree{¶m});
|
|
|
|
|
if (param_.size_leaf_vector > 1) {
|
|
|
|
|
this->p_mt_tree_.reset(new MultiTargetTree{¶m_});
|
|
|
|
|
this->GetMultiTargetTree()->LoadModel(in);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool feature_is_64 = IsA<I64Array>(in["split_indices"]);
|
|
|
|
|
if (typed && feature_is_64) {
|
|
|
|
|
LoadModelImpl<true, true>(in, param, &stats_, &nodes_);
|
|
|
|
|
LoadModelImpl<true, true>(in, param_, &stats_, &nodes_);
|
|
|
|
|
} else if (typed && !feature_is_64) {
|
|
|
|
|
LoadModelImpl<true, false>(in, param, &stats_, &nodes_);
|
|
|
|
|
LoadModelImpl<true, false>(in, param_, &stats_, &nodes_);
|
|
|
|
|
} else if (!typed && feature_is_64) {
|
|
|
|
|
LoadModelImpl<false, true>(in, param, &stats_, &nodes_);
|
|
|
|
|
LoadModelImpl<false, true>(in, param_, &stats_, &nodes_);
|
|
|
|
|
} else {
|
|
|
|
|
LoadModelImpl<false, false>(in, param, &stats_, &nodes_);
|
|
|
|
|
LoadModelImpl<false, false>(in, param_, &stats_, &nodes_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!has_cat) {
|
|
|
|
|
this->split_categories_segments_.resize(this->param.num_nodes);
|
|
|
|
|
this->split_types_.resize(this->param.num_nodes);
|
|
|
|
|
this->split_categories_segments_.resize(this->param_.num_nodes);
|
|
|
|
|
this->split_types_.resize(this->param_.num_nodes);
|
|
|
|
|
std::fill(split_types_.begin(), split_types_.end(), FeatureType::kNumerical);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
deleted_nodes_.clear();
|
|
|
|
|
for (bst_node_t i = 1; i < param.num_nodes; ++i) {
|
|
|
|
|
for (bst_node_t i = 1; i < param_.num_nodes; ++i) {
|
|
|
|
|
if (nodes_[i].IsDeleted()) {
|
|
|
|
|
deleted_nodes_.push_back(i);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// easier access to [] operator
|
|
|
|
|
auto& self = *this;
|
|
|
|
|
for (auto nid = 1; nid < param.num_nodes; ++nid) {
|
|
|
|
|
for (auto nid = 1; nid < param_.num_nodes; ++nid) {
|
|
|
|
|
auto parent = self[nid].Parent();
|
|
|
|
|
CHECK_NE(parent, RegTree::kInvalidNodeId);
|
|
|
|
|
self[nid].SetParent(self[nid].Parent(), self[parent].LeftChild() == nid);
|
|
|
|
|
}
|
|
|
|
|
CHECK_EQ(static_cast<bst_node_t>(deleted_nodes_.size()), param.num_deleted);
|
|
|
|
|
CHECK_EQ(this->split_categories_segments_.size(), param.num_nodes);
|
|
|
|
|
CHECK_EQ(static_cast<bst_node_t>(deleted_nodes_.size()), param_.num_deleted);
|
|
|
|
|
CHECK_EQ(this->split_categories_segments_.size(), param_.num_nodes);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void RegTree::SaveModel(Json* p_out) const {
|
|
|
|
|
auto& out = *p_out;
|
|
|
|
|
// basic properties
|
|
|
|
|
out["tree_param"] = ToJson(param);
|
|
|
|
|
out["tree_param"] = ToJson(param_);
|
|
|
|
|
// categorical splits
|
|
|
|
|
this->SaveCategoricalSplit(p_out);
|
|
|
|
|
// multi-target
|
|
|
|
|
if (this->IsMultiTarget()) {
|
|
|
|
|
CHECK_GT(param.size_leaf_vector, 1);
|
|
|
|
|
CHECK_GT(param_.size_leaf_vector, 1);
|
|
|
|
|
this->GetMultiTargetTree()->SaveModel(p_out);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
@@ -1150,11 +1150,11 @@ void RegTree::SaveModel(Json* p_out) const {
|
|
|
|
|
* pruner, and this pruner can be used inside another updater so leaf are not necessary
|
|
|
|
|
* at the end of node array.
|
|
|
|
|
*/
|
|
|
|
|
CHECK_EQ(param.num_nodes, static_cast<int>(nodes_.size()));
|
|
|
|
|
CHECK_EQ(param.num_nodes, static_cast<int>(stats_.size()));
|
|
|
|
|
CHECK_EQ(param_.num_nodes, static_cast<int>(nodes_.size()));
|
|
|
|
|
CHECK_EQ(param_.num_nodes, static_cast<int>(stats_.size()));
|
|
|
|
|
|
|
|
|
|
CHECK_EQ(get<String>(out["tree_param"]["num_nodes"]), std::to_string(param.num_nodes));
|
|
|
|
|
auto n_nodes = param.num_nodes;
|
|
|
|
|
CHECK_EQ(get<String>(out["tree_param"]["num_nodes"]), std::to_string(param_.num_nodes));
|
|
|
|
|
auto n_nodes = param_.num_nodes;
|
|
|
|
|
|
|
|
|
|
// stats
|
|
|
|
|
F32Array loss_changes(n_nodes);
|
|
|
|
|
@@ -1168,7 +1168,7 @@ void RegTree::SaveModel(Json* p_out) const {
|
|
|
|
|
|
|
|
|
|
F32Array conds(n_nodes);
|
|
|
|
|
U8Array default_left(n_nodes);
|
|
|
|
|
CHECK_EQ(this->split_types_.size(), param.num_nodes);
|
|
|
|
|
CHECK_EQ(this->split_types_.size(), param_.num_nodes);
|
|
|
|
|
|
|
|
|
|
namespace tf = tree_field;
|
|
|
|
|
|
|
|
|
|
@@ -1189,7 +1189,7 @@ void RegTree::SaveModel(Json* p_out) const {
|
|
|
|
|
default_left.Set(i, static_cast<uint8_t>(!!n.DefaultLeft()));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
if (this->param.num_feature > static_cast<bst_feature_t>(std::numeric_limits<int32_t>::max())) {
|
|
|
|
|
if (this->param_.num_feature > static_cast<bst_feature_t>(std::numeric_limits<int32_t>::max())) {
|
|
|
|
|
I64Array indices_64(n_nodes);
|
|
|
|
|
save_tree(&indices_64);
|
|
|
|
|
out[tf::kSplitIdx] = std::move(indices_64);
|
|
|
|
|
|