[py] split value histograms
This commit is contained in:
parent
6691d5c3f4
commit
cf607e2448
@ -3,16 +3,18 @@
|
||||
"""Core XGBoost Library."""
|
||||
from __future__ import absolute_import
|
||||
|
||||
import sys
|
||||
import os
|
||||
import ctypes
|
||||
import collections
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import scipy.sparse
|
||||
|
||||
from .libpath import find_lib_path
|
||||
|
||||
from .compat import STRING_TYPES, PY3, DataFrame, py_str
|
||||
from .compat import STRING_TYPES, PY3, DataFrame, py_str, PANDAS_INSTALLED
|
||||
|
||||
|
||||
class XGBoostError(Exception):
|
||||
@ -1058,3 +1060,44 @@ class Booster(object):
|
||||
|
||||
raise ValueError(msg.format(self.feature_names,
|
||||
data.feature_names))
|
||||
|
||||
def get_split_value_histogram(self, feature, fmap='', bins=None, as_pandas=True):
|
||||
"""Get split value histogram of a feature
|
||||
Parameters
|
||||
----------
|
||||
feature: str
|
||||
The name of the feature.
|
||||
fmap: str (optional)
|
||||
The name of feature map file.
|
||||
bin: int, default None
|
||||
The maximum number of bins.
|
||||
Number of bins equals number of unique split values n_unique, if bins == None or bins > n_unique.
|
||||
as_pandas : bool, default True
|
||||
Return pd.DataFrame when pandas is installed.
|
||||
If False or pandas is not installed, return numpy ndarray.
|
||||
|
||||
Returns
|
||||
-------
|
||||
a histogram of used splitting values for the specified feature either as numpy array or pandas DataFrame.
|
||||
"""
|
||||
xgdump = self.get_dump(fmap=fmap)
|
||||
values = []
|
||||
regexp = re.compile("\[{0}<([\d.Ee+-]+)\]".format(feature))
|
||||
for i in range(len(xgdump)):
|
||||
m = re.findall(regexp, xgdump[i])
|
||||
values.extend(map(float, m))
|
||||
|
||||
n_unique = np.unique(values).shape[0]
|
||||
bins = max(min(n_unique, bins) if bins is not None else n_unique, 1)
|
||||
|
||||
nph = np.histogram(values, bins=bins)
|
||||
nph = np.column_stack((nph[1][1:], nph[0]))
|
||||
nph = nph[nph[:, 1] > 0]
|
||||
|
||||
if as_pandas and PANDAS_INSTALLED:
|
||||
return DataFrame(nph, columns=['SplitValue', 'Count'])
|
||||
elif as_pandas and not PANDAS_INSTALLED:
|
||||
sys.stderr.write("Returning histogram as ndarray (as_pandas == True, but pandas is not installed).")
|
||||
return nph
|
||||
else:
|
||||
return nph
|
||||
|
||||
@ -240,3 +240,22 @@ def test_sklearn_nfolds_cv():
|
||||
cv3 = xgb.cv(params, dm, num_boost_round=10, nfold=nfolds, stratified=True, seed=seed)
|
||||
assert cv1.shape[0] == cv2.shape[0] and cv2.shape[0] == cv3.shape[0]
|
||||
assert cv2.iloc[-1, 0] == cv3.iloc[-1, 0]
|
||||
|
||||
|
||||
def test_split_value_histograms():
|
||||
digits_2class = load_digits(2)
|
||||
|
||||
X = digits_2class['data']
|
||||
y = digits_2class['target']
|
||||
|
||||
dm = xgb.DMatrix(X, label=y)
|
||||
params = {'max_depth': 6, 'eta': 0.01, 'silent': 1, 'objective': 'binary:logistic'}
|
||||
|
||||
gbdt = xgb.train(params, dm, num_boost_round=10)
|
||||
assert gbdt.get_split_value_histogram("not_there", as_pandas=True).shape[0] == 0
|
||||
assert gbdt.get_split_value_histogram("not_there", as_pandas=False).shape[0] == 0
|
||||
assert gbdt.get_split_value_histogram("f28", bins=0).shape[0] == 1
|
||||
assert gbdt.get_split_value_histogram("f28", bins=1).shape[0] == 1
|
||||
assert gbdt.get_split_value_histogram("f28", bins=2).shape[0] == 2
|
||||
assert gbdt.get_split_value_histogram("f28", bins=5).shape[0] == 2
|
||||
assert gbdt.get_split_value_histogram("f28", bins=None).shape[0] == 2
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user