[breaking] Change internal model serialization to UBJSON. (#7556)

* Use typed array for models.
* Change the memory snapshot format.
* Add new C API for saving to raw format.
This commit is contained in:
Jiaming Yuan
2022-01-16 02:11:53 +08:00
committed by GitHub
parent 13b0fa4b97
commit a1bcd33a3b
24 changed files with 566 additions and 255 deletions

View File

@@ -1,5 +1,5 @@
/*!
* Copyright 2015-2021 by Contributors
* Copyright 2015-2022 by Contributors
* \file tree_model.cc
* \brief model structure for tree
*/
@@ -893,27 +893,57 @@ void RegTree::Save(dmlc::Stream* fo) const {
}
}
}
// typed array, not boolean
template <typename JT, typename T>
std::enable_if_t<!std::is_same<T, Json>::value && !std::is_same<JT, Boolean>::value, T> GetElem(
std::vector<T> const& arr, size_t i) {
return arr[i];
}
// typed array boolean
template <typename JT, typename T>
std::enable_if_t<!std::is_same<T, Json>::value && std::is_same<T, uint8_t>::value &&
std::is_same<JT, Boolean>::value,
bool>
GetElem(std::vector<T> const& arr, size_t i) {
return arr[i] == 1;
}
// json array
template <typename JT, typename T>
std::enable_if_t<
std::is_same<T, Json>::value,
std::conditional_t<std::is_same<JT, Integer>::value, int64_t,
std::conditional_t<std::is_same<Boolean, JT>::value, bool, float>>>
GetElem(std::vector<T> const& arr, size_t i) {
if (std::is_same<JT, Boolean>::value && !IsA<Boolean>(arr[i])) {
return get<Integer const>(arr[i]) == 1;
}
return get<JT const>(arr[i]);
}
template <bool typed>
void RegTree::LoadCategoricalSplit(Json const& in) {
auto const& categories_segments = get<Array const>(in["categories_segments"]);
auto const& categories_sizes = get<Array const>(in["categories_sizes"]);
auto const& categories_nodes = get<Array const>(in["categories_nodes"]);
auto const& categories = get<Array const>(in["categories"]);
using I64ArrayT = std::conditional_t<typed, I64Array const, Array const>;
using I32ArrayT = std::conditional_t<typed, I32Array const, Array const>;
auto const& categories_segments = get<I64ArrayT>(in["categories_segments"]);
auto const& categories_sizes = get<I64ArrayT>(in["categories_sizes"]);
auto const& categories_nodes = get<I32ArrayT>(in["categories_nodes"]);
auto const& categories = get<I32ArrayT>(in["categories"]);
size_t cnt = 0;
bst_node_t last_cat_node = -1;
if (!categories_nodes.empty()) {
last_cat_node = get<Integer const>(categories_nodes[cnt]);
last_cat_node = GetElem<Integer>(categories_nodes, cnt);
}
for (bst_node_t nidx = 0; nidx < param.num_nodes; ++nidx) {
if (nidx == last_cat_node) {
auto j_begin = get<Integer const>(categories_segments[cnt]);
auto j_end = get<Integer const>(categories_sizes[cnt]) + j_begin;
auto j_begin = GetElem<Integer>(categories_segments, cnt);
auto j_end = GetElem<Integer>(categories_sizes, cnt) + j_begin;
bst_cat_t max_cat{std::numeric_limits<bst_cat_t>::min()};
CHECK_NE(j_end - j_begin, 0) << nidx;
for (auto j = j_begin; j < j_end; ++j) {
auto const &category = get<Integer const>(categories[j]);
auto const& category = GetElem<Integer>(categories, j);
auto cat = common::AsCat(category);
max_cat = std::max(max_cat, cat);
}
@@ -924,7 +954,7 @@ void RegTree::LoadCategoricalSplit(Json const& in) {
std::vector<uint32_t> cat_bits_storage(size, 0);
common::CatBitField cat_bits{common::Span<uint32_t>(cat_bits_storage)};
for (auto j = j_begin; j < j_end; ++j) {
cat_bits.Set(common::AsCat(get<Integer const>(categories[j])));
cat_bits.Set(common::AsCat(GetElem<Integer>(categories, j)));
}
auto begin = split_categories_.size();
@@ -936,9 +966,9 @@ void RegTree::LoadCategoricalSplit(Json const& in) {
++cnt;
if (cnt == categories_nodes.size()) {
last_cat_node = -1;
last_cat_node = -1; // Don't break, we still need to initialize the remaining nodes.
} else {
last_cat_node = get<Integer const>(categories_nodes[cnt]);
last_cat_node = GetElem<Integer>(categories_nodes, cnt);
}
} else {
split_categories_segments_[nidx].beg = categories.size();
@@ -947,104 +977,144 @@ void RegTree::LoadCategoricalSplit(Json const& in) {
}
}
template void RegTree::LoadCategoricalSplit<true>(Json const& in);
template void RegTree::LoadCategoricalSplit<false>(Json const& in);
void RegTree::SaveCategoricalSplit(Json* p_out) const {
auto& out = *p_out;
CHECK_EQ(this->split_types_.size(), param.num_nodes);
CHECK_EQ(this->GetSplitCategoriesPtr().size(), param.num_nodes);
std::vector<Json> categories_segments;
std::vector<Json> categories_sizes;
std::vector<Json> categories;
std::vector<Json> categories_nodes;
I64Array categories_segments;
I64Array categories_sizes;
I32Array categories; // bst_cat_t = int32_t
I32Array categories_nodes; // bst_note_t = int32_t
for (size_t i = 0; i < nodes_.size(); ++i) {
if (this->split_types_[i] == FeatureType::kCategorical) {
categories_nodes.emplace_back(i);
auto begin = categories.size();
categories_segments.emplace_back(static_cast<Integer::Int>(begin));
categories_nodes.GetArray().emplace_back(i);
auto begin = categories.Size();
categories_segments.GetArray().emplace_back(begin);
auto segment = split_categories_segments_[i];
auto node_categories =
this->GetSplitCategories().subspan(segment.beg, segment.size);
auto node_categories = this->GetSplitCategories().subspan(segment.beg, segment.size);
common::KCatBitField const cat_bits(node_categories);
for (size_t i = 0; i < cat_bits.Size(); ++i) {
if (cat_bits.Check(i)) {
categories.emplace_back(static_cast<Integer::Int>(i));
categories.GetArray().emplace_back(i);
}
}
size_t size = categories.size() - begin;
categories_sizes.emplace_back(static_cast<Integer::Int>(size));
size_t size = categories.Size() - begin;
categories_sizes.GetArray().emplace_back(size);
CHECK_NE(size, 0);
}
}
out["categories_segments"] = categories_segments;
out["categories_sizes"] = categories_sizes;
out["categories_nodes"] = categories_nodes;
out["categories"] = categories;
out["categories_segments"] = std::move(categories_segments);
out["categories_sizes"] = std::move(categories_sizes);
out["categories_nodes"] = std::move(categories_nodes);
out["categories"] = std::move(categories);
}
void RegTree::LoadModel(Json const& in) {
FromJson(in["tree_param"], &param);
auto n_nodes = param.num_nodes;
template <bool typed, bool feature_is_64,
typename FloatArrayT = std::conditional_t<typed, F32Array const, Array const>,
typename U8ArrayT = std::conditional_t<typed, U8Array const, Array const>,
typename I32ArrayT = std::conditional_t<typed, I32Array const, Array const>,
typename I64ArrayT = std::conditional_t<typed, I64Array const, Array const>,
typename IndexArrayT = std::conditional_t<feature_is_64, I64ArrayT, I32ArrayT>>
bool LoadModelImpl(Json const& in, TreeParam* param, std::vector<RTreeNodeStat>* p_stats,
std::vector<FeatureType>* p_split_types, std::vector<RegTree::Node>* p_nodes,
std::vector<RegTree::Segment>* p_split_categories_segments) {
auto& stats = *p_stats;
auto& split_types = *p_split_types;
auto& nodes = *p_nodes;
auto& split_categories_segments = *p_split_categories_segments;
FromJson(in["tree_param"], param);
auto n_nodes = param->num_nodes;
CHECK_NE(n_nodes, 0);
// stats
auto const& loss_changes = get<Array const>(in["loss_changes"]);
auto const& loss_changes = get<FloatArrayT>(in["loss_changes"]);
CHECK_EQ(loss_changes.size(), n_nodes);
auto const& sum_hessian = get<Array const>(in["sum_hessian"]);
auto const& sum_hessian = get<FloatArrayT>(in["sum_hessian"]);
CHECK_EQ(sum_hessian.size(), n_nodes);
auto const& base_weights = get<Array const>(in["base_weights"]);
auto const& base_weights = get<FloatArrayT>(in["base_weights"]);
CHECK_EQ(base_weights.size(), n_nodes);
// nodes
auto const& lefts = get<Array const>(in["left_children"]);
auto const& lefts = get<I32ArrayT>(in["left_children"]);
CHECK_EQ(lefts.size(), n_nodes);
auto const& rights = get<Array const>(in["right_children"]);
auto const& rights = get<I32ArrayT>(in["right_children"]);
CHECK_EQ(rights.size(), n_nodes);
auto const& parents = get<Array const>(in["parents"]);
auto const& parents = get<I32ArrayT>(in["parents"]);
CHECK_EQ(parents.size(), n_nodes);
auto const& indices = get<Array const>(in["split_indices"]);
auto const& indices = get<IndexArrayT>(in["split_indices"]);
CHECK_EQ(indices.size(), n_nodes);
auto const& conds = get<Array const>(in["split_conditions"]);
auto const& conds = get<FloatArrayT>(in["split_conditions"]);
CHECK_EQ(conds.size(), n_nodes);
auto const& default_left = get<Array const>(in["default_left"]);
auto const& default_left = get<U8ArrayT>(in["default_left"]);
CHECK_EQ(default_left.size(), n_nodes);
bool has_cat = get<Object const>(in).find("split_type") != get<Object const>(in).cend();
std::vector<Json> split_type;
std::remove_const_t<std::remove_reference_t<decltype(get<U8ArrayT const>(in["split_type"]))>>
split_type;
if (has_cat) {
split_type = get<Array const>(in["split_type"]);
split_type = get<U8ArrayT const>(in["split_type"]);
}
stats_.clear();
nodes_.clear();
stats = std::remove_reference_t<decltype(stats)>(n_nodes);
nodes = std::remove_reference_t<decltype(nodes)>(n_nodes);
split_types = std::remove_reference_t<decltype(split_types)>(n_nodes);
split_categories_segments = std::remove_reference_t<decltype(split_categories_segments)>(n_nodes);
stats_.resize(n_nodes);
nodes_.resize(n_nodes);
split_types_.resize(n_nodes);
split_categories_segments_.resize(n_nodes);
static_assert(std::is_integral<decltype(GetElem<Integer>(lefts, 0))>::value, "");
static_assert(std::is_floating_point<decltype(GetElem<Number>(loss_changes, 0))>::value, "");
CHECK_EQ(n_nodes, split_categories_segments.size());
CHECK_EQ(n_nodes, split_categories_segments_.size());
for (int32_t i = 0; i < n_nodes; ++i) {
auto& s = stats_[i];
s.loss_chg = get<Number const>(loss_changes[i]);
s.sum_hess = get<Number const>(sum_hessian[i]);
s.base_weight = get<Number const>(base_weights[i]);
auto& s = stats[i];
s.loss_chg = GetElem<Number>(loss_changes, i);
s.sum_hess = GetElem<Number>(sum_hessian, i);
s.base_weight = GetElem<Number>(base_weights, i);
auto& n = nodes_[i];
bst_node_t left = get<Integer const>(lefts[i]);
bst_node_t right = get<Integer const>(rights[i]);
bst_node_t parent = get<Integer const>(parents[i]);
bst_feature_t ind = get<Integer const>(indices[i]);
float cond { get<Number const>(conds[i]) };
bool dft_left { get<Boolean const>(default_left[i]) };
n = Node{left, right, parent, ind, cond, dft_left};
auto& n = nodes[i];
bst_node_t left = GetElem<Integer>(lefts, i);
bst_node_t right = GetElem<Integer>(rights, i);
bst_node_t parent = GetElem<Integer>(parents, i);
bst_feature_t ind = GetElem<Integer>(indices, i);
float cond{GetElem<Number>(conds, i)};
bool dft_left{GetElem<Boolean>(default_left, i)};
n = RegTree::Node{left, right, parent, ind, cond, dft_left};
if (has_cat) {
split_types_[i] =
static_cast<FeatureType>(get<Integer const>(split_type[i]));
split_types[i] = static_cast<FeatureType>(GetElem<Integer>(split_type, i));
}
}
return has_cat;
}
void RegTree::LoadModel(Json const& in) {
bool has_cat{false};
bool typed = IsA<F32Array>(in["loss_changes"]);
bool feature_is_64 = IsA<I64Array>(in["split_indices"]);
if (typed && feature_is_64) {
has_cat = LoadModelImpl<true, true>(in, &param, &stats_, &split_types_, &nodes_,
&split_categories_segments_);
} else if (typed && !feature_is_64) {
has_cat = LoadModelImpl<true, false>(in, &param, &stats_, &split_types_, &nodes_,
&split_categories_segments_);
} else if (!typed && feature_is_64) {
has_cat = LoadModelImpl<false, true>(in, &param, &stats_, &split_types_, &nodes_,
&split_categories_segments_);
} else {
has_cat = LoadModelImpl<false, false>(in, &param, &stats_, &split_types_, &nodes_,
&split_categories_segments_);
}
if (has_cat) {
this->LoadCategoricalSplit(in);
if (typed) {
this->LoadCategoricalSplit<true>(in);
} else {
this->LoadCategoricalSplit<false>(in);
}
} else {
this->split_categories_segments_.resize(this->param.num_nodes);
std::fill(split_types_.begin(), split_types_.end(), FeatureType::kNumerical);
@@ -1058,7 +1128,7 @@ void RegTree::LoadModel(Json const& in) {
}
// easier access to [] operator
auto& self = *this;
for (auto nid = 1; nid < n_nodes; ++nid) {
for (auto nid = 1; nid < param.num_nodes; ++nid) {
auto parent = self[nid].Parent();
CHECK_NE(parent, RegTree::kInvalidNodeId);
self[nid].SetParent(self[nid].Parent(), self[parent].LeftChild() == nid);
@@ -1079,39 +1149,51 @@ void RegTree::SaveModel(Json* p_out) const {
CHECK_EQ(param.num_nodes, static_cast<int>(stats_.size()));
out["tree_param"] = ToJson(param);
CHECK_EQ(get<String>(out["tree_param"]["num_nodes"]), std::to_string(param.num_nodes));
using I = Integer::Int;
auto n_nodes = param.num_nodes;
// stats
std::vector<Json> loss_changes(n_nodes);
std::vector<Json> sum_hessian(n_nodes);
std::vector<Json> base_weights(n_nodes);
F32Array loss_changes(n_nodes);
F32Array sum_hessian(n_nodes);
F32Array base_weights(n_nodes);
// nodes
std::vector<Json> lefts(n_nodes);
std::vector<Json> rights(n_nodes);
std::vector<Json> parents(n_nodes);
std::vector<Json> indices(n_nodes);
std::vector<Json> conds(n_nodes);
std::vector<Json> default_left(n_nodes);
std::vector<Json> split_type(n_nodes);
I32Array lefts(n_nodes);
I32Array rights(n_nodes);
I32Array parents(n_nodes);
F32Array conds(n_nodes);
U8Array default_left(n_nodes);
U8Array split_type(n_nodes);
CHECK_EQ(this->split_types_.size(), param.num_nodes);
for (bst_node_t i = 0; i < n_nodes; ++i) {
auto const& s = stats_[i];
loss_changes[i] = s.loss_chg;
sum_hessian[i] = s.sum_hess;
base_weights[i] = s.base_weight;
auto save_tree = [&](auto* p_indices_array) {
auto& indices_array = *p_indices_array;
for (bst_node_t i = 0; i < n_nodes; ++i) {
auto const& s = stats_[i];
loss_changes.Set(i, s.loss_chg);
sum_hessian.Set(i, s.sum_hess);
base_weights.Set(i, s.base_weight);
auto const& n = nodes_[i];
lefts[i] = static_cast<I>(n.LeftChild());
rights[i] = static_cast<I>(n.RightChild());
parents[i] = static_cast<I>(n.Parent());
indices[i] = static_cast<I>(n.SplitIndex());
conds[i] = n.SplitCond();
default_left[i] = n.DefaultLeft();
auto const& n = nodes_[i];
lefts.Set(i, n.LeftChild());
rights.Set(i, n.RightChild());
parents.Set(i, n.Parent());
indices_array.Set(i, n.SplitIndex());
conds.Set(i, n.SplitCond());
default_left.Set(i, static_cast<uint8_t>(!!n.DefaultLeft()));
split_type[i] = static_cast<I>(this->NodeSplitType(i));
split_type.Set(i, static_cast<uint8_t>(this->NodeSplitType(i)));
}
};
if (this->param.num_feature > static_cast<bst_feature_t>(std::numeric_limits<int32_t>::max())) {
I64Array indices_64(n_nodes);
save_tree(&indices_64);
out["split_indices"] = std::move(indices_64);
} else {
I32Array indices_32(n_nodes);
save_tree(&indices_32);
out["split_indices"] = std::move(indices_32);
}
this->SaveCategoricalSplit(&out);
@@ -1124,7 +1206,7 @@ void RegTree::SaveModel(Json* p_out) const {
out["left_children"] = std::move(lefts);
out["right_children"] = std::move(rights);
out["parents"] = std::move(parents);
out["split_indices"] = std::move(indices);
out["split_conditions"] = std::move(conds);
out["default_left"] = std::move(default_left);
}