175 lines
5.6 KiB
Python
175 lines
5.6 KiB
Python
'''Demonstration for parsing JSON tree model file generated by XGBoost. The
|
|
support is experimental, output schema is subject to change in the future.
|
|
'''
|
|
import json
|
|
import argparse
|
|
|
|
|
|
class Tree:
|
|
'''A tree built by XGBoost.'''
|
|
# Index into node array
|
|
_left = 0
|
|
_right = 1
|
|
_parent = 2
|
|
_ind = 3
|
|
_cond = 4
|
|
_default_left = 5
|
|
# Index into stat array
|
|
_loss_chg = 0
|
|
_sum_hess = 1
|
|
_base_weight = 2
|
|
|
|
def __init__(self, tree_id: int, nodes, stats):
|
|
self.tree_id = tree_id
|
|
self.nodes = nodes
|
|
self.stats = stats
|
|
|
|
def loss_change(self, node_id: int):
|
|
'''Loss gain of a node.'''
|
|
return self.stats[node_id][self._loss_chg]
|
|
|
|
def sum_hessian(self, node_id: int):
|
|
'''Sum Hessian of a node.'''
|
|
return self.stats[node_id][self._sum_hess]
|
|
|
|
def base_weight(self, node_id: int):
|
|
'''Base weight of a node.'''
|
|
return self.stats[node_id][self._base_weight]
|
|
|
|
def split_index(self, node_id: int):
|
|
'''Split feature index of node.'''
|
|
return self.nodes[node_id][self._ind]
|
|
|
|
def split_condition(self, node_id: int):
|
|
'''Split value of a node.'''
|
|
return self.nodes[node_id][self._cond]
|
|
|
|
def parent(self, node_id: int):
|
|
'''Parent ID of a node.'''
|
|
return self.nodes[node_id][self._parent]
|
|
|
|
def left_child(self, node_id: int):
|
|
'''Left child ID of a node.'''
|
|
return self.nodes[node_id][self._left]
|
|
|
|
def right_child(self, node_id: int):
|
|
'''Right child ID of a node.'''
|
|
return self.nodes[node_id][self._right]
|
|
|
|
def is_leaf(self, node_id: int):
|
|
'''Whether a node is leaf.'''
|
|
return self.nodes[node_id][self._left] == -1
|
|
|
|
def is_deleted(self, node_id: int):
|
|
'''Whether a node is deleted.'''
|
|
# std::numeric_limits<uint32_t>::max()
|
|
return self.nodes[node_id][self._ind] == 4294967295
|
|
|
|
def __str__(self):
|
|
stacks = [0]
|
|
nodes = []
|
|
while stacks:
|
|
node = {}
|
|
nid = stacks.pop()
|
|
|
|
node['node id'] = nid
|
|
node['gain'] = self.loss_change(nid)
|
|
node['cover'] = self.sum_hessian(nid)
|
|
nodes.append(node)
|
|
|
|
if not self.is_leaf(nid) and not self.is_deleted(nid):
|
|
left = self.left_child(nid)
|
|
right = self.right_child(nid)
|
|
stacks.append(left)
|
|
stacks.append(right)
|
|
|
|
string = '\n'.join(map(lambda x: ' ' + str(x), nodes))
|
|
return string
|
|
|
|
|
|
class Model:
|
|
'''Gradient boosted tree model.'''
|
|
def __init__(self, model: dict):
|
|
'''Construct the Model from JSON object.
|
|
|
|
parameters
|
|
----------
|
|
m: A dictionary loaded by json
|
|
'''
|
|
# Basic property of a model
|
|
self.learner_model_shape = model['learner']['learner_model_param']
|
|
self.num_output_group = int(self.learner_model_shape['num_class'])
|
|
self.num_feature = int(self.learner_model_shape['num_feature'])
|
|
self.base_score = float(self.learner_model_shape['base_score'])
|
|
# A field encoding which output group a tree belongs
|
|
self.tree_info = model['learner']['gradient_booster']['model'][
|
|
'tree_info']
|
|
|
|
model_shape = model['learner']['gradient_booster']['model'][
|
|
'gbtree_model_param']
|
|
|
|
# JSON representation of trees
|
|
j_trees = model['learner']['gradient_booster']['model']['trees']
|
|
|
|
# Load the trees
|
|
self.num_trees = int(model_shape['num_trees'])
|
|
self.leaf_size = int(model_shape['size_leaf_vector'])
|
|
# Right now XGBoost doesn't support vector leaf yet
|
|
assert self.leaf_size == 0, str(self.leaf_size)
|
|
|
|
trees = []
|
|
for i in range(self.num_trees):
|
|
tree = j_trees[i]
|
|
tree_id = int(tree['id'])
|
|
assert tree_id == i, (tree_id, i)
|
|
# properties
|
|
left_children = tree['left_children']
|
|
right_children = tree['right_children']
|
|
parents = tree['parents']
|
|
split_conditions = tree['split_conditions']
|
|
split_indices = tree['split_indices']
|
|
default_left = tree['default_left']
|
|
# stats
|
|
base_weights = tree['base_weights']
|
|
loss_changes = tree['loss_changes']
|
|
sum_hessian = tree['sum_hessian']
|
|
|
|
stats = []
|
|
nodes = []
|
|
# We resemble the structure used inside XGBoost, which is similar
|
|
# to adjacency list.
|
|
for node_id in range(len(left_children)):
|
|
nodes.append([
|
|
left_children[node_id], right_children[node_id],
|
|
parents[node_id], split_indices[node_id],
|
|
split_conditions[node_id], default_left[node_id]
|
|
])
|
|
stats.append([
|
|
loss_changes[node_id], sum_hessian[node_id],
|
|
base_weights[node_id]
|
|
])
|
|
|
|
tree = Tree(tree_id, nodes, stats)
|
|
trees.append(tree)
|
|
|
|
self.trees = trees
|
|
|
|
def print_model(self):
|
|
for i, tree in enumerate(self.trees):
|
|
print('tree_id:', i)
|
|
print(tree)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(
|
|
description='Demonstration for loading and printing XGBoost model.')
|
|
parser.add_argument('--model',
|
|
type=str,
|
|
required=True,
|
|
help='Path to JSON model file.')
|
|
args = parser.parse_args()
|
|
with open(args.model, 'r') as fd:
|
|
model = json.load(fd)
|
|
model = Model(model)
|
|
model.print_model()
|