Added trees_to_df() method for Booster class (#4153)
* add test_parse_tree.py to tests/python * Fix formatting * Fix pylint error * Ignore 'no member' error for Pandas dataframe
This commit is contained in:
committed by
Philip Hyunsu Cho
parent
1b7405f688
commit
74009afcac
@@ -913,6 +913,7 @@ class DMatrix(object):
|
||||
|
||||
|
||||
class Booster(object):
|
||||
# pylint: disable=too-many-public-methods
|
||||
"""A Booster of XGBoost.
|
||||
|
||||
Booster is the model of xgboost, that contains low level routines for
|
||||
@@ -1578,6 +1579,91 @@ class Booster(object):
|
||||
|
||||
return gmap
|
||||
|
||||
def trees_to_dataframe(self, fmap=''):
|
||||
"""Parse a boosted tree model text dump into a pandas DataFrame structure.
|
||||
|
||||
This feature is only defined when the decision tree model is chosen as base
|
||||
learner (`booster in {gbtree, dart}`). It is not defined for other base learner
|
||||
types, such as linear learners (`booster=gblinear`).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fmap: str (optional)
|
||||
The name of feature map file.
|
||||
"""
|
||||
# pylint: disable=too-many-locals
|
||||
if not PANDAS_INSTALLED:
|
||||
raise Exception(('pandas must be available to use this method.'
|
||||
'Install pandas before calling again.'))
|
||||
|
||||
if getattr(self, 'booster', None) is not None and self.booster not in {'gbtree', 'dart'}:
|
||||
raise ValueError('This method is not defined for Booster type {}'
|
||||
.format(self.booster))
|
||||
|
||||
tree_ids = []
|
||||
node_ids = []
|
||||
fids = []
|
||||
splits = []
|
||||
y_directs = []
|
||||
n_directs = []
|
||||
missings = []
|
||||
gains = []
|
||||
covers = []
|
||||
|
||||
trees = self.get_dump(fmap, with_stats=True)
|
||||
for i, tree in enumerate(trees):
|
||||
for line in tree.split('\n'):
|
||||
arr = line.split('[')
|
||||
# Leaf node
|
||||
if len(arr) == 1:
|
||||
# Last element of line.split is an empy string
|
||||
if arr == ['']:
|
||||
continue
|
||||
# parse string
|
||||
parse = arr[0].split(':')
|
||||
stats = re.split('=|,', parse[1])
|
||||
|
||||
# append to lists
|
||||
tree_ids.append(i)
|
||||
node_ids.append(int(re.findall(r'\b\d+\b', parse[0])[0]))
|
||||
fids.append('Leaf')
|
||||
splits.append(float('NAN'))
|
||||
y_directs.append(float('NAN'))
|
||||
n_directs.append(float('NAN'))
|
||||
missings.append(float('NAN'))
|
||||
gains.append(float(stats[1]))
|
||||
covers.append(float(stats[3]))
|
||||
# Not a Leaf Node
|
||||
else:
|
||||
# parse string
|
||||
fid = arr[1].split(']')
|
||||
parse = fid[0].split('<')
|
||||
stats = re.split('=|,', fid[1])
|
||||
|
||||
# append to lists
|
||||
tree_ids.append(i)
|
||||
node_ids.append(int(re.findall(r'\b\d+\b', arr[0])[0]))
|
||||
fids.append(parse[0])
|
||||
splits.append(float(parse[1]))
|
||||
str_i = str(i)
|
||||
y_directs.append(str_i + '-' + stats[1])
|
||||
n_directs.append(str_i + '-' + stats[3])
|
||||
missings.append(str_i + '-' + stats[5])
|
||||
gains.append(float(stats[7]))
|
||||
covers.append(float(stats[9]))
|
||||
|
||||
ids = [str(t_id) + '-' + str(n_id) for t_id, n_id in zip(tree_ids, node_ids)]
|
||||
df = DataFrame({'Tree': tree_ids, 'Node': node_ids, 'ID': ids,
|
||||
'Feature': fids, 'Split': splits, 'Yes': y_directs,
|
||||
'No': n_directs, 'Missing': missings, 'Gain': gains,
|
||||
'Cover': covers})
|
||||
|
||||
if callable(getattr(df, 'sort_values', None)):
|
||||
# pylint: disable=no-member
|
||||
return df.sort_values(['Tree', 'Node']).reset_index(drop=True)
|
||||
# pylint: disable=no-member
|
||||
return df.sort(['Tree', 'Node']).reset_index(drop=True)
|
||||
|
||||
def _validate_features(self, data):
|
||||
"""
|
||||
Validate Booster and data's feature_names are identical.
|
||||
|
||||
Reference in New Issue
Block a user