new plot feature importance function
This commit is contained in:
parent
15dee73795
commit
e06c1da842
@ -25,4 +25,6 @@ Imports:
|
|||||||
data.table (>= 1.9.4),
|
data.table (>= 1.9.4),
|
||||||
magrittr (>= 1.5),
|
magrittr (>= 1.5),
|
||||||
stringr,
|
stringr,
|
||||||
DiagrammeR (>= 0.3)
|
DiagrammeR (>= 0.3),
|
||||||
|
ggplot2(>= 1.0.0),
|
||||||
|
Ckmeans.1d.dp
|
||||||
|
|||||||
@ -10,6 +10,7 @@ export(xgb.dump)
|
|||||||
export(xgb.importance)
|
export(xgb.importance)
|
||||||
export(xgb.load)
|
export(xgb.load)
|
||||||
export(xgb.model.dt.tree)
|
export(xgb.model.dt.tree)
|
||||||
|
export(xgb.plot.importance)
|
||||||
export(xgb.plot.tree)
|
export(xgb.plot.tree)
|
||||||
export(xgb.save)
|
export(xgb.save)
|
||||||
export(xgb.train)
|
export(xgb.train)
|
||||||
@ -18,6 +19,7 @@ exportMethods(predict)
|
|||||||
import(methods)
|
import(methods)
|
||||||
importClassesFrom(Matrix,dgCMatrix)
|
importClassesFrom(Matrix,dgCMatrix)
|
||||||
importClassesFrom(Matrix,dgeMatrix)
|
importClassesFrom(Matrix,dgeMatrix)
|
||||||
|
importFrom(Ckmeans.1d.dp,Ckmeans.1d.dp)
|
||||||
importFrom(DiagrammeR,mermaid)
|
importFrom(DiagrammeR,mermaid)
|
||||||
importFrom(data.table,":=")
|
importFrom(data.table,":=")
|
||||||
importFrom(data.table,as.data.table)
|
importFrom(data.table,as.data.table)
|
||||||
@ -26,6 +28,16 @@ importFrom(data.table,data.table)
|
|||||||
importFrom(data.table,rbindlist)
|
importFrom(data.table,rbindlist)
|
||||||
importFrom(data.table,set)
|
importFrom(data.table,set)
|
||||||
importFrom(data.table,setnames)
|
importFrom(data.table,setnames)
|
||||||
|
importFrom(ggplot2,aes)
|
||||||
|
importFrom(ggplot2,coord_flip)
|
||||||
|
importFrom(ggplot2,element_blank)
|
||||||
|
importFrom(ggplot2,element_text)
|
||||||
|
importFrom(ggplot2,geom_bar)
|
||||||
|
importFrom(ggplot2,ggplot)
|
||||||
|
importFrom(ggplot2,ggtitle)
|
||||||
|
importFrom(ggplot2,theme)
|
||||||
|
importFrom(ggplot2,xlab)
|
||||||
|
importFrom(ggplot2,ylab)
|
||||||
importFrom(magrittr,"%>%")
|
importFrom(magrittr,"%>%")
|
||||||
importFrom(magrittr,add)
|
importFrom(magrittr,add)
|
||||||
importFrom(magrittr,not)
|
importFrom(magrittr,not)
|
||||||
|
|||||||
60
R-package/R/xgb.plot.importance.R
Normal file
60
R-package/R/xgb.plot.importance.R
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
#' Plot feature importance bar graph
|
||||||
|
#'
|
||||||
|
#' Read a data.table containing feature importance details and plot it.
|
||||||
|
#'
|
||||||
|
|
||||||
|
#' @importFrom ggplot2 ggplot
|
||||||
|
#' @importFrom ggplot2 aes
|
||||||
|
#' @importFrom ggplot2 geom_bar
|
||||||
|
#' @importFrom ggplot2 coord_flip
|
||||||
|
#' @importFrom ggplot2 xlab
|
||||||
|
#' @importFrom ggplot2 ylab
|
||||||
|
#' @importFrom ggplot2 ggtitle
|
||||||
|
#' @importFrom ggplot2 theme
|
||||||
|
#' @importFrom ggplot2 element_text
|
||||||
|
#' @importFrom ggplot2 element_blank
|
||||||
|
#' @importFrom Ckmeans.1d.dp Ckmeans.1d.dp
|
||||||
|
#' @importFrom magrittr %>%
|
||||||
|
#' @param importance_matrix a \code{data.table} returned by the \code{xgb.importance} function.
|
||||||
|
#' @param numberOfClusters a \code{numeric} vector containing the min and the max range of the possible number of clusters of bars.
|
||||||
|
#'
|
||||||
|
#' @return A \code{ggplot2} bar graph representing each feature by a horizontal bar. Longer is the bar, more important is the feature. Features are classified by importance and clustered by importance. The group is represented through the color of the bar.
|
||||||
|
#'
|
||||||
|
#' @details
|
||||||
|
#' The purpose of this function is to easily represent the importance of each feature of a model.
|
||||||
|
#' The function return a ggplot graph, therefore each of its characteristic can be overriden (to customize it).
|
||||||
|
#' In particular you may want to override the title of the graph. To do so, add \code{+ ggtitle("A GRAPH NAME")} next to the value returned by this function.
|
||||||
|
#'
|
||||||
|
#' @examples
|
||||||
|
#' data(agaricus.train, package='xgboost')
|
||||||
|
#'
|
||||||
|
#' #Both dataset are list with two items, a sparse matrix and labels
|
||||||
|
#' #(labels = outcome column which will be learned).
|
||||||
|
#' #Each column of the sparse Matrix is a feature in one hot encoding format.
|
||||||
|
#' train <- agaricus.train
|
||||||
|
#'
|
||||||
|
#' bst <- xgboost(data = train$data, label = train$label, max.depth = 2,
|
||||||
|
#' eta = 1, nround = 2,objective = "binary:logistic")
|
||||||
|
#'
|
||||||
|
#' #train$data@@Dimnames[[2]] represents the column names of the sparse matrix.
|
||||||
|
#' importance_matrix <- xgb.importance(train$data@@Dimnames[[2]], model = bst)
|
||||||
|
#' xgb.plot.importance(importance_matrix)
|
||||||
|
#'
|
||||||
|
#' @export
|
||||||
|
xgb.plot.importance <- function(importance_matrix = NULL, numberOfClusters = c(1:10)){
|
||||||
|
if (!"data.table" %in% class(importance_matrix)) {
|
||||||
|
stop("importance_matrix: Should be a data.table.")
|
||||||
|
}
|
||||||
|
|
||||||
|
clusters <- suppressWarnings(Ckmeans.1d.dp(importance_matrix[,Gain], numberOfClusters))
|
||||||
|
importance_matrix[,"Cluster":=clusters$cluster %>% as.character]
|
||||||
|
|
||||||
|
plot <- ggplot(importance_matrix, aes(x=reorder(Feature, Gain), y = Gain, width= 0.05), environment = environment())+ geom_bar(aes(fill=Cluster), stat="identity", position="identity") + coord_flip() + xlab("Features") + ylab("Gain") + ggtitle("Feature importance") + theme(plot.title = element_text(lineheight=.9, face="bold"), panel.grid.major.y = element_blank() )
|
||||||
|
|
||||||
|
return(plot)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Avoid error messages during CRAN check.
|
||||||
|
# The reason is that these variables are never declared
|
||||||
|
# They are mainly column names inferred by Data.table...
|
||||||
|
globalVariables(c("Feature","Gain"))
|
||||||
40
R-package/man/xgb.plot.importance.Rd
Normal file
40
R-package/man/xgb.plot.importance.Rd
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
% Generated by roxygen2 (4.1.0): do not edit by hand
|
||||||
|
% Please edit documentation in R/xgb.plot.importance.R
|
||||||
|
\name{xgb.plot.importance}
|
||||||
|
\alias{xgb.plot.importance}
|
||||||
|
\title{Plot feature importance bar graph}
|
||||||
|
\usage{
|
||||||
|
xgb.plot.importance(importance_matrix = NULL, numberOfClusters = c(1:10))
|
||||||
|
}
|
||||||
|
\arguments{
|
||||||
|
\item{importance_matrix}{a \code{data.table} returned by the \code{xgb.importance} function.}
|
||||||
|
|
||||||
|
\item{numberOfClusters}{a \code{numeric} vector containing the min and the max range of the possible number of clusters of bars.}
|
||||||
|
}
|
||||||
|
\value{
|
||||||
|
A \code{ggplot2} bar graph representing each feature by a horizontal bar. Longer is the bar, more important is the feature. Features are classified by importance and clustered by importance. The group is represented through the color of the bar.
|
||||||
|
}
|
||||||
|
\description{
|
||||||
|
Read a data.table containing feature importance details and plot it.
|
||||||
|
}
|
||||||
|
\details{
|
||||||
|
The purpose of this function is to easily represent the importance of each feature of a model.
|
||||||
|
The function return a ggplot graph, therefore each of its characteristic can be overriden (to customize it).
|
||||||
|
In particular you may want to override the title of the graph. To do so, add \code{+ ggtitle("A GRAPH NAME")} next to the value returned by this function.
|
||||||
|
}
|
||||||
|
\examples{
|
||||||
|
data(agaricus.train, package='xgboost')
|
||||||
|
|
||||||
|
#Both dataset are list with two items, a sparse matrix and labels
|
||||||
|
#(labels = outcome column which will be learned).
|
||||||
|
#Each column of the sparse Matrix is a feature in one hot encoding format.
|
||||||
|
train <- agaricus.train
|
||||||
|
|
||||||
|
bst <- xgboost(data = train$data, label = train$label, max.depth = 2,
|
||||||
|
eta = 1, nround = 2,objective = "binary:logistic")
|
||||||
|
|
||||||
|
#train$data@Dimnames[[2]] represents the column names of the sparse matrix.
|
||||||
|
importance_matrix <- xgb.importance(train$data@Dimnames[[2]], model = bst)
|
||||||
|
xgb.plot.importance(importance_matrix)
|
||||||
|
}
|
||||||
|
|
||||||
Loading…
x
Reference in New Issue
Block a user