Added trees_to_df() method for Booster class (#4153)
* add test_parse_tree.py to tests/python * Fix formatting * Fix pylint error * Ignore 'no member' error for Pandas dataframe
This commit is contained in:
parent
1b7405f688
commit
74009afcac
@ -913,6 +913,7 @@ class DMatrix(object):
|
|||||||
|
|
||||||
|
|
||||||
class Booster(object):
|
class Booster(object):
|
||||||
|
# pylint: disable=too-many-public-methods
|
||||||
"""A Booster of XGBoost.
|
"""A Booster of XGBoost.
|
||||||
|
|
||||||
Booster is the model of xgboost, that contains low level routines for
|
Booster is the model of xgboost, that contains low level routines for
|
||||||
@ -1578,6 +1579,91 @@ class Booster(object):
|
|||||||
|
|
||||||
return gmap
|
return gmap
|
||||||
|
|
||||||
|
def trees_to_dataframe(self, fmap=''):
|
||||||
|
"""Parse a boosted tree model text dump into a pandas DataFrame structure.
|
||||||
|
|
||||||
|
This feature is only defined when the decision tree model is chosen as base
|
||||||
|
learner (`booster in {gbtree, dart}`). It is not defined for other base learner
|
||||||
|
types, such as linear learners (`booster=gblinear`).
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
fmap: str (optional)
|
||||||
|
The name of feature map file.
|
||||||
|
"""
|
||||||
|
# pylint: disable=too-many-locals
|
||||||
|
if not PANDAS_INSTALLED:
|
||||||
|
raise Exception(('pandas must be available to use this method.'
|
||||||
|
'Install pandas before calling again.'))
|
||||||
|
|
||||||
|
if getattr(self, 'booster', None) is not None and self.booster not in {'gbtree', 'dart'}:
|
||||||
|
raise ValueError('This method is not defined for Booster type {}'
|
||||||
|
.format(self.booster))
|
||||||
|
|
||||||
|
tree_ids = []
|
||||||
|
node_ids = []
|
||||||
|
fids = []
|
||||||
|
splits = []
|
||||||
|
y_directs = []
|
||||||
|
n_directs = []
|
||||||
|
missings = []
|
||||||
|
gains = []
|
||||||
|
covers = []
|
||||||
|
|
||||||
|
trees = self.get_dump(fmap, with_stats=True)
|
||||||
|
for i, tree in enumerate(trees):
|
||||||
|
for line in tree.split('\n'):
|
||||||
|
arr = line.split('[')
|
||||||
|
# Leaf node
|
||||||
|
if len(arr) == 1:
|
||||||
|
# Last element of line.split is an empy string
|
||||||
|
if arr == ['']:
|
||||||
|
continue
|
||||||
|
# parse string
|
||||||
|
parse = arr[0].split(':')
|
||||||
|
stats = re.split('=|,', parse[1])
|
||||||
|
|
||||||
|
# append to lists
|
||||||
|
tree_ids.append(i)
|
||||||
|
node_ids.append(int(re.findall(r'\b\d+\b', parse[0])[0]))
|
||||||
|
fids.append('Leaf')
|
||||||
|
splits.append(float('NAN'))
|
||||||
|
y_directs.append(float('NAN'))
|
||||||
|
n_directs.append(float('NAN'))
|
||||||
|
missings.append(float('NAN'))
|
||||||
|
gains.append(float(stats[1]))
|
||||||
|
covers.append(float(stats[3]))
|
||||||
|
# Not a Leaf Node
|
||||||
|
else:
|
||||||
|
# parse string
|
||||||
|
fid = arr[1].split(']')
|
||||||
|
parse = fid[0].split('<')
|
||||||
|
stats = re.split('=|,', fid[1])
|
||||||
|
|
||||||
|
# append to lists
|
||||||
|
tree_ids.append(i)
|
||||||
|
node_ids.append(int(re.findall(r'\b\d+\b', arr[0])[0]))
|
||||||
|
fids.append(parse[0])
|
||||||
|
splits.append(float(parse[1]))
|
||||||
|
str_i = str(i)
|
||||||
|
y_directs.append(str_i + '-' + stats[1])
|
||||||
|
n_directs.append(str_i + '-' + stats[3])
|
||||||
|
missings.append(str_i + '-' + stats[5])
|
||||||
|
gains.append(float(stats[7]))
|
||||||
|
covers.append(float(stats[9]))
|
||||||
|
|
||||||
|
ids = [str(t_id) + '-' + str(n_id) for t_id, n_id in zip(tree_ids, node_ids)]
|
||||||
|
df = DataFrame({'Tree': tree_ids, 'Node': node_ids, 'ID': ids,
|
||||||
|
'Feature': fids, 'Split': splits, 'Yes': y_directs,
|
||||||
|
'No': n_directs, 'Missing': missings, 'Gain': gains,
|
||||||
|
'Cover': covers})
|
||||||
|
|
||||||
|
if callable(getattr(df, 'sort_values', None)):
|
||||||
|
# pylint: disable=no-member
|
||||||
|
return df.sort_values(['Tree', 'Node']).reset_index(drop=True)
|
||||||
|
# pylint: disable=no-member
|
||||||
|
return df.sort(['Tree', 'Node']).reset_index(drop=True)
|
||||||
|
|
||||||
def _validate_features(self, data):
|
def _validate_features(self, data):
|
||||||
"""
|
"""
|
||||||
Validate Booster and data's feature_names are identical.
|
Validate Booster and data's feature_names are identical.
|
||||||
|
|||||||
50
tests/python/test_parse_tree.py
Normal file
50
tests/python/test_parse_tree.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
import xgboost as xgb
|
||||||
|
import unittest
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import testing as tm
|
||||||
|
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.skipif(**tm.no_pandas())
|
||||||
|
|
||||||
|
|
||||||
|
dpath = 'demo/data/'
|
||||||
|
rng = np.random.RandomState(1994)
|
||||||
|
|
||||||
|
|
||||||
|
class TestTreesToDataFrame(unittest.TestCase):
|
||||||
|
|
||||||
|
def build_model(self, max_depth, num_round):
|
||||||
|
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||||
|
param = {'max_depth': max_depth, 'objective': 'binary:logistic', 'silent': False}
|
||||||
|
num_round = num_round
|
||||||
|
bst = xgb.train(param, dtrain, num_round)
|
||||||
|
return bst
|
||||||
|
|
||||||
|
def parse_dumped_model(self, booster, item_to_get, splitter):
|
||||||
|
item_to_get += '='
|
||||||
|
txt_dump = booster.get_dump(with_stats=True)
|
||||||
|
tree_list = [tree.split('/n') for tree in txt_dump]
|
||||||
|
split_trees = [tree[0].split(item_to_get)[1:] for tree in tree_list]
|
||||||
|
res = sum([float(line.split(splitter)[0])
|
||||||
|
for tree in split_trees for line in tree])
|
||||||
|
return res
|
||||||
|
|
||||||
|
def test_trees_to_dataframe(self):
|
||||||
|
bst = self.build_model(max_depth=5, num_round=10)
|
||||||
|
gain_from_dump = self.parse_dumped_model(booster=bst,
|
||||||
|
item_to_get='gain',
|
||||||
|
splitter=',')
|
||||||
|
cover_from_dump = self.parse_dumped_model(booster=bst,
|
||||||
|
item_to_get='cover',
|
||||||
|
splitter='\n')
|
||||||
|
# method being tested
|
||||||
|
df = bst.trees_to_dataframe()
|
||||||
|
|
||||||
|
# test for equality of gains
|
||||||
|
gain_from_df = df[df.Feature != 'Leaf'][['Gain']].sum()
|
||||||
|
assert np.allclose(gain_from_dump, gain_from_df)
|
||||||
|
|
||||||
|
# test for equality of covers
|
||||||
|
cover_from_df = df.Cover.sum()
|
||||||
|
assert np.allclose(cover_from_dump, cover_from_df)
|
||||||
Loading…
x
Reference in New Issue
Block a user