[py] split value histograms

This commit is contained in:
Faron 2016-03-20 15:08:15 +01:00
parent 6691d5c3f4
commit cf607e2448
2 changed files with 63 additions and 1 deletions

View File

@ -3,16 +3,18 @@
"""Core XGBoost Library.""" """Core XGBoost Library."""
from __future__ import absolute_import from __future__ import absolute_import
import sys
import os import os
import ctypes import ctypes
import collections import collections
import re
import numpy as np import numpy as np
import scipy.sparse import scipy.sparse
from .libpath import find_lib_path from .libpath import find_lib_path
from .compat import STRING_TYPES, PY3, DataFrame, py_str from .compat import STRING_TYPES, PY3, DataFrame, py_str, PANDAS_INSTALLED
class XGBoostError(Exception): class XGBoostError(Exception):
@ -1058,3 +1060,44 @@ class Booster(object):
raise ValueError(msg.format(self.feature_names, raise ValueError(msg.format(self.feature_names,
data.feature_names)) data.feature_names))
def get_split_value_histogram(self, feature, fmap='', bins=None, as_pandas=True):
"""Get split value histogram of a feature
Parameters
----------
feature: str
The name of the feature.
fmap: str (optional)
The name of feature map file.
bin: int, default None
The maximum number of bins.
Number of bins equals number of unique split values n_unique, if bins == None or bins > n_unique.
as_pandas : bool, default True
Return pd.DataFrame when pandas is installed.
If False or pandas is not installed, return numpy ndarray.
Returns
-------
a histogram of used splitting values for the specified feature either as numpy array or pandas DataFrame.
"""
xgdump = self.get_dump(fmap=fmap)
values = []
regexp = re.compile("\[{0}<([\d.Ee+-]+)\]".format(feature))
for i in range(len(xgdump)):
m = re.findall(regexp, xgdump[i])
values.extend(map(float, m))
n_unique = np.unique(values).shape[0]
bins = max(min(n_unique, bins) if bins is not None else n_unique, 1)
nph = np.histogram(values, bins=bins)
nph = np.column_stack((nph[1][1:], nph[0]))
nph = nph[nph[:, 1] > 0]
if as_pandas and PANDAS_INSTALLED:
return DataFrame(nph, columns=['SplitValue', 'Count'])
elif as_pandas and not PANDAS_INSTALLED:
sys.stderr.write("Returning histogram as ndarray (as_pandas == True, but pandas is not installed).")
return nph
else:
return nph

View File

@ -240,3 +240,22 @@ def test_sklearn_nfolds_cv():
cv3 = xgb.cv(params, dm, num_boost_round=10, nfold=nfolds, stratified=True, seed=seed) cv3 = xgb.cv(params, dm, num_boost_round=10, nfold=nfolds, stratified=True, seed=seed)
assert cv1.shape[0] == cv2.shape[0] and cv2.shape[0] == cv3.shape[0] assert cv1.shape[0] == cv2.shape[0] and cv2.shape[0] == cv3.shape[0]
assert cv2.iloc[-1, 0] == cv3.iloc[-1, 0] assert cv2.iloc[-1, 0] == cv3.iloc[-1, 0]
def test_split_value_histograms():
digits_2class = load_digits(2)
X = digits_2class['data']
y = digits_2class['target']
dm = xgb.DMatrix(X, label=y)
params = {'max_depth': 6, 'eta': 0.01, 'silent': 1, 'objective': 'binary:logistic'}
gbdt = xgb.train(params, dm, num_boost_round=10)
assert gbdt.get_split_value_histogram("not_there", as_pandas=True).shape[0] == 0
assert gbdt.get_split_value_histogram("not_there", as_pandas=False).shape[0] == 0
assert gbdt.get_split_value_histogram("f28", bins=0).shape[0] == 1
assert gbdt.get_split_value_histogram("f28", bins=1).shape[0] == 1
assert gbdt.get_split_value_histogram("f28", bins=2).shape[0] == 2
assert gbdt.get_split_value_histogram("f28", bins=5).shape[0] == 2
assert gbdt.get_split_value_histogram("f28", bins=None).shape[0] == 2