Merge pull request #664 from pommedeterresautee/master
Support GLM in importance plot + increase tests #Rstat
This commit is contained in:
commit
88e7c6012b
@ -1,6 +1,6 @@
|
|||||||
#' Plot feature importance bar graph
|
#' Plot feature importance bar graph
|
||||||
#'
|
#'
|
||||||
#' Read a data.table containing feature importance details and plot it.
|
#' Read a data.table containing feature importance details and plot it (for both GLM and Trees).
|
||||||
#'
|
#'
|
||||||
#' @importFrom magrittr %>%
|
#' @importFrom magrittr %>%
|
||||||
#' @param importance_matrix a \code{data.table} returned by the \code{xgb.importance} function.
|
#' @param importance_matrix a \code{data.table} returned by the \code{xgb.importance} function.
|
||||||
@ -10,7 +10,7 @@
|
|||||||
#'
|
#'
|
||||||
#' @details
|
#' @details
|
||||||
#' The purpose of this function is to easily represent the importance of each feature of a model.
|
#' The purpose of this function is to easily represent the importance of each feature of a model.
|
||||||
#' The function return a ggplot graph, therefore each of its characteristic can be overriden (to customize it).
|
#' The function returns a ggplot graph, therefore each of its characteristic can be overriden (to customize it).
|
||||||
#' In particular you may want to override the title of the graph. To do so, add \code{+ ggtitle("A GRAPH NAME")} next to the value returned by this function.
|
#' In particular you may want to override the title of the graph. To do so, add \code{+ ggtitle("A GRAPH NAME")} next to the value returned by this function.
|
||||||
#'
|
#'
|
||||||
#' @examples
|
#' @examples
|
||||||
@ -40,21 +40,29 @@ xgb.plot.importance <-
|
|||||||
stop("Ckmeans.1d.dp package is required for plotting the importance", call. = FALSE)
|
stop("Ckmeans.1d.dp package is required for plotting the importance", call. = FALSE)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(isTRUE(all.equal(colnames(importance_matrix), c("Feature", "Gain", "Cover", "Frequency")))){
|
||||||
|
y.axe.name <- "Gain"
|
||||||
|
} else if(isTRUE(all.equal(colnames(importance_matrix), c("Feature", "Weight")))){
|
||||||
|
y.axe.name <- "Weight"
|
||||||
|
} else {
|
||||||
|
stop("Importance matrix is not correct (column names issue)")
|
||||||
|
}
|
||||||
|
|
||||||
# To avoid issues in clustering when co-occurences are used
|
# To avoid issues in clustering when co-occurences are used
|
||||||
importance_matrix <-
|
importance_matrix <-
|
||||||
importance_matrix[, .(Gain = sum(Gain)), by = Feature]
|
importance_matrix[, .(Gain.or.Weight = sum(get(y.axe.name))), by = Feature]
|
||||||
|
|
||||||
clusters <-
|
clusters <-
|
||||||
suppressWarnings(Ckmeans.1d.dp::Ckmeans.1d.dp(importance_matrix[,Gain], numberOfClusters))
|
suppressWarnings(Ckmeans.1d.dp::Ckmeans.1d.dp(importance_matrix[,Gain.or.Weight], numberOfClusters))
|
||||||
importance_matrix[,"Cluster":= clusters$cluster %>% as.character]
|
importance_matrix[,"Cluster":= clusters$cluster %>% as.character]
|
||||||
|
|
||||||
plot <-
|
plot <-
|
||||||
ggplot2::ggplot(
|
ggplot2::ggplot(
|
||||||
importance_matrix, ggplot2::aes(
|
importance_matrix, ggplot2::aes(
|
||||||
x = stats::reorder(Feature, Gain), y = Gain, width = 0.05
|
x = stats::reorder(Feature, Gain.or.Weight), y = Gain.or.Weight, width = 0.05
|
||||||
), environment = environment()
|
), environment = environment()
|
||||||
) + ggplot2::geom_bar(ggplot2::aes(fill = Cluster), stat = "identity", position =
|
) + ggplot2::geom_bar(ggplot2::aes(fill = Cluster), stat = "identity", position =
|
||||||
"identity") + ggplot2::coord_flip() + ggplot2::xlab("Features") + ggplot2::ylab("Gain") + ggplot2::ggtitle("Feature importance") + ggplot2::theme(
|
"identity") + ggplot2::coord_flip() + ggplot2::xlab("Features") + ggplot2::ylab(y.axe.name) + ggplot2::ggtitle("Feature importance") + ggplot2::theme(
|
||||||
plot.title = ggplot2::element_text(lineheight = .9, face = "bold"), panel.grid.major.y = ggplot2::element_blank()
|
plot.title = ggplot2::element_text(lineheight = .9, face = "bold"), panel.grid.major.y = ggplot2::element_blank()
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -66,6 +74,6 @@ xgb.plot.importance <-
|
|||||||
# They are mainly column names inferred by Data.table...
|
# They are mainly column names inferred by Data.table...
|
||||||
globalVariables(
|
globalVariables(
|
||||||
c(
|
c(
|
||||||
"Feature", "Gain", "Cluster", "ggplot", "aes", "geom_bar", "coord_flip", "xlab", "ylab", "ggtitle", "theme", "element_blank", "element_text"
|
"Feature", "Gain.or.Weight", "Cluster", "ggplot", "aes", "geom_bar", "coord_flip", "xlab", "ylab", "ggtitle", "theme", "element_blank", "element_text", "Gain.or.Weight"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@ -15,11 +15,11 @@ xgb.plot.importance(importance_matrix = NULL, numberOfClusters = c(1:10))
|
|||||||
A \code{ggplot2} bar graph representing each feature by a horizontal bar. Longer is the bar, more important is the feature. Features are classified by importance and clustered by importance. The group is represented through the color of the bar.
|
A \code{ggplot2} bar graph representing each feature by a horizontal bar. Longer is the bar, more important is the feature. Features are classified by importance and clustered by importance. The group is represented through the color of the bar.
|
||||||
}
|
}
|
||||||
\description{
|
\description{
|
||||||
Read a data.table containing feature importance details and plot it.
|
Read a data.table containing feature importance details and plot it (for both GLM and Trees).
|
||||||
}
|
}
|
||||||
\details{
|
\details{
|
||||||
The purpose of this function is to easily represent the importance of each feature of a model.
|
The purpose of this function is to easily represent the importance of each feature of a model.
|
||||||
The function return a ggplot graph, therefore each of its characteristic can be overriden (to customize it).
|
The function returns a ggplot graph, therefore each of its characteristic can be overriden (to customize it).
|
||||||
In particular you may want to override the title of the graph. To do so, add \code{+ ggtitle("A GRAPH NAME")} next to the value returned by this function.
|
In particular you may want to override the title of the graph. To do so, add \code{+ ggtitle("A GRAPH NAME")} next to the value returned by this function.
|
||||||
}
|
}
|
||||||
\examples{
|
\examples{
|
||||||
|
|||||||
@ -14,50 +14,55 @@ df[,AgeCat := as.factor(ifelse(Age > 30, "Old", "Young"))]
|
|||||||
df[,ID := NULL]
|
df[,ID := NULL]
|
||||||
sparse_matrix <- sparse.model.matrix(Improved~.-1, data = df)
|
sparse_matrix <- sparse.model.matrix(Improved~.-1, data = df)
|
||||||
output_vector <- df[,Y := 0][Improved == "Marked",Y := 1][,Y]
|
output_vector <- df[,Y := 0][Improved == "Marked",Y := 1][,Y]
|
||||||
bst <- xgboost(data = sparse_matrix, label = output_vector, max.depth = 9,
|
bst.Tree <- xgboost(data = sparse_matrix, label = output_vector, max.depth = 9,
|
||||||
eta = 1, nthread = 2, nround = 10, objective = "binary:logistic")
|
eta = 1, nthread = 2, nround = 10, objective = "binary:logistic", booster = "gbtree")
|
||||||
|
|
||||||
|
bst.GLM <- xgboost(data = sparse_matrix, label = output_vector,
|
||||||
|
eta = 1, nthread = 2, nround = 10, objective = "binary:logistic", booster = "gblinear")
|
||||||
|
|
||||||
feature.names <- agaricus.train$data@Dimnames[[2]]
|
feature.names <- agaricus.train$data@Dimnames[[2]]
|
||||||
|
|
||||||
test_that("xgb.dump works", {
|
test_that("xgb.dump works", {
|
||||||
capture.output(print(xgb.dump(bst)))
|
capture.output(print(xgb.dump(bst.Tree)))
|
||||||
expect_true(xgb.dump(bst, 'xgb.model.dump', with.stats = T))
|
capture.output(print(xgb.dump(bst.GLM)))
|
||||||
|
expect_true(xgb.dump(bst.Tree, 'xgb.model.dump', with.stats = T))
|
||||||
})
|
})
|
||||||
|
|
||||||
test_that("xgb.model.dt.tree works with and without feature names", {
|
test_that("xgb.model.dt.tree works with and without feature names", {
|
||||||
names.dt.trees <- c("ID", "Feature", "Split", "Yes", "No", "Missing", "Quality", "Cover",
|
names.dt.trees <- c("ID", "Feature", "Split", "Yes", "No", "Missing", "Quality", "Cover",
|
||||||
"Tree", "Yes.Feature", "Yes.Cover", "Yes.Quality", "No.Feature", "No.Cover", "No.Quality")
|
"Tree", "Yes.Feature", "Yes.Cover", "Yes.Quality", "No.Feature", "No.Cover", "No.Quality")
|
||||||
dt.tree <- xgb.model.dt.tree(feature_names = feature.names, model = bst)
|
dt.tree <- xgb.model.dt.tree(feature_names = feature.names, model = bst.Tree)
|
||||||
expect_equal(names.dt.trees, names(dt.tree))
|
expect_equal(names.dt.trees, names(dt.tree))
|
||||||
expect_equal(dim(dt.tree), c(162, 15))
|
expect_equal(dim(dt.tree), c(162, 15))
|
||||||
xgb.model.dt.tree(model = bst)
|
xgb.model.dt.tree(model = bst.Tree)
|
||||||
})
|
})
|
||||||
|
|
||||||
test_that("xgb.importance works with and without feature names", {
|
test_that("xgb.importance works with and without feature names", {
|
||||||
importance <- xgb.importance(feature_names = sparse_matrix@Dimnames[[2]], model = bst)
|
importance.Tree <- xgb.importance(feature_names = sparse_matrix@Dimnames[[2]], model = bst.Tree)
|
||||||
expect_equal(dim(importance), c(7, 4))
|
expect_equal(dim(importance.Tree), c(7, 4))
|
||||||
expect_equal(colnames(importance), c("Feature", "Gain", "Cover", "Frequency"))
|
expect_equal(colnames(importance.Tree), c("Feature", "Gain", "Cover", "Frequency"))
|
||||||
xgb.importance(model = bst)
|
xgb.importance(model = bst.Tree)
|
||||||
|
xgb.plot.importance(importance_matrix = importance.Tree)
|
||||||
})
|
})
|
||||||
|
|
||||||
test_that("xgb.importance works with GLM model", {
|
test_that("xgb.importance works with GLM model", {
|
||||||
bst.GLM <- xgboost(data = sparse_matrix, label = output_vector,
|
|
||||||
eta = 1, nthread = 2, nround = 10, objective = "binary:logistic", booster = "gblinear")
|
|
||||||
importance.GLM <- xgb.importance(feature_names = sparse_matrix@Dimnames[[2]], model = bst.GLM)
|
importance.GLM <- xgb.importance(feature_names = sparse_matrix@Dimnames[[2]], model = bst.GLM)
|
||||||
expect_equal(dim(importance.GLM), c(10, 2))
|
expect_equal(dim(importance.GLM), c(10, 2))
|
||||||
expect_equal(colnames(importance.GLM), c("Feature", "Weight"))
|
expect_equal(colnames(importance.GLM), c("Feature", "Weight"))
|
||||||
xgb.importance(model = bst.GLM)
|
xgb.importance(model = bst.GLM)
|
||||||
|
xgb.plot.importance(importance.GLM)
|
||||||
})
|
})
|
||||||
|
|
||||||
test_that("xgb.plot.tree works with and without feature names", {
|
test_that("xgb.plot.tree works with and without feature names", {
|
||||||
xgb.plot.tree(feature_names = feature.names, model = bst)
|
xgb.plot.tree(feature_names = feature.names, model = bst.Tree)
|
||||||
xgb.plot.tree(model = bst)
|
xgb.plot.tree(model = bst.Tree)
|
||||||
})
|
})
|
||||||
|
|
||||||
test_that("xgb.plot.multi.trees works with and without feature names", {
|
test_that("xgb.plot.multi.trees works with and without feature names", {
|
||||||
xgb.plot.multi.trees(model = bst, feature_names = feature.names, features.keep = 3)
|
xgb.plot.multi.trees(model = bst.Tree, feature_names = feature.names, features.keep = 3)
|
||||||
xgb.plot.multi.trees(model = bst, features.keep = 3)
|
xgb.plot.multi.trees(model = bst.Tree, features.keep = 3)
|
||||||
})
|
})
|
||||||
|
|
||||||
test_that("xgb.plot.deepness works", {
|
test_that("xgb.plot.deepness works", {
|
||||||
xgb.plot.deepness(model = bst)
|
xgb.plot.deepness(model = bst.Tree)
|
||||||
})
|
})
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user