[py] split value histograms

This commit is contained in:
Faron 2016-03-20 15:08:15 +01:00
parent 6691d5c3f4
commit cf607e2448
2 changed files with 63 additions and 1 deletions

View File

@ -3,16 +3,18 @@
"""Core XGBoost Library."""
from __future__ import absolute_import
import sys
import os
import ctypes
import collections
import re
import numpy as np
import scipy.sparse
from .libpath import find_lib_path
from .compat import STRING_TYPES, PY3, DataFrame, py_str
from .compat import STRING_TYPES, PY3, DataFrame, py_str, PANDAS_INSTALLED
class XGBoostError(Exception):
@ -1058,3 +1060,44 @@ class Booster(object):
raise ValueError(msg.format(self.feature_names,
data.feature_names))
def get_split_value_histogram(self, feature, fmap='', bins=None, as_pandas=True):
"""Get split value histogram of a feature
Parameters
----------
feature: str
The name of the feature.
fmap: str (optional)
The name of feature map file.
bin: int, default None
The maximum number of bins.
Number of bins equals number of unique split values n_unique, if bins == None or bins > n_unique.
as_pandas : bool, default True
Return pd.DataFrame when pandas is installed.
If False or pandas is not installed, return numpy ndarray.
Returns
-------
a histogram of used splitting values for the specified feature either as numpy array or pandas DataFrame.
"""
xgdump = self.get_dump(fmap=fmap)
values = []
regexp = re.compile("\[{0}<([\d.Ee+-]+)\]".format(feature))
for i in range(len(xgdump)):
m = re.findall(regexp, xgdump[i])
values.extend(map(float, m))
n_unique = np.unique(values).shape[0]
bins = max(min(n_unique, bins) if bins is not None else n_unique, 1)
nph = np.histogram(values, bins=bins)
nph = np.column_stack((nph[1][1:], nph[0]))
nph = nph[nph[:, 1] > 0]
if as_pandas and PANDAS_INSTALLED:
return DataFrame(nph, columns=['SplitValue', 'Count'])
elif as_pandas and not PANDAS_INSTALLED:
sys.stderr.write("Returning histogram as ndarray (as_pandas == True, but pandas is not installed).")
return nph
else:
return nph

View File

@ -240,3 +240,22 @@ def test_sklearn_nfolds_cv():
cv3 = xgb.cv(params, dm, num_boost_round=10, nfold=nfolds, stratified=True, seed=seed)
assert cv1.shape[0] == cv2.shape[0] and cv2.shape[0] == cv3.shape[0]
assert cv2.iloc[-1, 0] == cv3.iloc[-1, 0]
def test_split_value_histograms():
digits_2class = load_digits(2)
X = digits_2class['data']
y = digits_2class['target']
dm = xgb.DMatrix(X, label=y)
params = {'max_depth': 6, 'eta': 0.01, 'silent': 1, 'objective': 'binary:logistic'}
gbdt = xgb.train(params, dm, num_boost_round=10)
assert gbdt.get_split_value_histogram("not_there", as_pandas=True).shape[0] == 0
assert gbdt.get_split_value_histogram("not_there", as_pandas=False).shape[0] == 0
assert gbdt.get_split_value_histogram("f28", bins=0).shape[0] == 1
assert gbdt.get_split_value_histogram("f28", bins=1).shape[0] == 1
assert gbdt.get_split_value_histogram("f28", bins=2).shape[0] == 2
assert gbdt.get_split_value_histogram("f28", bins=5).shape[0] == 2
assert gbdt.get_split_value_histogram("f28", bins=None).shape[0] == 2