xgboost/R-package/R/xgb.plot.importance.R
2015-02-18 19:41:13 +01:00

62 lines
3.1 KiB
R

#' Plot feature importance bar graph
#'
#' Read a data.table containing feature importance details and plot it.
#'
#' @importFrom ggplot2 ggplot
#' @importFrom ggplot2 aes
#' @importFrom ggplot2 geom_bar
#' @importFrom ggplot2 coord_flip
#' @importFrom ggplot2 xlab
#' @importFrom ggplot2 ylab
#' @importFrom ggplot2 ggtitle
#' @importFrom ggplot2 theme
#' @importFrom ggplot2 element_text
#' @importFrom ggplot2 element_blank
#' @importFrom Ckmeans.1d.dp Ckmeans.1d.dp
#' @importFrom magrittr %>%
#' @param importance_matrix a \code{data.table} returned by the \code{xgb.importance} function.
#' @param numberOfClusters a \code{numeric} vector containing the min and the max range of the possible number of clusters of bars.
#'
#' @return A \code{ggplot2} bar graph representing each feature by a horizontal bar. Longer is the bar, more important is the feature. Features are classified by importance and clustered by importance. The group is represented through the color of the bar.
#'
#' @details
#' The purpose of this function is to easily represent the importance of each feature of a model.
#' The function return a ggplot graph, therefore each of its characteristic can be overriden (to customize it).
#' In particular you may want to override the title of the graph. To do so, add \code{+ ggtitle("A GRAPH NAME")} next to the value returned by this function.
#'
#' @examples
#' data(agaricus.train, package='xgboost')
#'
#' #Both dataset are list with two items, a sparse matrix and labels
#' #(labels = outcome column which will be learned).
#' #Each column of the sparse Matrix is a feature in one hot encoding format.
#' train <- agaricus.train
#'
#' bst <- xgboost(data = train$data, label = train$label, max.depth = 2,
#' eta = 1, nround = 2,objective = "binary:logistic")
#'
#' #train$data@@Dimnames[[2]] represents the column names of the sparse matrix.
#' importance_matrix <- xgb.importance(train$data@@Dimnames[[2]], model = bst)
#' xgb.plot.importance(importance_matrix)
#'
#' @export
xgb.plot.importance <- function(importance_matrix = NULL, numberOfClusters = c(1:10)){
if (!"data.table" %in% class(importance_matrix)) {
stop("importance_matrix: Should be a data.table.")
}
# To avoid issues in clustering when co-occurences are used
importance_matrix <- importance_matrix[, .(Gain = sum(Gain)), by = Feature]
clusters <- suppressWarnings(Ckmeans.1d.dp(importance_matrix[,Gain], numberOfClusters))
importance_matrix[,"Cluster":=clusters$cluster %>% as.character]
plot <- ggplot(importance_matrix, aes(x=reorder(Feature, Gain), y = Gain, width= 0.05), environment = environment())+ geom_bar(aes(fill=Cluster), stat="identity", position="identity") + coord_flip() + xlab("Features") + ylab("Gain") + ggtitle("Feature importance") + theme(plot.title = element_text(lineheight=.9, face="bold"), panel.grid.major.y = element_blank() )
return(plot)
}
# Avoid error messages during CRAN check.
# The reason is that these variables are never declared
# They are mainly column names inferred by Data.table...
globalVariables(c("Feature", "Gain", "Cluster"))