diff --git a/R-package/NAMESPACE b/R-package/NAMESPACE index 23de90d28..d29ad7a18 100644 --- a/R-package/NAMESPACE +++ b/R-package/NAMESPACE @@ -25,6 +25,7 @@ importFrom(data.table,copy) importFrom(data.table,data.table) importFrom(data.table,rbindlist) importFrom(data.table,set) +importFrom(data.table,setnames) importFrom(magrittr,"%>%") importFrom(magrittr,add) importFrom(magrittr,not) diff --git a/R-package/R/xgb.dump.R b/R-package/R/xgb.dump.R index f73850883..61bfe412e 100644 --- a/R-package/R/xgb.dump.R +++ b/R-package/R/xgb.dump.R @@ -47,7 +47,7 @@ xgb.dump <- function(model = NULL, fname = NULL, fmap = "", with.stats=FALSE) { if(is.null(fname)) { return(str_split(result, "\n") %>% unlist %>% str_replace("^\t+","") %>% Filter(function(x) x != "", .)) } else { - writeLines(result, fname) + result %>% str_split("\n") %>% unlist %>% Filter(function(x) x != "", .) %>% writeLines(fname) return(TRUE) } } \ No newline at end of file diff --git a/R-package/R/xgb.importance.R b/R-package/R/xgb.importance.R index eaaad9ab8..189ee03b4 100644 --- a/R-package/R/xgb.importance.R +++ b/R-package/R/xgb.importance.R @@ -4,8 +4,9 @@ #' Can be tree or linear model (text dump of linear model are only supported in dev version of \code{Xgboost} for now). #' #' @importFrom data.table data.table -#' @importFrom magrittr %>% +#' @importFrom data.table setnames #' @importFrom data.table := +#' @importFrom magrittr %>% #' @param feature_names names of each feature as a character vector. Can be extracted from a sparse matrix (see example). If model dump already contains feature names, this argument should be \code{NULL}. #' @param filename_dump the path to the text file storing the model. Model dump must include the gain per feature and per tree (\code{with.stats = T} in function \code{xgb.dump}). #' diff --git a/R-package/R/xgb.model.dt.tree.R b/R-package/R/xgb.model.dt.tree.R index 1fc104cce..3e0723c61 100644 --- a/R-package/R/xgb.model.dt.tree.R +++ b/R-package/R/xgb.model.dt.tree.R @@ -5,6 +5,7 @@ #' @importFrom data.table data.table #' @importFrom data.table set #' @importFrom data.table rbindlist +#' @importFrom data.table copy #' @importFrom data.table := #' @importFrom magrittr %>% #' @importFrom magrittr not @@ -88,11 +89,13 @@ xgb.model.dt.tree <- function(feature_names = NULL, filename_dump = NULL, text = tree <- text[(position[i]+1):(position[i+1]-1)] + treeID <- i-1 + notLeaf <- str_match(tree, "leaf") %>% is.na leaf <- notLeaf %>% not %>% tree[.] branch <- notLeaf %>% tree[.] - idBranch <- str_extract(branch, "\\d*:") %>% str_replace(":", "") %>% addTreeId(i) - idLeaf <- str_extract(leaf, "\\d*:") %>% str_replace(":", "") %>% addTreeId(i) + idBranch <- str_extract(branch, "\\d*:") %>% str_replace(":", "") %>% addTreeId(treeID) + idLeaf <- str_extract(leaf, "\\d*:") %>% str_replace(":", "") %>% addTreeId(treeID) featureBranch <- str_extract(branch, "f\\d*<") %>% str_replace("<", "") %>% str_replace("f", "") %>% as.numeric if(!is.null(feature_names)){ featureBranch <- feature_names[featureBranch + 1] @@ -100,20 +103,48 @@ xgb.model.dt.tree <- function(feature_names = NULL, filename_dump = NULL, text = featureLeaf <- rep("Leaf", length(leaf)) splitBranch <- str_extract(branch, "<\\d*\\.*\\d*\\]") %>% str_replace("<", "") %>% str_replace("\\]", "") splitLeaf <- rep(NA, length(leaf)) - yesBranch <- extract(branch, "yes=\\d*") %>% addTreeId(i) + yesBranch <- extract(branch, "yes=\\d*") %>% addTreeId(treeID) yesLeaf <- rep(NA, length(leaf)) - noBranch <- extract(branch, "no=\\d*") %>% addTreeId(i) + noBranch <- extract(branch, "no=\\d*") %>% addTreeId(treeID) noLeaf <- rep(NA, length(leaf)) - missingBranch <- extract(branch, "missing=\\d+") %>% addTreeId(i) + missingBranch <- extract(branch, "missing=\\d+") %>% addTreeId(treeID) missingLeaf <- rep(NA, length(leaf)) qualityBranch <- extract(branch, "gain=\\d*\\.*\\d*") qualityLeaf <- extract(leaf, "leaf=\\-*\\d*\\.*\\d*") coverBranch <- extract(branch, "cover=\\d*\\.*\\d*") coverLeaf <- extract(leaf, "cover=\\d*\\.*\\d*") - dt <- data.table(ID = c(idBranch, idLeaf), Feature = c(featureBranch, featureLeaf), Split = c(splitBranch, splitLeaf), Yes = c(yesBranch, yesLeaf), No = c(noBranch, noLeaf), Missing = c(missingBranch, missingLeaf), Quality = c(qualityBranch, qualityLeaf), Cover = c(coverBranch, coverLeaf))[order(ID)][,Tree:=i] + dt <- data.table(ID = c(idBranch, idLeaf), Feature = c(featureBranch, featureLeaf), Split = c(splitBranch, splitLeaf), Yes = c(yesBranch, yesLeaf), No = c(noBranch, noLeaf), Missing = c(missingBranch, missingLeaf), Quality = c(qualityBranch, qualityLeaf), Cover = c(coverBranch, coverLeaf))[order(ID)][,Tree:=treeID] allTrees <- rbindlist(list(allTrees, dt), use.names = T, fill = F) } + yes <- allTrees[!is.na(Yes),Yes] + + set(allTrees, i = which(allTrees[,Feature]!= "Leaf"), + j = "Yes.Feature", + value = allTrees[ID == yes,Feature]) + + set(allTrees, i = which(allTrees[,Feature]!= "Leaf"), + j = "Yes.Cover", + value = allTrees[ID == yes,Cover]) + + set(allTrees, i = which(allTrees[,Feature]!= "Leaf"), + j = "Yes.Quality", + value = allTrees[ID == yes,Quality]) + + no <- allTrees[!is.na(No),No] + + set(allTrees, i = which(allTrees[,Feature]!= "Leaf"), + j = "No.Feature", + value = allTrees[ID == no,Feature]) + + set(allTrees, i = which(allTrees[,Feature]!= "Leaf"), + j = "No.Cover", + value = allTrees[ID == no,Cover]) + + set(allTrees, i = which(allTrees[,Feature]!= "Leaf"), + j = "No.Quality", + value = allTrees[ID == no,Quality]) + allTrees } diff --git a/R-package/R/xgb.plot.tree.R b/R-package/R/xgb.plot.tree.R index b980671b0..1a8a04e8a 100644 --- a/R-package/R/xgb.plot.tree.R +++ b/R-package/R/xgb.plot.tree.R @@ -59,13 +59,9 @@ xgb.plot.tree <- function(feature_names = NULL, filename_dump = NULL, n_first_tr allTrees <- xgb.model.dt.tree(feature_names, filename_dump, n_first_tree) - set(allTrees, i = which(allTrees[,Feature]!= "Leaf"), j = "YesFeature", value = merge(copy(allTrees)[,ID:=Yes][, .(ID)], allTrees[,.(ID, Feature, Quality, Cover)], by = "ID")[,paste(Feature, "
Cover: ", Cover, sep = "")]) + allTrees[Feature!="Leaf" ,yesPath:= paste(ID,"(", Feature, "
Cover: ", Cover, "
Gain: ", Quality, ")-->|< ", Split, "|", Yes, ">", Yes.Feature, "]", sep = "")] - set(allTrees, i = which(allTrees[,Feature]!= "Leaf"), j = "NoFeature", value = merge(copy(allTrees)[,ID:=No][, .(ID)], allTrees[,.(ID, Feature, Quality, Cover)], by = "ID")[,paste(Feature, "
Cover: ", Cover, sep = "")]) - - allTrees[Feature!="Leaf" ,yesPath:= paste(ID,"(", Feature, "
Cover: ", Cover, "
Gain: ", Quality, ")-->|< ", Split, "|", Yes, ">", YesFeature, "]", sep = "")] - - allTrees[Feature!="Leaf" ,noPath:= paste(ID,"(", Feature, ")-->|>= ", Split, "|", No, ">", NoFeature, "]", sep = "")] + allTrees[Feature!="Leaf" ,noPath:= paste(ID,"(", Feature, ")-->|>= ", Split, "|", No, ">", No.Feature, "]", sep = "")] if(is.null(styles)){