[py] split value histograms
This commit is contained in:
parent
6691d5c3f4
commit
cf607e2448
@ -3,16 +3,18 @@
|
|||||||
"""Core XGBoost Library."""
|
"""Core XGBoost Library."""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
|
|
||||||
|
import sys
|
||||||
import os
|
import os
|
||||||
import ctypes
|
import ctypes
|
||||||
import collections
|
import collections
|
||||||
|
import re
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy.sparse
|
import scipy.sparse
|
||||||
|
|
||||||
from .libpath import find_lib_path
|
from .libpath import find_lib_path
|
||||||
|
|
||||||
from .compat import STRING_TYPES, PY3, DataFrame, py_str
|
from .compat import STRING_TYPES, PY3, DataFrame, py_str, PANDAS_INSTALLED
|
||||||
|
|
||||||
|
|
||||||
class XGBoostError(Exception):
|
class XGBoostError(Exception):
|
||||||
@ -1058,3 +1060,44 @@ class Booster(object):
|
|||||||
|
|
||||||
raise ValueError(msg.format(self.feature_names,
|
raise ValueError(msg.format(self.feature_names,
|
||||||
data.feature_names))
|
data.feature_names))
|
||||||
|
|
||||||
|
def get_split_value_histogram(self, feature, fmap='', bins=None, as_pandas=True):
|
||||||
|
"""Get split value histogram of a feature
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
feature: str
|
||||||
|
The name of the feature.
|
||||||
|
fmap: str (optional)
|
||||||
|
The name of feature map file.
|
||||||
|
bin: int, default None
|
||||||
|
The maximum number of bins.
|
||||||
|
Number of bins equals number of unique split values n_unique, if bins == None or bins > n_unique.
|
||||||
|
as_pandas : bool, default True
|
||||||
|
Return pd.DataFrame when pandas is installed.
|
||||||
|
If False or pandas is not installed, return numpy ndarray.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
a histogram of used splitting values for the specified feature either as numpy array or pandas DataFrame.
|
||||||
|
"""
|
||||||
|
xgdump = self.get_dump(fmap=fmap)
|
||||||
|
values = []
|
||||||
|
regexp = re.compile("\[{0}<([\d.Ee+-]+)\]".format(feature))
|
||||||
|
for i in range(len(xgdump)):
|
||||||
|
m = re.findall(regexp, xgdump[i])
|
||||||
|
values.extend(map(float, m))
|
||||||
|
|
||||||
|
n_unique = np.unique(values).shape[0]
|
||||||
|
bins = max(min(n_unique, bins) if bins is not None else n_unique, 1)
|
||||||
|
|
||||||
|
nph = np.histogram(values, bins=bins)
|
||||||
|
nph = np.column_stack((nph[1][1:], nph[0]))
|
||||||
|
nph = nph[nph[:, 1] > 0]
|
||||||
|
|
||||||
|
if as_pandas and PANDAS_INSTALLED:
|
||||||
|
return DataFrame(nph, columns=['SplitValue', 'Count'])
|
||||||
|
elif as_pandas and not PANDAS_INSTALLED:
|
||||||
|
sys.stderr.write("Returning histogram as ndarray (as_pandas == True, but pandas is not installed).")
|
||||||
|
return nph
|
||||||
|
else:
|
||||||
|
return nph
|
||||||
|
|||||||
@ -240,3 +240,22 @@ def test_sklearn_nfolds_cv():
|
|||||||
cv3 = xgb.cv(params, dm, num_boost_round=10, nfold=nfolds, stratified=True, seed=seed)
|
cv3 = xgb.cv(params, dm, num_boost_round=10, nfold=nfolds, stratified=True, seed=seed)
|
||||||
assert cv1.shape[0] == cv2.shape[0] and cv2.shape[0] == cv3.shape[0]
|
assert cv1.shape[0] == cv2.shape[0] and cv2.shape[0] == cv3.shape[0]
|
||||||
assert cv2.iloc[-1, 0] == cv3.iloc[-1, 0]
|
assert cv2.iloc[-1, 0] == cv3.iloc[-1, 0]
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_value_histograms():
|
||||||
|
digits_2class = load_digits(2)
|
||||||
|
|
||||||
|
X = digits_2class['data']
|
||||||
|
y = digits_2class['target']
|
||||||
|
|
||||||
|
dm = xgb.DMatrix(X, label=y)
|
||||||
|
params = {'max_depth': 6, 'eta': 0.01, 'silent': 1, 'objective': 'binary:logistic'}
|
||||||
|
|
||||||
|
gbdt = xgb.train(params, dm, num_boost_round=10)
|
||||||
|
assert gbdt.get_split_value_histogram("not_there", as_pandas=True).shape[0] == 0
|
||||||
|
assert gbdt.get_split_value_histogram("not_there", as_pandas=False).shape[0] == 0
|
||||||
|
assert gbdt.get_split_value_histogram("f28", bins=0).shape[0] == 1
|
||||||
|
assert gbdt.get_split_value_histogram("f28", bins=1).shape[0] == 1
|
||||||
|
assert gbdt.get_split_value_histogram("f28", bins=2).shape[0] == 2
|
||||||
|
assert gbdt.get_split_value_histogram("f28", bins=5).shape[0] == 2
|
||||||
|
assert gbdt.get_split_value_histogram("f28", bins=None).shape[0] == 2
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user