Added trees_to_df() method for Booster class (#4153)
* add test_parse_tree.py to tests/python * Fix formatting * Fix pylint error * Ignore 'no member' error for Pandas dataframe
This commit is contained in:
parent
1b7405f688
commit
74009afcac
@ -913,6 +913,7 @@ class DMatrix(object):
|
||||
|
||||
|
||||
class Booster(object):
|
||||
# pylint: disable=too-many-public-methods
|
||||
"""A Booster of XGBoost.
|
||||
|
||||
Booster is the model of xgboost, that contains low level routines for
|
||||
@ -1578,6 +1579,91 @@ class Booster(object):
|
||||
|
||||
return gmap
|
||||
|
||||
def trees_to_dataframe(self, fmap=''):
|
||||
"""Parse a boosted tree model text dump into a pandas DataFrame structure.
|
||||
|
||||
This feature is only defined when the decision tree model is chosen as base
|
||||
learner (`booster in {gbtree, dart}`). It is not defined for other base learner
|
||||
types, such as linear learners (`booster=gblinear`).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fmap: str (optional)
|
||||
The name of feature map file.
|
||||
"""
|
||||
# pylint: disable=too-many-locals
|
||||
if not PANDAS_INSTALLED:
|
||||
raise Exception(('pandas must be available to use this method.'
|
||||
'Install pandas before calling again.'))
|
||||
|
||||
if getattr(self, 'booster', None) is not None and self.booster not in {'gbtree', 'dart'}:
|
||||
raise ValueError('This method is not defined for Booster type {}'
|
||||
.format(self.booster))
|
||||
|
||||
tree_ids = []
|
||||
node_ids = []
|
||||
fids = []
|
||||
splits = []
|
||||
y_directs = []
|
||||
n_directs = []
|
||||
missings = []
|
||||
gains = []
|
||||
covers = []
|
||||
|
||||
trees = self.get_dump(fmap, with_stats=True)
|
||||
for i, tree in enumerate(trees):
|
||||
for line in tree.split('\n'):
|
||||
arr = line.split('[')
|
||||
# Leaf node
|
||||
if len(arr) == 1:
|
||||
# Last element of line.split is an empy string
|
||||
if arr == ['']:
|
||||
continue
|
||||
# parse string
|
||||
parse = arr[0].split(':')
|
||||
stats = re.split('=|,', parse[1])
|
||||
|
||||
# append to lists
|
||||
tree_ids.append(i)
|
||||
node_ids.append(int(re.findall(r'\b\d+\b', parse[0])[0]))
|
||||
fids.append('Leaf')
|
||||
splits.append(float('NAN'))
|
||||
y_directs.append(float('NAN'))
|
||||
n_directs.append(float('NAN'))
|
||||
missings.append(float('NAN'))
|
||||
gains.append(float(stats[1]))
|
||||
covers.append(float(stats[3]))
|
||||
# Not a Leaf Node
|
||||
else:
|
||||
# parse string
|
||||
fid = arr[1].split(']')
|
||||
parse = fid[0].split('<')
|
||||
stats = re.split('=|,', fid[1])
|
||||
|
||||
# append to lists
|
||||
tree_ids.append(i)
|
||||
node_ids.append(int(re.findall(r'\b\d+\b', arr[0])[0]))
|
||||
fids.append(parse[0])
|
||||
splits.append(float(parse[1]))
|
||||
str_i = str(i)
|
||||
y_directs.append(str_i + '-' + stats[1])
|
||||
n_directs.append(str_i + '-' + stats[3])
|
||||
missings.append(str_i + '-' + stats[5])
|
||||
gains.append(float(stats[7]))
|
||||
covers.append(float(stats[9]))
|
||||
|
||||
ids = [str(t_id) + '-' + str(n_id) for t_id, n_id in zip(tree_ids, node_ids)]
|
||||
df = DataFrame({'Tree': tree_ids, 'Node': node_ids, 'ID': ids,
|
||||
'Feature': fids, 'Split': splits, 'Yes': y_directs,
|
||||
'No': n_directs, 'Missing': missings, 'Gain': gains,
|
||||
'Cover': covers})
|
||||
|
||||
if callable(getattr(df, 'sort_values', None)):
|
||||
# pylint: disable=no-member
|
||||
return df.sort_values(['Tree', 'Node']).reset_index(drop=True)
|
||||
# pylint: disable=no-member
|
||||
return df.sort(['Tree', 'Node']).reset_index(drop=True)
|
||||
|
||||
def _validate_features(self, data):
|
||||
"""
|
||||
Validate Booster and data's feature_names are identical.
|
||||
|
||||
50
tests/python/test_parse_tree.py
Normal file
50
tests/python/test_parse_tree.py
Normal file
@ -0,0 +1,50 @@
|
||||
import xgboost as xgb
|
||||
import unittest
|
||||
import numpy as np
|
||||
import pytest
|
||||
import testing as tm
|
||||
|
||||
|
||||
pytestmark = pytest.mark.skipif(**tm.no_pandas())
|
||||
|
||||
|
||||
dpath = 'demo/data/'
|
||||
rng = np.random.RandomState(1994)
|
||||
|
||||
|
||||
class TestTreesToDataFrame(unittest.TestCase):
|
||||
|
||||
def build_model(self, max_depth, num_round):
|
||||
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||
param = {'max_depth': max_depth, 'objective': 'binary:logistic', 'silent': False}
|
||||
num_round = num_round
|
||||
bst = xgb.train(param, dtrain, num_round)
|
||||
return bst
|
||||
|
||||
def parse_dumped_model(self, booster, item_to_get, splitter):
|
||||
item_to_get += '='
|
||||
txt_dump = booster.get_dump(with_stats=True)
|
||||
tree_list = [tree.split('/n') for tree in txt_dump]
|
||||
split_trees = [tree[0].split(item_to_get)[1:] for tree in tree_list]
|
||||
res = sum([float(line.split(splitter)[0])
|
||||
for tree in split_trees for line in tree])
|
||||
return res
|
||||
|
||||
def test_trees_to_dataframe(self):
|
||||
bst = self.build_model(max_depth=5, num_round=10)
|
||||
gain_from_dump = self.parse_dumped_model(booster=bst,
|
||||
item_to_get='gain',
|
||||
splitter=',')
|
||||
cover_from_dump = self.parse_dumped_model(booster=bst,
|
||||
item_to_get='cover',
|
||||
splitter='\n')
|
||||
# method being tested
|
||||
df = bst.trees_to_dataframe()
|
||||
|
||||
# test for equality of gains
|
||||
gain_from_df = df[df.Feature != 'Leaf'][['Gain']].sum()
|
||||
assert np.allclose(gain_from_dump, gain_from_df)
|
||||
|
||||
# test for equality of covers
|
||||
cover_from_df = df.Cover.sum()
|
||||
assert np.allclose(cover_from_dump, cover_from_df)
|
||||
Loading…
x
Reference in New Issue
Block a user