Implement feature score for linear model. (#7048)
* Add feature score support for linear model. * Port R interface to the new implementation. * Add linear model support in Python. Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -2132,47 +2132,18 @@ class Booster(object):
|
||||
fmap = os.fspath(os.path.expanduser(fmap))
|
||||
length = c_bst_ulong()
|
||||
sarr = ctypes.POINTER(ctypes.c_char_p)()
|
||||
if self.feature_names is not None and fmap == '':
|
||||
flen = len(self.feature_names)
|
||||
|
||||
fname = from_pystr_to_cstr(self.feature_names)
|
||||
|
||||
if self.feature_types is None:
|
||||
# use quantitative as default
|
||||
# {'q': quantitative, 'i': indicator}
|
||||
ftype = from_pystr_to_cstr(['q'] * flen)
|
||||
else:
|
||||
ftype = from_pystr_to_cstr(self.feature_types)
|
||||
_check_call(_LIB.XGBoosterDumpModelExWithFeatures(
|
||||
self.handle,
|
||||
ctypes.c_int(flen),
|
||||
fname,
|
||||
ftype,
|
||||
ctypes.c_int(with_stats),
|
||||
c_str(dump_format),
|
||||
ctypes.byref(length),
|
||||
ctypes.byref(sarr)))
|
||||
else:
|
||||
if fmap != '' and not os.path.exists(fmap):
|
||||
raise ValueError("No such file: {0}".format(fmap))
|
||||
_check_call(_LIB.XGBoosterDumpModelEx(self.handle,
|
||||
c_str(fmap),
|
||||
ctypes.c_int(with_stats),
|
||||
c_str(dump_format),
|
||||
ctypes.byref(length),
|
||||
ctypes.byref(sarr)))
|
||||
_check_call(_LIB.XGBoosterDumpModelEx(self.handle,
|
||||
c_str(fmap),
|
||||
ctypes.c_int(with_stats),
|
||||
c_str(dump_format),
|
||||
ctypes.byref(length),
|
||||
ctypes.byref(sarr)))
|
||||
res = from_cstr_to_pystr(sarr, length)
|
||||
return res
|
||||
|
||||
def get_fscore(self, fmap=''):
|
||||
"""Get feature importance of each feature.
|
||||
|
||||
.. note:: Feature importance is defined only for tree boosters
|
||||
|
||||
Feature importance is only defined when the decision tree model is chosen as base
|
||||
learner (`booster=gbtree`). It is not defined for other base learner types, such
|
||||
as linear learners (`booster=gblinear`).
|
||||
|
||||
.. note:: Zero-importance features will not be included
|
||||
|
||||
Keep in mind that this function does not include zero-importance feature, i.e.
|
||||
@@ -2190,7 +2161,7 @@ class Booster(object):
|
||||
self, fmap: os.PathLike = '', importance_type: str = 'weight'
|
||||
) -> Dict[str, float]:
|
||||
"""Get feature importance of each feature.
|
||||
Importance type can be defined as:
|
||||
For tree model Importance type can be defined as:
|
||||
|
||||
* 'weight': the number of times a feature is used to split the data across all trees.
|
||||
* 'gain': the average gain across all splits the feature is used in.
|
||||
@@ -2198,11 +2169,15 @@ class Booster(object):
|
||||
* 'total_gain': the total gain across all splits the feature is used in.
|
||||
* 'total_cover': the total coverage across all splits the feature is used in.
|
||||
|
||||
.. note:: Feature importance is defined only for tree boosters
|
||||
.. note::
|
||||
|
||||
Feature importance is only defined when the decision tree model is chosen as
|
||||
base learner (`booster=gbtree` or `booster=dart`). It is not defined for other
|
||||
base learner types, such as linear learners (`booster=gblinear`).
|
||||
For linear model, only "weight" is defined and it's the normalized coefficients
|
||||
without bias.
|
||||
|
||||
.. note:: Zero-importance features will not be included
|
||||
|
||||
Keep in mind that this function does not include zero-importance feature, i.e.
|
||||
those features that have not been used in any split conditions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -2213,7 +2188,9 @@ class Booster(object):
|
||||
|
||||
Returns
|
||||
-------
|
||||
A map between feature names and their scores.
|
||||
A map between feature names and their scores. When `gblinear` is used for
|
||||
multi-class classification the scores for each feature is a list with length
|
||||
`n_classes`, otherwise they're scalars.
|
||||
"""
|
||||
fmap = os.fspath(os.path.expanduser(fmap))
|
||||
args = from_pystr_to_cstr(
|
||||
@@ -2221,21 +2198,31 @@ class Booster(object):
|
||||
)
|
||||
features = ctypes.POINTER(ctypes.c_char_p)()
|
||||
scores = ctypes.POINTER(ctypes.c_float)()
|
||||
length = c_bst_ulong()
|
||||
n_out_features = c_bst_ulong()
|
||||
out_dim = c_bst_ulong()
|
||||
shape = ctypes.POINTER(c_bst_ulong)()
|
||||
|
||||
_check_call(
|
||||
_LIB.XGBoosterFeatureScore(
|
||||
self.handle,
|
||||
args,
|
||||
ctypes.byref(length),
|
||||
ctypes.byref(n_out_features),
|
||||
ctypes.byref(features),
|
||||
ctypes.byref(scores)
|
||||
ctypes.byref(out_dim),
|
||||
ctypes.byref(shape),
|
||||
ctypes.byref(scores),
|
||||
)
|
||||
)
|
||||
features_arr = from_cstr_to_pystr(features, length)
|
||||
scores_arr = ctypes2numpy(scores, length.value, np.float32)
|
||||
features_arr = from_cstr_to_pystr(features, n_out_features)
|
||||
scores_arr = _prediction_output(shape, out_dim, scores, False)
|
||||
|
||||
results = {}
|
||||
for feat, score in zip(features_arr, scores_arr):
|
||||
results[feat] = float(score)
|
||||
if len(scores_arr.shape) > 1 and scores_arr.shape[1] > 1:
|
||||
for feat, score in zip(features_arr, scores_arr):
|
||||
results[feat] = [float(s) for s in score]
|
||||
else:
|
||||
for feat, score in zip(features_arr, scores_arr):
|
||||
results[feat] = float(score)
|
||||
return results
|
||||
|
||||
def trees_to_dataframe(self, fmap=''):
|
||||
|
||||
@@ -156,9 +156,14 @@ __model_doc = f'''
|
||||
[2, 3, 4]], where each inner list is a group of indices of features
|
||||
that are allowed to interact with each other. See tutorial for more
|
||||
information
|
||||
importance_type: string, default "gain"
|
||||
importance_type: Optional[str]
|
||||
The feature importance type for the feature_importances\\_ property:
|
||||
either "gain", "weight", "cover", "total_gain" or "total_cover".
|
||||
|
||||
* For tree model, it's either "gain", "weight", "cover", "total_gain" or
|
||||
"total_cover".
|
||||
* For linear model, only "weight" is defined and it's the normalized coefficients
|
||||
without bias.
|
||||
|
||||
gpu_id : Optional[int]
|
||||
Device ordinal.
|
||||
validate_parameters : Optional[bool]
|
||||
@@ -382,7 +387,7 @@ class XGBModel(XGBModelBase):
|
||||
num_parallel_tree: Optional[int] = None,
|
||||
monotone_constraints: Optional[Union[Dict[str, int], str]] = None,
|
||||
interaction_constraints: Optional[Union[str, List[Tuple[str]]]] = None,
|
||||
importance_type: str = "gain",
|
||||
importance_type: Optional[str] = None,
|
||||
gpu_id: Optional[int] = None,
|
||||
validate_parameters: Optional[bool] = None,
|
||||
predictor: Optional[str] = None,
|
||||
@@ -991,29 +996,26 @@ class XGBModel(XGBModelBase):
|
||||
@property
|
||||
def feature_importances_(self) -> np.ndarray:
|
||||
"""
|
||||
Feature importances property
|
||||
|
||||
.. note:: Feature importance is defined only for tree boosters
|
||||
|
||||
Feature importance is only defined when the decision tree model is chosen as base
|
||||
learner (`booster=gbtree`). It is not defined for other base learner types, such
|
||||
as linear learners (`booster=gblinear`).
|
||||
Feature importances property, return depends on `importance_type` parameter.
|
||||
|
||||
Returns
|
||||
-------
|
||||
feature_importances_ : array of shape ``[n_features]``
|
||||
feature_importances_ : array of shape ``[n_features]`` except for multi-class
|
||||
linear model, which returns an array with shape `(n_features, n_classes)`
|
||||
|
||||
"""
|
||||
if self.get_params()['booster'] not in {'gbtree', 'dart'}:
|
||||
raise AttributeError(
|
||||
'Feature importance is not defined for Booster type {}'
|
||||
.format(self.booster))
|
||||
b: Booster = self.get_booster()
|
||||
score = b.get_score(importance_type=self.importance_type)
|
||||
|
||||
def dft() -> str:
|
||||
return "weight" if self.booster == "gblinear" else "gain"
|
||||
score = b.get_score(
|
||||
importance_type=self.importance_type if self.importance_type else dft()
|
||||
)
|
||||
if b.feature_names is None:
|
||||
feature_names = ["f{0}".format(i) for i in range(self.n_features_in_)]
|
||||
else:
|
||||
feature_names = b.feature_names
|
||||
# gblinear returns all features so the `get` in next line is only for gbtree.
|
||||
all_features = [score.get(f, 0.) for f in feature_names]
|
||||
all_features_arr = np.array(all_features, dtype=np.float32)
|
||||
total = all_features_arr.sum()
|
||||
|
||||
Reference in New Issue
Block a user