From 7cb34e3ad678200d8b2dc47b702d70601b41c6f6 Mon Sep 17 00:00:00 2001 From: unknown Date: Sat, 7 Nov 2015 22:24:37 +0100 Subject: [PATCH] Fix some bug + improve display + code clean --- R-package/R/xgb.plot.tree.R | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/R-package/R/xgb.plot.tree.R b/R-package/R/xgb.plot.tree.R index 10ca42bc7..63bebf6cf 100644 --- a/R-package/R/xgb.plot.tree.R +++ b/R-package/R/xgb.plot.tree.R @@ -4,16 +4,13 @@ #' Plotting only works for boosted tree model (not linear model). #' #' @importFrom data.table data.table -#' @importFrom data.table set -#' @importFrom data.table rbindlist #' @importFrom data.table := -#' @importFrom data.table copy #' @importFrom magrittr %>% #' @param feature_names names of each feature as a character vector. Can be extracted from a sparse matrix (see example). If model dump already contains feature names, this argument should be \code{NULL}. #' @param filename_dump the path to the text file storing the model. Model dump must include the gain per feature and per tree (parameter \code{with.stats = T} in function \code{xgb.dump}). Possible to provide a model directly (see \code{model} argument). #' @param model generated by the \code{xgb.train} function. Avoid the creation of a dump file. #' @param n_first_tree limit the plot to the n first trees. If \code{NULL}, all trees of the model are plotted. Performance can be low for huge models. -#' @param width the width of the diagram in pixels. +#' @param width the width of the diagram in pixels. #' @param height the height of the diagram in pixels. #' #' @return A \code{DiagrammeR} of the model. @@ -62,22 +59,18 @@ xgb.plot.tree <- function(feature_names = NULL, filename_dump = NULL, model = NU allTrees <- xgb.model.dt.tree(feature_names = feature_names, model = model, n_first_tree = n_first_tree) } - allTrees[Feature != "Leaf" ,yesPath := paste(ID,"(", Feature, "
Cover: ", Cover, "
Gain: ", Quality, ")-->|< ", Split, "|", Yes, ">", Yes.Feature, "]", sep = "")] - - allTrees[Feature != "Leaf" ,noPath := paste(ID,"(", Feature, ")-->|>= ", Split, "|", No, ">", No.Feature, "]", sep = "")] - allTrees[, label:= paste0(Feature, "\nCover: ", Cover, "\nGain: ", Quality)] allTrees[, shape:= "rectangle"][Feature == "Leaf", shape:= "oval"] allTrees[, filledcolor:= "Beige"][Feature == "Leaf", filledcolor:= "Khaki"] - nodes <- DiagrammeR::create_nodes(nodes = allTrees[,ID], - label = allTrees[,label], - #type = c("lower", "lower", "upper", "upper"), + # rev is used to put the first tree on top. + nodes <- DiagrammeR::create_nodes(nodes = allTrees[,ID] %>% rev, + label = allTrees[,label] %>% rev, style = "filled", color = "DimGray", - fillcolor= allTrees[,filledcolor], - shape = allTrees[,shape], - data = allTrees[,Feature], + fillcolor= allTrees[,filledcolor] %>% rev, + shape = allTrees[,shape] %>% rev, + data = allTrees[,Feature] %>% rev, fontname = "Helvetica" ) @@ -100,4 +93,4 @@ xgb.plot.tree <- function(feature_names = NULL, filename_dump = NULL, model = NU # Avoid error messages during CRAN check. # The reason is that these variables are never declared # They are mainly column names inferred by Data.table... -globalVariables(c("Feature", "ID", "Cover", "Quality", "Split", "Yes", "No", ".", "shape", "filledcolor")) +globalVariables(c("Feature", "ID", "Cover", "Quality", "Split", "Yes", "No", ".", "shape", "filledcolor", "label"))