[Breaking] Don't save leaf child count in JSON. (#6094)
The field is deprecated and not used anywhere in XGBoost.
This commit is contained in:
parent
5994f3b14c
commit
e5d40b39cd
@ -18,7 +18,6 @@ class Tree:
|
|||||||
_loss_chg = 0
|
_loss_chg = 0
|
||||||
_sum_hess = 1
|
_sum_hess = 1
|
||||||
_base_weight = 2
|
_base_weight = 2
|
||||||
_child_cnt = 3
|
|
||||||
|
|
||||||
def __init__(self, tree_id: int, nodes, stats):
|
def __init__(self, tree_id: int, nodes, stats):
|
||||||
self.tree_id = tree_id
|
self.tree_id = tree_id
|
||||||
@ -37,10 +36,6 @@ class Tree:
|
|||||||
'''Base weight of a node.'''
|
'''Base weight of a node.'''
|
||||||
return self.stats[node_id][self._base_weight]
|
return self.stats[node_id][self._base_weight]
|
||||||
|
|
||||||
def num_children(self, node_id: int):
|
|
||||||
'''Number of children of a node.'''
|
|
||||||
return self.stats[node_id][self._child_cnt]
|
|
||||||
|
|
||||||
def split_index(self, node_id: int):
|
def split_index(self, node_id: int):
|
||||||
'''Split feature index of node.'''
|
'''Split feature index of node.'''
|
||||||
return self.nodes[node_id][self._ind]
|
return self.nodes[node_id][self._ind]
|
||||||
@ -138,7 +133,6 @@ class Model:
|
|||||||
base_weights = tree['base_weights']
|
base_weights = tree['base_weights']
|
||||||
loss_changes = tree['loss_changes']
|
loss_changes = tree['loss_changes']
|
||||||
sum_hessian = tree['sum_hessian']
|
sum_hessian = tree['sum_hessian']
|
||||||
leaf_child_counts = tree['leaf_child_counts']
|
|
||||||
|
|
||||||
stats = []
|
stats = []
|
||||||
nodes = []
|
nodes = []
|
||||||
@ -152,7 +146,7 @@ class Model:
|
|||||||
])
|
])
|
||||||
stats.append([
|
stats.append([
|
||||||
loss_changes[node_id], sum_hessian[node_id],
|
loss_changes[node_id], sum_hessian[node_id],
|
||||||
base_weights[node_id], leaf_child_counts[node_id]
|
base_weights[node_id]
|
||||||
])
|
])
|
||||||
|
|
||||||
tree = Tree(tree_id, nodes, stats)
|
tree = Tree(tree_id, nodes, stats)
|
||||||
|
|||||||
@ -58,12 +58,6 @@
|
|||||||
"type": "number"
|
"type": "number"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"leaf_child_counts": {
|
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"type": "integer"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"left_children": {
|
"left_children": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {
|
"items": {
|
||||||
@ -106,7 +100,6 @@
|
|||||||
"loss_changes",
|
"loss_changes",
|
||||||
"sum_hessian",
|
"sum_hessian",
|
||||||
"base_weights",
|
"base_weights",
|
||||||
"leaf_child_counts",
|
|
||||||
"left_children",
|
"left_children",
|
||||||
"right_children",
|
"right_children",
|
||||||
"parents",
|
"parents",
|
||||||
|
|||||||
@ -783,8 +783,6 @@ void RegTree::LoadModel(Json const& in) {
|
|||||||
CHECK_EQ(sum_hessian.size(), n_nodes);
|
CHECK_EQ(sum_hessian.size(), n_nodes);
|
||||||
auto const& base_weights = get<Array const>(in["base_weights"]);
|
auto const& base_weights = get<Array const>(in["base_weights"]);
|
||||||
CHECK_EQ(base_weights.size(), n_nodes);
|
CHECK_EQ(base_weights.size(), n_nodes);
|
||||||
auto const& leaf_child_counts = get<Array const>(in["leaf_child_counts"]);
|
|
||||||
CHECK_EQ(leaf_child_counts.size(), n_nodes);
|
|
||||||
// nodes
|
// nodes
|
||||||
auto const& lefts = get<Array const>(in["left_children"]);
|
auto const& lefts = get<Array const>(in["left_children"]);
|
||||||
CHECK_EQ(lefts.size(), n_nodes);
|
CHECK_EQ(lefts.size(), n_nodes);
|
||||||
@ -822,7 +820,6 @@ void RegTree::LoadModel(Json const& in) {
|
|||||||
s.loss_chg = get<Number const>(loss_changes[i]);
|
s.loss_chg = get<Number const>(loss_changes[i]);
|
||||||
s.sum_hess = get<Number const>(sum_hessian[i]);
|
s.sum_hess = get<Number const>(sum_hessian[i]);
|
||||||
s.base_weight = get<Number const>(base_weights[i]);
|
s.base_weight = get<Number const>(base_weights[i]);
|
||||||
s.leaf_child_cnt = get<Integer const>(leaf_child_counts[i]);
|
|
||||||
|
|
||||||
auto& n = nodes_[i];
|
auto& n = nodes_[i];
|
||||||
bst_node_t left = get<Integer const>(lefts[i]);
|
bst_node_t left = get<Integer const>(lefts[i]);
|
||||||
@ -888,7 +885,6 @@ void RegTree::SaveModel(Json* p_out) const {
|
|||||||
std::vector<Json> loss_changes(n_nodes);
|
std::vector<Json> loss_changes(n_nodes);
|
||||||
std::vector<Json> sum_hessian(n_nodes);
|
std::vector<Json> sum_hessian(n_nodes);
|
||||||
std::vector<Json> base_weights(n_nodes);
|
std::vector<Json> base_weights(n_nodes);
|
||||||
std::vector<Json> leaf_child_counts(n_nodes);
|
|
||||||
|
|
||||||
// nodes
|
// nodes
|
||||||
std::vector<Json> lefts(n_nodes);
|
std::vector<Json> lefts(n_nodes);
|
||||||
@ -906,7 +902,6 @@ void RegTree::SaveModel(Json* p_out) const {
|
|||||||
loss_changes[i] = s.loss_chg;
|
loss_changes[i] = s.loss_chg;
|
||||||
sum_hessian[i] = s.sum_hess;
|
sum_hessian[i] = s.sum_hess;
|
||||||
base_weights[i] = s.base_weight;
|
base_weights[i] = s.base_weight;
|
||||||
leaf_child_counts[i] = static_cast<I>(s.leaf_child_cnt);
|
|
||||||
|
|
||||||
auto const& n = nodes_[i];
|
auto const& n = nodes_[i];
|
||||||
lefts[i] = static_cast<I>(n.LeftChild());
|
lefts[i] = static_cast<I>(n.LeftChild());
|
||||||
@ -938,7 +933,6 @@ void RegTree::SaveModel(Json* p_out) const {
|
|||||||
out["loss_changes"] = std::move(loss_changes);
|
out["loss_changes"] = std::move(loss_changes);
|
||||||
out["sum_hessian"] = std::move(sum_hessian);
|
out["sum_hessian"] = std::move(sum_hessian);
|
||||||
out["base_weights"] = std::move(base_weights);
|
out["base_weights"] = std::move(base_weights);
|
||||||
out["leaf_child_counts"] = std::move(leaf_child_counts);
|
|
||||||
|
|
||||||
out["left_children"] = std::move(lefts);
|
out["left_children"] = std::move(lefts);
|
||||||
out["right_children"] = std::move(rights);
|
out["right_children"] = std::move(rights);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user