[breaking] Change internal model serialization to UBJSON. (#7556)
* Use typed array for models. * Change the memory snapshot format. * Add new C API for saving to raw format.
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2015-2021 by Contributors
|
||||
* Copyright 2015-2022 by Contributors
|
||||
* \file tree_model.cc
|
||||
* \brief model structure for tree
|
||||
*/
|
||||
@@ -893,27 +893,57 @@ void RegTree::Save(dmlc::Stream* fo) const {
|
||||
}
|
||||
}
|
||||
}
|
||||
// typed array, not boolean
|
||||
template <typename JT, typename T>
|
||||
std::enable_if_t<!std::is_same<T, Json>::value && !std::is_same<JT, Boolean>::value, T> GetElem(
|
||||
std::vector<T> const& arr, size_t i) {
|
||||
return arr[i];
|
||||
}
|
||||
// typed array boolean
|
||||
template <typename JT, typename T>
|
||||
std::enable_if_t<!std::is_same<T, Json>::value && std::is_same<T, uint8_t>::value &&
|
||||
std::is_same<JT, Boolean>::value,
|
||||
bool>
|
||||
GetElem(std::vector<T> const& arr, size_t i) {
|
||||
return arr[i] == 1;
|
||||
}
|
||||
// json array
|
||||
template <typename JT, typename T>
|
||||
std::enable_if_t<
|
||||
std::is_same<T, Json>::value,
|
||||
std::conditional_t<std::is_same<JT, Integer>::value, int64_t,
|
||||
std::conditional_t<std::is_same<Boolean, JT>::value, bool, float>>>
|
||||
GetElem(std::vector<T> const& arr, size_t i) {
|
||||
if (std::is_same<JT, Boolean>::value && !IsA<Boolean>(arr[i])) {
|
||||
return get<Integer const>(arr[i]) == 1;
|
||||
}
|
||||
return get<JT const>(arr[i]);
|
||||
}
|
||||
|
||||
template <bool typed>
|
||||
void RegTree::LoadCategoricalSplit(Json const& in) {
|
||||
auto const& categories_segments = get<Array const>(in["categories_segments"]);
|
||||
auto const& categories_sizes = get<Array const>(in["categories_sizes"]);
|
||||
auto const& categories_nodes = get<Array const>(in["categories_nodes"]);
|
||||
auto const& categories = get<Array const>(in["categories"]);
|
||||
using I64ArrayT = std::conditional_t<typed, I64Array const, Array const>;
|
||||
using I32ArrayT = std::conditional_t<typed, I32Array const, Array const>;
|
||||
|
||||
auto const& categories_segments = get<I64ArrayT>(in["categories_segments"]);
|
||||
auto const& categories_sizes = get<I64ArrayT>(in["categories_sizes"]);
|
||||
auto const& categories_nodes = get<I32ArrayT>(in["categories_nodes"]);
|
||||
auto const& categories = get<I32ArrayT>(in["categories"]);
|
||||
|
||||
size_t cnt = 0;
|
||||
bst_node_t last_cat_node = -1;
|
||||
if (!categories_nodes.empty()) {
|
||||
last_cat_node = get<Integer const>(categories_nodes[cnt]);
|
||||
last_cat_node = GetElem<Integer>(categories_nodes, cnt);
|
||||
}
|
||||
for (bst_node_t nidx = 0; nidx < param.num_nodes; ++nidx) {
|
||||
if (nidx == last_cat_node) {
|
||||
auto j_begin = get<Integer const>(categories_segments[cnt]);
|
||||
auto j_end = get<Integer const>(categories_sizes[cnt]) + j_begin;
|
||||
auto j_begin = GetElem<Integer>(categories_segments, cnt);
|
||||
auto j_end = GetElem<Integer>(categories_sizes, cnt) + j_begin;
|
||||
bst_cat_t max_cat{std::numeric_limits<bst_cat_t>::min()};
|
||||
CHECK_NE(j_end - j_begin, 0) << nidx;
|
||||
|
||||
for (auto j = j_begin; j < j_end; ++j) {
|
||||
auto const &category = get<Integer const>(categories[j]);
|
||||
auto const& category = GetElem<Integer>(categories, j);
|
||||
auto cat = common::AsCat(category);
|
||||
max_cat = std::max(max_cat, cat);
|
||||
}
|
||||
@@ -924,7 +954,7 @@ void RegTree::LoadCategoricalSplit(Json const& in) {
|
||||
std::vector<uint32_t> cat_bits_storage(size, 0);
|
||||
common::CatBitField cat_bits{common::Span<uint32_t>(cat_bits_storage)};
|
||||
for (auto j = j_begin; j < j_end; ++j) {
|
||||
cat_bits.Set(common::AsCat(get<Integer const>(categories[j])));
|
||||
cat_bits.Set(common::AsCat(GetElem<Integer>(categories, j)));
|
||||
}
|
||||
|
||||
auto begin = split_categories_.size();
|
||||
@@ -936,9 +966,9 @@ void RegTree::LoadCategoricalSplit(Json const& in) {
|
||||
|
||||
++cnt;
|
||||
if (cnt == categories_nodes.size()) {
|
||||
last_cat_node = -1;
|
||||
last_cat_node = -1; // Don't break, we still need to initialize the remaining nodes.
|
||||
} else {
|
||||
last_cat_node = get<Integer const>(categories_nodes[cnt]);
|
||||
last_cat_node = GetElem<Integer>(categories_nodes, cnt);
|
||||
}
|
||||
} else {
|
||||
split_categories_segments_[nidx].beg = categories.size();
|
||||
@@ -947,104 +977,144 @@ void RegTree::LoadCategoricalSplit(Json const& in) {
|
||||
}
|
||||
}
|
||||
|
||||
template void RegTree::LoadCategoricalSplit<true>(Json const& in);
|
||||
template void RegTree::LoadCategoricalSplit<false>(Json const& in);
|
||||
|
||||
void RegTree::SaveCategoricalSplit(Json* p_out) const {
|
||||
auto& out = *p_out;
|
||||
CHECK_EQ(this->split_types_.size(), param.num_nodes);
|
||||
CHECK_EQ(this->GetSplitCategoriesPtr().size(), param.num_nodes);
|
||||
|
||||
std::vector<Json> categories_segments;
|
||||
std::vector<Json> categories_sizes;
|
||||
std::vector<Json> categories;
|
||||
std::vector<Json> categories_nodes;
|
||||
I64Array categories_segments;
|
||||
I64Array categories_sizes;
|
||||
I32Array categories; // bst_cat_t = int32_t
|
||||
I32Array categories_nodes; // bst_note_t = int32_t
|
||||
|
||||
for (size_t i = 0; i < nodes_.size(); ++i) {
|
||||
if (this->split_types_[i] == FeatureType::kCategorical) {
|
||||
categories_nodes.emplace_back(i);
|
||||
auto begin = categories.size();
|
||||
categories_segments.emplace_back(static_cast<Integer::Int>(begin));
|
||||
categories_nodes.GetArray().emplace_back(i);
|
||||
auto begin = categories.Size();
|
||||
categories_segments.GetArray().emplace_back(begin);
|
||||
auto segment = split_categories_segments_[i];
|
||||
auto node_categories =
|
||||
this->GetSplitCategories().subspan(segment.beg, segment.size);
|
||||
auto node_categories = this->GetSplitCategories().subspan(segment.beg, segment.size);
|
||||
common::KCatBitField const cat_bits(node_categories);
|
||||
for (size_t i = 0; i < cat_bits.Size(); ++i) {
|
||||
if (cat_bits.Check(i)) {
|
||||
categories.emplace_back(static_cast<Integer::Int>(i));
|
||||
categories.GetArray().emplace_back(i);
|
||||
}
|
||||
}
|
||||
size_t size = categories.size() - begin;
|
||||
categories_sizes.emplace_back(static_cast<Integer::Int>(size));
|
||||
size_t size = categories.Size() - begin;
|
||||
categories_sizes.GetArray().emplace_back(size);
|
||||
CHECK_NE(size, 0);
|
||||
}
|
||||
}
|
||||
|
||||
out["categories_segments"] = categories_segments;
|
||||
out["categories_sizes"] = categories_sizes;
|
||||
out["categories_nodes"] = categories_nodes;
|
||||
out["categories"] = categories;
|
||||
out["categories_segments"] = std::move(categories_segments);
|
||||
out["categories_sizes"] = std::move(categories_sizes);
|
||||
out["categories_nodes"] = std::move(categories_nodes);
|
||||
out["categories"] = std::move(categories);
|
||||
}
|
||||
|
||||
void RegTree::LoadModel(Json const& in) {
|
||||
FromJson(in["tree_param"], ¶m);
|
||||
auto n_nodes = param.num_nodes;
|
||||
template <bool typed, bool feature_is_64,
|
||||
typename FloatArrayT = std::conditional_t<typed, F32Array const, Array const>,
|
||||
typename U8ArrayT = std::conditional_t<typed, U8Array const, Array const>,
|
||||
typename I32ArrayT = std::conditional_t<typed, I32Array const, Array const>,
|
||||
typename I64ArrayT = std::conditional_t<typed, I64Array const, Array const>,
|
||||
typename IndexArrayT = std::conditional_t<feature_is_64, I64ArrayT, I32ArrayT>>
|
||||
bool LoadModelImpl(Json const& in, TreeParam* param, std::vector<RTreeNodeStat>* p_stats,
|
||||
std::vector<FeatureType>* p_split_types, std::vector<RegTree::Node>* p_nodes,
|
||||
std::vector<RegTree::Segment>* p_split_categories_segments) {
|
||||
auto& stats = *p_stats;
|
||||
auto& split_types = *p_split_types;
|
||||
auto& nodes = *p_nodes;
|
||||
auto& split_categories_segments = *p_split_categories_segments;
|
||||
|
||||
FromJson(in["tree_param"], param);
|
||||
auto n_nodes = param->num_nodes;
|
||||
CHECK_NE(n_nodes, 0);
|
||||
// stats
|
||||
auto const& loss_changes = get<Array const>(in["loss_changes"]);
|
||||
auto const& loss_changes = get<FloatArrayT>(in["loss_changes"]);
|
||||
CHECK_EQ(loss_changes.size(), n_nodes);
|
||||
auto const& sum_hessian = get<Array const>(in["sum_hessian"]);
|
||||
auto const& sum_hessian = get<FloatArrayT>(in["sum_hessian"]);
|
||||
CHECK_EQ(sum_hessian.size(), n_nodes);
|
||||
auto const& base_weights = get<Array const>(in["base_weights"]);
|
||||
auto const& base_weights = get<FloatArrayT>(in["base_weights"]);
|
||||
CHECK_EQ(base_weights.size(), n_nodes);
|
||||
// nodes
|
||||
auto const& lefts = get<Array const>(in["left_children"]);
|
||||
auto const& lefts = get<I32ArrayT>(in["left_children"]);
|
||||
CHECK_EQ(lefts.size(), n_nodes);
|
||||
auto const& rights = get<Array const>(in["right_children"]);
|
||||
auto const& rights = get<I32ArrayT>(in["right_children"]);
|
||||
CHECK_EQ(rights.size(), n_nodes);
|
||||
auto const& parents = get<Array const>(in["parents"]);
|
||||
auto const& parents = get<I32ArrayT>(in["parents"]);
|
||||
CHECK_EQ(parents.size(), n_nodes);
|
||||
auto const& indices = get<Array const>(in["split_indices"]);
|
||||
auto const& indices = get<IndexArrayT>(in["split_indices"]);
|
||||
CHECK_EQ(indices.size(), n_nodes);
|
||||
auto const& conds = get<Array const>(in["split_conditions"]);
|
||||
auto const& conds = get<FloatArrayT>(in["split_conditions"]);
|
||||
CHECK_EQ(conds.size(), n_nodes);
|
||||
auto const& default_left = get<Array const>(in["default_left"]);
|
||||
auto const& default_left = get<U8ArrayT>(in["default_left"]);
|
||||
CHECK_EQ(default_left.size(), n_nodes);
|
||||
|
||||
bool has_cat = get<Object const>(in).find("split_type") != get<Object const>(in).cend();
|
||||
std::vector<Json> split_type;
|
||||
std::remove_const_t<std::remove_reference_t<decltype(get<U8ArrayT const>(in["split_type"]))>>
|
||||
split_type;
|
||||
if (has_cat) {
|
||||
split_type = get<Array const>(in["split_type"]);
|
||||
split_type = get<U8ArrayT const>(in["split_type"]);
|
||||
}
|
||||
stats_.clear();
|
||||
nodes_.clear();
|
||||
stats = std::remove_reference_t<decltype(stats)>(n_nodes);
|
||||
nodes = std::remove_reference_t<decltype(nodes)>(n_nodes);
|
||||
split_types = std::remove_reference_t<decltype(split_types)>(n_nodes);
|
||||
split_categories_segments = std::remove_reference_t<decltype(split_categories_segments)>(n_nodes);
|
||||
|
||||
stats_.resize(n_nodes);
|
||||
nodes_.resize(n_nodes);
|
||||
split_types_.resize(n_nodes);
|
||||
split_categories_segments_.resize(n_nodes);
|
||||
static_assert(std::is_integral<decltype(GetElem<Integer>(lefts, 0))>::value, "");
|
||||
static_assert(std::is_floating_point<decltype(GetElem<Number>(loss_changes, 0))>::value, "");
|
||||
CHECK_EQ(n_nodes, split_categories_segments.size());
|
||||
|
||||
CHECK_EQ(n_nodes, split_categories_segments_.size());
|
||||
for (int32_t i = 0; i < n_nodes; ++i) {
|
||||
auto& s = stats_[i];
|
||||
s.loss_chg = get<Number const>(loss_changes[i]);
|
||||
s.sum_hess = get<Number const>(sum_hessian[i]);
|
||||
s.base_weight = get<Number const>(base_weights[i]);
|
||||
auto& s = stats[i];
|
||||
s.loss_chg = GetElem<Number>(loss_changes, i);
|
||||
s.sum_hess = GetElem<Number>(sum_hessian, i);
|
||||
s.base_weight = GetElem<Number>(base_weights, i);
|
||||
|
||||
auto& n = nodes_[i];
|
||||
bst_node_t left = get<Integer const>(lefts[i]);
|
||||
bst_node_t right = get<Integer const>(rights[i]);
|
||||
bst_node_t parent = get<Integer const>(parents[i]);
|
||||
bst_feature_t ind = get<Integer const>(indices[i]);
|
||||
float cond { get<Number const>(conds[i]) };
|
||||
bool dft_left { get<Boolean const>(default_left[i]) };
|
||||
n = Node{left, right, parent, ind, cond, dft_left};
|
||||
auto& n = nodes[i];
|
||||
bst_node_t left = GetElem<Integer>(lefts, i);
|
||||
bst_node_t right = GetElem<Integer>(rights, i);
|
||||
bst_node_t parent = GetElem<Integer>(parents, i);
|
||||
bst_feature_t ind = GetElem<Integer>(indices, i);
|
||||
float cond{GetElem<Number>(conds, i)};
|
||||
bool dft_left{GetElem<Boolean>(default_left, i)};
|
||||
n = RegTree::Node{left, right, parent, ind, cond, dft_left};
|
||||
|
||||
if (has_cat) {
|
||||
split_types_[i] =
|
||||
static_cast<FeatureType>(get<Integer const>(split_type[i]));
|
||||
split_types[i] = static_cast<FeatureType>(GetElem<Integer>(split_type, i));
|
||||
}
|
||||
}
|
||||
|
||||
return has_cat;
|
||||
}
|
||||
|
||||
void RegTree::LoadModel(Json const& in) {
|
||||
bool has_cat{false};
|
||||
bool typed = IsA<F32Array>(in["loss_changes"]);
|
||||
bool feature_is_64 = IsA<I64Array>(in["split_indices"]);
|
||||
if (typed && feature_is_64) {
|
||||
has_cat = LoadModelImpl<true, true>(in, ¶m, &stats_, &split_types_, &nodes_,
|
||||
&split_categories_segments_);
|
||||
} else if (typed && !feature_is_64) {
|
||||
has_cat = LoadModelImpl<true, false>(in, ¶m, &stats_, &split_types_, &nodes_,
|
||||
&split_categories_segments_);
|
||||
} else if (!typed && feature_is_64) {
|
||||
has_cat = LoadModelImpl<false, true>(in, ¶m, &stats_, &split_types_, &nodes_,
|
||||
&split_categories_segments_);
|
||||
} else {
|
||||
has_cat = LoadModelImpl<false, false>(in, ¶m, &stats_, &split_types_, &nodes_,
|
||||
&split_categories_segments_);
|
||||
}
|
||||
|
||||
if (has_cat) {
|
||||
this->LoadCategoricalSplit(in);
|
||||
if (typed) {
|
||||
this->LoadCategoricalSplit<true>(in);
|
||||
} else {
|
||||
this->LoadCategoricalSplit<false>(in);
|
||||
}
|
||||
} else {
|
||||
this->split_categories_segments_.resize(this->param.num_nodes);
|
||||
std::fill(split_types_.begin(), split_types_.end(), FeatureType::kNumerical);
|
||||
@@ -1058,7 +1128,7 @@ void RegTree::LoadModel(Json const& in) {
|
||||
}
|
||||
// easier access to [] operator
|
||||
auto& self = *this;
|
||||
for (auto nid = 1; nid < n_nodes; ++nid) {
|
||||
for (auto nid = 1; nid < param.num_nodes; ++nid) {
|
||||
auto parent = self[nid].Parent();
|
||||
CHECK_NE(parent, RegTree::kInvalidNodeId);
|
||||
self[nid].SetParent(self[nid].Parent(), self[parent].LeftChild() == nid);
|
||||
@@ -1079,39 +1149,51 @@ void RegTree::SaveModel(Json* p_out) const {
|
||||
CHECK_EQ(param.num_nodes, static_cast<int>(stats_.size()));
|
||||
out["tree_param"] = ToJson(param);
|
||||
CHECK_EQ(get<String>(out["tree_param"]["num_nodes"]), std::to_string(param.num_nodes));
|
||||
using I = Integer::Int;
|
||||
auto n_nodes = param.num_nodes;
|
||||
|
||||
// stats
|
||||
std::vector<Json> loss_changes(n_nodes);
|
||||
std::vector<Json> sum_hessian(n_nodes);
|
||||
std::vector<Json> base_weights(n_nodes);
|
||||
F32Array loss_changes(n_nodes);
|
||||
F32Array sum_hessian(n_nodes);
|
||||
F32Array base_weights(n_nodes);
|
||||
|
||||
// nodes
|
||||
std::vector<Json> lefts(n_nodes);
|
||||
std::vector<Json> rights(n_nodes);
|
||||
std::vector<Json> parents(n_nodes);
|
||||
std::vector<Json> indices(n_nodes);
|
||||
std::vector<Json> conds(n_nodes);
|
||||
std::vector<Json> default_left(n_nodes);
|
||||
std::vector<Json> split_type(n_nodes);
|
||||
I32Array lefts(n_nodes);
|
||||
I32Array rights(n_nodes);
|
||||
I32Array parents(n_nodes);
|
||||
|
||||
|
||||
F32Array conds(n_nodes);
|
||||
U8Array default_left(n_nodes);
|
||||
U8Array split_type(n_nodes);
|
||||
CHECK_EQ(this->split_types_.size(), param.num_nodes);
|
||||
|
||||
for (bst_node_t i = 0; i < n_nodes; ++i) {
|
||||
auto const& s = stats_[i];
|
||||
loss_changes[i] = s.loss_chg;
|
||||
sum_hessian[i] = s.sum_hess;
|
||||
base_weights[i] = s.base_weight;
|
||||
auto save_tree = [&](auto* p_indices_array) {
|
||||
auto& indices_array = *p_indices_array;
|
||||
for (bst_node_t i = 0; i < n_nodes; ++i) {
|
||||
auto const& s = stats_[i];
|
||||
loss_changes.Set(i, s.loss_chg);
|
||||
sum_hessian.Set(i, s.sum_hess);
|
||||
base_weights.Set(i, s.base_weight);
|
||||
|
||||
auto const& n = nodes_[i];
|
||||
lefts[i] = static_cast<I>(n.LeftChild());
|
||||
rights[i] = static_cast<I>(n.RightChild());
|
||||
parents[i] = static_cast<I>(n.Parent());
|
||||
indices[i] = static_cast<I>(n.SplitIndex());
|
||||
conds[i] = n.SplitCond();
|
||||
default_left[i] = n.DefaultLeft();
|
||||
auto const& n = nodes_[i];
|
||||
lefts.Set(i, n.LeftChild());
|
||||
rights.Set(i, n.RightChild());
|
||||
parents.Set(i, n.Parent());
|
||||
indices_array.Set(i, n.SplitIndex());
|
||||
conds.Set(i, n.SplitCond());
|
||||
default_left.Set(i, static_cast<uint8_t>(!!n.DefaultLeft()));
|
||||
|
||||
split_type[i] = static_cast<I>(this->NodeSplitType(i));
|
||||
split_type.Set(i, static_cast<uint8_t>(this->NodeSplitType(i)));
|
||||
}
|
||||
};
|
||||
if (this->param.num_feature > static_cast<bst_feature_t>(std::numeric_limits<int32_t>::max())) {
|
||||
I64Array indices_64(n_nodes);
|
||||
save_tree(&indices_64);
|
||||
out["split_indices"] = std::move(indices_64);
|
||||
} else {
|
||||
I32Array indices_32(n_nodes);
|
||||
save_tree(&indices_32);
|
||||
out["split_indices"] = std::move(indices_32);
|
||||
}
|
||||
|
||||
this->SaveCategoricalSplit(&out);
|
||||
@@ -1124,7 +1206,7 @@ void RegTree::SaveModel(Json* p_out) const {
|
||||
out["left_children"] = std::move(lefts);
|
||||
out["right_children"] = std::move(rights);
|
||||
out["parents"] = std::move(parents);
|
||||
out["split_indices"] = std::move(indices);
|
||||
|
||||
out["split_conditions"] = std::move(conds);
|
||||
out["default_left"] = std::move(default_left);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user