[R] xgb.importance: fix for multiclass gblinear, new 'trees' parameter (#2388)

This commit is contained in:
Vadim Khotilovich
2017-06-07 13:13:21 -05:00
committed by GitHub
parent 2ae56ca84f
commit c82276386d
4 changed files with 120 additions and 35 deletions

View File

@@ -6,6 +6,11 @@
#' contains feature names, those would be used when \code{feature_names=NULL} (default value).
#' Non-null \code{feature_names} could be provided to override those in the model.
#' @param model object of class \code{xgb.Booster}.
#' @param trees (only for the gbtree booster) an integer vector of tree indices that should be included
#' into the importance calculation. If set to \code{NULL}, all trees of the model are parsed.
#' It could be useful, e.g., in multiclass classification to get feature importances
#' for each class separately. IMPORTANT: the tree index in xgboost models
#' is zero-based (e.g., use \code{trees = 0:4} for first 5 trees).
#' @param data deprecated.
#' @param label deprecated.
#' @param target deprecated.
@@ -32,27 +37,51 @@
#' a feature have been used in trees.
#' }
#'
#' A linear model's importance \code{data.table} has only two columns:
#' A linear model's importance \code{data.table} has the following columns:
#' \itemize{
#' \item \code{Features} names of the features used in the model;
#' \item \code{Weight} the linear coefficient of this feature.
#' \item \code{Weight} the linear coefficient of this feature;
#' \item \code{Class} (only for multiclass models) class label.
#' }
#'
#' If you don't provide or \code{model} doesn't have \code{feature_names},
#' If \code{feature_names} is not provided and \code{model} doesn't have \code{feature_names},
#' index of the features will be used instead. Because the index is extracted from the model dump
#' (based on C++ code), it starts at 0 (as in C/C++ or Python) instead of 1 (usual in R).
#'
#' @examples
#'
#' # binomial classification using gbtree:
#' data(agaricus.train, package='xgboost')
#'
#' bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, max_depth = 2,
#' eta = 1, nthread = 2, nrounds = 2,objective = "binary:logistic")
#'
#' eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic")
#' xgb.importance(model = bst)
#'
#' # binomial classification using gblinear:
#' bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, booster = "gblinear",
#' eta = 0.3, nthread = 1, nrounds = 20, objective = "binary:logistic")
#' xgb.importance(model = bst)
#'
#' # multiclass classification using gbtree:
#' nclass <- 3
#' nrounds <- 10
#' mbst <- xgboost(data = as.matrix(iris[, -5]), label = as.numeric(iris$Species) - 1,
#' max_depth = 3, eta = 0.2, nthread = 2, nrounds = nrounds,
#' objective = "multi:softprob", num_class = nclass)
#' # all classes clumped together:
#' xgb.importance(model = mbst)
#' # inspect importances separately for each class:
#' xgb.importance(model = mbst, trees = seq(from=0, by=nclass, length.out=nrounds))
#' xgb.importance(model = mbst, trees = seq(from=1, by=nclass, length.out=nrounds))
#' xgb.importance(model = mbst, trees = seq(from=2, by=nclass, length.out=nrounds))
#'
#' # multiclass classification using gblinear:
#' mbst <- xgboost(data = scale(as.matrix(iris[, -5])), label = as.numeric(iris$Species) - 1,
#' booster = "gblinear", eta = 0.2, nthread = 1, nrounds = 15,
#' objective = "multi:softprob", num_class = nclass)
#' xgb.importance(model = mbst)
#'
#' @export
xgb.importance <- function(feature_names = NULL, model = NULL,
xgb.importance <- function(feature_names = NULL, model = NULL, trees = NULL,
data = NULL, label = NULL, target = NULL){
if (!(is.null(data) && is.null(label) && is.null(target)))
@@ -74,14 +103,25 @@ xgb.importance <- function(feature_names = NULL, model = NULL,
weights <- which(model_text_dump == "weight:") %>%
{model_text_dump[(. + 1):length(model_text_dump)]} %>%
as.numeric
num_class <- NVL(model$params$num_class, 1)
if(is.null(feature_names))
feature_names <- seq(to = length(weights))
if (length(feature_names) != length(weights))
stop("feature_names has less elements than there are features used in the model")
result <- data.table(Feature = feature_names, Weight = weights)[order(-abs(Weight))]
feature_names <- seq(to = length(weights) / num_class) - 1
if (length(feature_names) * num_class != length(weights))
stop("feature_names length does not match the number of features used in the model")
result <- if (num_class == 1) {
data.table(Feature = feature_names, Weight = weights)[order(-abs(Weight))]
} else {
data.table(Feature = rep(feature_names, each = num_class),
Weight = weights,
Class = 0:(num_class - 1))[order(Class, -abs(Weight))]
}
} else {
# tree model
result <- xgb.model.dt.tree(feature_names = feature_names, text = model_text_dump)[
result <- xgb.model.dt.tree(feature_names = feature_names,
text = model_text_dump,
trees = trees)[
Feature != "Leaf", .(Gain = sum(Quality),
Cover = sum(Cover),
Frequency = .N), by = Feature][