add new function to read model and use it in the plot function
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
#' Plot a boosted tree model
|
||||
#'
|
||||
#' Read a xgboost model text dump.
|
||||
#' Read a tree model text dump.
|
||||
#' Plotting only works for boosted tree model (not linear model).
|
||||
#'
|
||||
#' @importFrom data.table data.table
|
||||
@@ -21,7 +21,7 @@
|
||||
#' @param n_first_tree limit the plot to the n first trees. If \code{NULL}, all trees of the model are plotted. Performance can be low for huge models.
|
||||
#' @param style a \code{character} vector storing a css style to customize the appearance of nodes. Look at the \href{https://github.com/knsv/mermaid/wiki}{Mermaid wiki} for more information.
|
||||
#'
|
||||
#' @return A \code{data.table} of the features used in the model with their average gain (and their weight for boosted tree model) in the model.
|
||||
#' @return A \code{DiagrammeR} of the model.
|
||||
#'
|
||||
#' @details
|
||||
#'
|
||||
@@ -34,7 +34,7 @@
|
||||
#' }
|
||||
#'
|
||||
#' Each branch finishes with a leaf. For each leaf, only the \code{cover} is indicated.
|
||||
#' It uses Mermaid JS library for that purpose.
|
||||
#' It uses \href{https://github.com/knsv/mermaid/}{Mermaid} library for that purpose.
|
||||
#'
|
||||
#' @examples
|
||||
#' data(agaricus.train, package='xgboost')
|
||||
@@ -51,63 +51,13 @@
|
||||
#' xgb.plot.tree(agaricus.train$data@@Dimnames[[2]], 'xgb.model.dump')
|
||||
#'
|
||||
#' @export
|
||||
xgb.plot.tree <- function(feature_names = NULL, filename_dump = NULL, n_first_tree = NULL, styles = NULL){
|
||||
|
||||
if (!class(feature_names) %in% c("character", "NULL")) {
|
||||
stop("feature_names: Has to be a vector of character or NULL if the model dump already contains feature name. Look at this function documentation to see where to get feature names.")
|
||||
}
|
||||
if (class(filename_dump) != "character" || !file.exists(filename_dump)) {
|
||||
stop("filename_dump: Has to be a path to the model dump file.")
|
||||
}
|
||||
if (!class(n_first_tree) %in% c("numeric", "NULL") | length(n_first_tree) > 1) {
|
||||
stop("n_first_tree: Has to be a numeric vector of size 1.")
|
||||
}
|
||||
xgb.plot.tree <- function(feature_names = NULL, filename_dump = NULL, n_first_tree = NULL, styles = NULL){
|
||||
|
||||
if (!class(styles) %in% c("character", "NULL") | length(styles) > 1) {
|
||||
stop("style: Has to be a character vector of size 1.")
|
||||
}
|
||||
|
||||
text <- readLines(filename_dump) %>% str_trim(side = "both")
|
||||
position <- str_match(text, "booster") %>% is.na %>% not %>% which %>% c(length(text)+1)
|
||||
|
||||
extract <- function(x, pattern) str_extract(x, pattern) %>% str_split("=") %>% lapply(function(x) x[2] %>% as.numeric) %>% unlist
|
||||
|
||||
n_round <- min(length(position) - 1, n_first_tree)
|
||||
|
||||
addTreeId <- function(x, i) paste(i,x,sep = "-")
|
||||
|
||||
allTrees <- data.table()
|
||||
|
||||
for(i in 1:n_round){
|
||||
|
||||
tree <- text[(position[i]+1):(position[i+1]-1)]
|
||||
|
||||
notLeaf <- str_match(tree, "leaf") %>% is.na
|
||||
leaf <- notLeaf %>% not %>% tree[.]
|
||||
branch <- notLeaf %>% tree[.]
|
||||
idBranch <- str_extract(branch, "\\d*:") %>% str_replace(":", "") %>% addTreeId(i)
|
||||
idLeaf <- str_extract(leaf, "\\d*:") %>% str_replace(":", "") %>% addTreeId(i)
|
||||
featureBranch <- str_extract(branch, "f\\d*<") %>% str_replace("<", "") %>% str_replace("f", "") %>% as.numeric
|
||||
if(!is.null(feature_names)){
|
||||
featureBranch <- feature_names[featureBranch + 1]
|
||||
}
|
||||
featureLeaf <- rep("Leaf", length(leaf))
|
||||
splitBranch <- str_extract(branch, "<\\d*\\.*\\d*\\]") %>% str_replace("<", "") %>% str_replace("\\]", "")
|
||||
splitLeaf <- rep(NA, length(leaf))
|
||||
yesBranch <- extract(branch, "yes=\\d*") %>% addTreeId(i)
|
||||
yesLeaf <- rep(NA, length(leaf))
|
||||
noBranch <- extract(branch, "no=\\d*") %>% addTreeId(i)
|
||||
noLeaf <- rep(NA, length(leaf))
|
||||
missingBranch <- extract(branch, "missing=\\d+") %>% addTreeId(i)
|
||||
missingLeaf <- rep(NA, length(leaf))
|
||||
qualityBranch <- extract(branch, "gain=\\d*\\.*\\d*")
|
||||
qualityLeaf <- extract(leaf, "leaf=\\-*\\d*\\.*\\d*")
|
||||
coverBranch <- extract(branch, "cover=\\d*\\.*\\d*")
|
||||
coverLeaf <- extract(leaf, "cover=\\d*\\.*\\d*")
|
||||
dt <- data.table(ID = c(idBranch, idLeaf), Feature = c(featureBranch, featureLeaf), Split = c(splitBranch, splitLeaf), Yes = c(yesBranch, yesLeaf), No = c(noBranch, noLeaf), Missing = c(missingBranch, missingLeaf), Quality = c(qualityBranch, qualityLeaf), Cover = c(coverBranch, coverLeaf))[order(ID)][,Tree:=i]
|
||||
|
||||
allTrees <- rbindlist(list(allTrees, dt), use.names = T, fill = F)
|
||||
}
|
||||
allTrees <- xgb.model.dt.tree(feature_names, filename_dump, n_first_tree)
|
||||
|
||||
set(allTrees, i = which(allTrees[,Feature]!= "Leaf"), j = "YesFeature", value = merge(copy(allTrees)[,ID:=Yes][, .(ID)], allTrees[,.(ID, Feature, Quality, Cover)], by = "ID")[,paste(Feature, "<br/>Cover: ", Cover, sep = "")])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user