[py] split value histograms
This commit is contained in:
@@ -3,16 +3,18 @@
|
||||
"""Core XGBoost Library."""
|
||||
from __future__ import absolute_import
|
||||
|
||||
import sys
|
||||
import os
|
||||
import ctypes
|
||||
import collections
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import scipy.sparse
|
||||
|
||||
from .libpath import find_lib_path
|
||||
|
||||
from .compat import STRING_TYPES, PY3, DataFrame, py_str
|
||||
from .compat import STRING_TYPES, PY3, DataFrame, py_str, PANDAS_INSTALLED
|
||||
|
||||
|
||||
class XGBoostError(Exception):
|
||||
@@ -1058,3 +1060,44 @@ class Booster(object):
|
||||
|
||||
raise ValueError(msg.format(self.feature_names,
|
||||
data.feature_names))
|
||||
|
||||
def get_split_value_histogram(self, feature, fmap='', bins=None, as_pandas=True):
|
||||
"""Get split value histogram of a feature
|
||||
Parameters
|
||||
----------
|
||||
feature: str
|
||||
The name of the feature.
|
||||
fmap: str (optional)
|
||||
The name of feature map file.
|
||||
bin: int, default None
|
||||
The maximum number of bins.
|
||||
Number of bins equals number of unique split values n_unique, if bins == None or bins > n_unique.
|
||||
as_pandas : bool, default True
|
||||
Return pd.DataFrame when pandas is installed.
|
||||
If False or pandas is not installed, return numpy ndarray.
|
||||
|
||||
Returns
|
||||
-------
|
||||
a histogram of used splitting values for the specified feature either as numpy array or pandas DataFrame.
|
||||
"""
|
||||
xgdump = self.get_dump(fmap=fmap)
|
||||
values = []
|
||||
regexp = re.compile("\[{0}<([\d.Ee+-]+)\]".format(feature))
|
||||
for i in range(len(xgdump)):
|
||||
m = re.findall(regexp, xgdump[i])
|
||||
values.extend(map(float, m))
|
||||
|
||||
n_unique = np.unique(values).shape[0]
|
||||
bins = max(min(n_unique, bins) if bins is not None else n_unique, 1)
|
||||
|
||||
nph = np.histogram(values, bins=bins)
|
||||
nph = np.column_stack((nph[1][1:], nph[0]))
|
||||
nph = nph[nph[:, 1] > 0]
|
||||
|
||||
if as_pandas and PANDAS_INSTALLED:
|
||||
return DataFrame(nph, columns=['SplitValue', 'Count'])
|
||||
elif as_pandas and not PANDAS_INSTALLED:
|
||||
sys.stderr.write("Returning histogram as ndarray (as_pandas == True, but pandas is not installed).")
|
||||
return nph
|
||||
else:
|
||||
return nph
|
||||
|
||||
Reference in New Issue
Block a user