[R] parameter style consistency

This commit is contained in:
Vadim Khotilovich 2016-06-27 01:58:03 -05:00
parent 56bd442b31
commit e9eb34fabc
8 changed files with 63 additions and 94 deletions

View File

@ -2,12 +2,9 @@
#' #'
#' May improve the learning by adding new features to the training data based on the decision trees from a previously learned model. #' May improve the learning by adding new features to the training data based on the decision trees from a previously learned model.
#' #'
#' @importFrom magrittr %>%
#' @importFrom Matrix cBind
#' @importFrom Matrix sparse.model.matrix
#'
#' @param model decision tree boosting model learned on the original data #' @param model decision tree boosting model learned on the original data
#' @param training.data original data (usually provided as a \code{dgCMatrix} matrix) #' @param data original data (usually provided as a \code{dgCMatrix} matrix)
#' @param ... currently not used
#' #'
#' @return \code{dgCMatrix} matrix including both the original data and the new features. #' @return \code{dgCMatrix} matrix including both the original data and the new features.
#' #'
@ -54,7 +51,7 @@
#' dtrain <- xgb.DMatrix(data = agaricus.train$data, label = agaricus.train$label) #' dtrain <- xgb.DMatrix(data = agaricus.train$data, label = agaricus.train$label)
#' dtest <- xgb.DMatrix(data = agaricus.test$data, label = agaricus.test$label) #' dtest <- xgb.DMatrix(data = agaricus.test$data, label = agaricus.test$label)
#' #'
#' param <- list(max.depth=2, eta=1, silent=1, objective='binary:logistic') #' param <- list(max_depth=2, eta=1, silent=1, objective='binary:logistic')
#' nround = 4 #' nround = 4
#' #'
#' bst = xgb.train(params = param, data = dtrain, nrounds = nround, nthread = 2) #' bst = xgb.train(params = param, data = dtrain, nrounds = nround, nthread = 2)
@ -79,13 +76,14 @@
#' cat(paste("The accuracy was", accuracy.before, "before adding leaf features and it is now", accuracy.after, "!\n")) #' cat(paste("The accuracy was", accuracy.before, "before adding leaf features and it is now", accuracy.after, "!\n"))
#' #'
#' @export #' @export
xgb.create.features <- function(model, training.data){ xgb.create.features <- function(model, data, ...){
pred_with_leaf = predict(model, training.data, predleaf = TRUE) check.deprecation(...)
pred_with_leaf = predict(model, data, predleaf = TRUE)
cols <- list() cols <- list()
for(i in 1:length(trees)){ for(i in 1:length(trees)){
# max is not the real max but it s not important for the purpose of adding features # max is not the real max but it s not important for the purpose of adding features
leaf.id <- sort(unique(pred_with_leaf[,i])) leaf_id <- sort(unique(pred_with_leaf[,i]))
cols[[i]] <- factor(x = pred_with_leaf[,i], level = leaf.id) cols[[i]] <- factor(x = pred_with_leaf[,i], level = leaf_id)
} }
cBind(training.data, sparse.model.matrix( ~ . -1, as.data.frame(cols))) cBind(data, sparse.model.matrix( ~ . -1, as.data.frame(cols)))
} }

View File

@ -2,14 +2,6 @@
#' #'
#' Create a \code{data.table} of the most important features of a model. #' Create a \code{data.table} of the most important features of a model.
#' #'
#' @importFrom data.table data.table
#' @importFrom data.table setnames
#' @importFrom data.table :=
#' @importFrom magrittr %>%
#' @importFrom Matrix colSums
#' @importFrom Matrix cBind
#' @importFrom Matrix sparseVector
#'
#' @param feature_names names of each feature as a \code{character} vector. Can be extracted from a sparse matrix (see example). If model dump already contains feature names, this argument should be \code{NULL}. #' @param feature_names names of each feature as a \code{character} vector. Can be extracted from a sparse matrix (see example). If model dump already contains feature names, this argument should be \code{NULL}.
#' @param model generated by the \code{xgb.train} function. #' @param model generated by the \code{xgb.train} function.
#' @param data the dataset used for the training step. Will be used with \code{label} parameter for co-occurence computation. More information in \code{Detail} part. This parameter is optional. #' @param data the dataset used for the training step. Will be used with \code{label} parameter for co-occurence computation. More information in \code{Detail} part. This parameter is optional.
@ -46,14 +38,13 @@
#' @examples #' @examples
#' data(agaricus.train, package='xgboost') #' data(agaricus.train, package='xgboost')
#' #'
#' bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, max.depth = 2, #' bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, max_depth = 2,
#' eta = 1, nthread = 2, nround = 2,objective = "binary:logistic") #' eta = 1, nthread = 2, nrounds = 2,objective = "binary:logistic")
#' #'
#' # agaricus.train$data@@Dimnames[[2]] represents the column names of the sparse matrix. #' xgb.importance(colnames(agaricus.train$data), model = bst)
#' xgb.importance(agaricus.train$data@@Dimnames[[2]], model = bst)
#' #'
#' # Same thing with co-occurence computation this time #' # Same thing with co-occurence computation this time
#' xgb.importance(agaricus.train$data@@Dimnames[[2]], model = bst, data = agaricus.train$data, label = agaricus.train$label) #' xgb.importance(colnames(agaricus.train$data), model = bst, data = agaricus.train$data, label = agaricus.train$label)
#' #'
#' @export #' @export
xgb.importance <- function(feature_names = NULL, model = NULL, data = NULL, label = NULL, target = function(x) ( (x + label) == 2)){ xgb.importance <- function(feature_names = NULL, model = NULL, data = NULL, label = NULL, target = function(x) ( (x + label) == 2)){
@ -84,7 +75,7 @@ xgb.importance <- function(feature_names = NULL, model = NULL, data = NULL, labe
data.table(Feature = feature_names, Weight = weights) data.table(Feature = feature_names, Weight = weights)
} }
model.text.dump <- xgb.dump(model = model, with.stats = T) model.text.dump <- xgb.dump(model = model, with_stats = T)
if(model.text.dump[2] == "bias:"){ if(model.text.dump[2] == "bias:"){
result <- model.text.dump %>% linearDump(feature_names, .) result <- model.text.dump %>% linearDump(feature_names, .)

View File

@ -9,8 +9,8 @@
#' data(agaricus.test, package='xgboost') #' data(agaricus.test, package='xgboost')
#' train <- agaricus.train #' train <- agaricus.train
#' test <- agaricus.test #' test <- agaricus.test
#' bst <- xgboost(data = train$data, label = train$label, max.depth = 2, #' bst <- xgboost(data = train$data, label = train$label, max_depth = 2,
#' eta = 1, nthread = 2, nround = 2,objective = "binary:logistic") #' eta = 1, nthread = 2, nrounds = 2,objective = "binary:logistic")
#' xgb.save(bst, 'xgb.model') #' xgb.save(bst, 'xgb.model')
#' bst <- xgb.load('xgb.model') #' bst <- xgb.load('xgb.model')
#' pred <- predict(bst, test$data) #' pred <- predict(bst, test$data)
@ -26,6 +26,6 @@ xgb.load <- function(modelfile) {
} else { } else {
bst <- xgb.handleToBooster(handle, NULL) bst <- xgb.handleToBooster(handle, NULL)
} }
bst <- xgb.Booster.check(bst) bst <- xgb.Booster.check(bst, saveraw = TRUE)
return(bst) return(bst)
} }

View File

@ -2,16 +2,11 @@
#' #'
#' Parse a boosted tree model text dump into a \code{data.table} structure. #' Parse a boosted tree model text dump into a \code{data.table} structure.
#' #'
#' @importFrom data.table data.table
#' @importFrom data.table :=
#' @importFrom magrittr %>%
#' @importFrom stringr str_match
#'
#' @param feature_names character vector of feature names. If the model already #' @param feature_names character vector of feature names. If the model already
#' contains feature names, this argument should be \code{NULL} (default value) #' contains feature names, this argument should be \code{NULL} (default value)
#' @param model object of class \code{xgb.Booster} #' @param model object of class \code{xgb.Booster}
#' @param text \code{character} vector previously generated by the \code{xgb.dump} #' @param text \code{character} vector previously generated by the \code{xgb.dump}
#' function (where parameter \code{with.stats = TRUE} should have been set). #' function (where parameter \code{with_stats = TRUE} should have been set).
#' @param n_first_tree limit the parsing to the \code{n} first trees. #' @param n_first_tree limit the parsing to the \code{n} first trees.
#' If set to \code{NULL}, all trees of the model are parsed. #' If set to \code{NULL}, all trees of the model are parsed.
#' #'
@ -40,8 +35,8 @@
#' #'
#' data(agaricus.train, package='xgboost') #' data(agaricus.train, package='xgboost')
#' #'
#' bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, max.depth = 2, #' bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, max_depth = 2,
#' eta = 1, nthread = 2, nround = 2,objective = "binary:logistic") #' eta = 1, nthread = 2, nrounds = 2,objective = "binary:logistic")
#' #'
#' (dt <- xgb.model.dt.tree(colnames(agaricus.train$data), bst)) #' (dt <- xgb.model.dt.tree(colnames(agaricus.train$data), bst))
#' #'
@ -71,12 +66,12 @@ xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL,
} }
if(is.null(text)){ if(is.null(text)){
text <- xgb.dump(model = model, with.stats = T) text <- xgb.dump(model = model, with_stats = T)
} }
position <- which(!is.na(str_match(text, "booster"))) position <- which(!is.na(str_match(text, "booster")))
addTreeId <- function(x, i) paste(i,x,sep = "-") add.tree.id <- function(x, i) paste(i, x, sep = "-")
anynumber_regex <- "[-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?" anynumber_regex <- "[-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?"
@ -88,7 +83,7 @@ xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL,
td <- td[Tree <= n_first_tree & !grepl('^booster', t)] td <- td[Tree <= n_first_tree & !grepl('^booster', t)]
td[, Node := str_match(t, "(\\d+):")[,2] %>% as.numeric ] td[, Node := str_match(t, "(\\d+):")[,2] %>% as.numeric ]
td[, ID := addTreeId(Node, Tree)] td[, ID := add.tree.id(Node, Tree)]
td[, isLeaf := !is.na(str_match(t, "leaf"))] td[, isLeaf := !is.na(str_match(t, "leaf"))]
# parse branch lines # parse branch lines
@ -97,7 +92,7 @@ xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL,
"gain=(", anynumber_regex, "),cover=(", anynumber_regex, ")") "gain=(", anynumber_regex, "),cover=(", anynumber_regex, ")")
# skip some indices with spurious capture groups from anynumber_regex # skip some indices with spurious capture groups from anynumber_regex
xtr <- str_match(t, rx)[, c(2,3,5,6,7,8,10)] xtr <- str_match(t, rx)[, c(2,3,5,6,7,8,10)]
xtr[, 3:5] <- addTreeId(xtr[, 3:5], Tree) xtr[, 3:5] <- add.tree.id(xtr[, 3:5], Tree)
lapply(1:ncol(xtr), function(i) xtr[,i]) lapply(1:ncol(xtr), function(i) xtr[,i])
}] }]
# assign feature_names when available # assign feature_names when available

View File

@ -2,7 +2,6 @@
#' #'
#' Plot multiple graph aligned by rows and columns. #' Plot multiple graph aligned by rows and columns.
#' #'
#' @importFrom data.table data.table
#' @param cols number of columns #' @param cols number of columns
#' @return NULL #' @return NULL
multiplot <- function(..., cols = 1) { multiplot <- function(..., cols = 1) {
@ -42,18 +41,18 @@ edge.parser <- function(element) {
#' Extract path from root to leaf from data.table #' Extract path from root to leaf from data.table
#' @param dt.tree data.table containing the nodes and edges of the trees #' @param dt.tree data.table containing the nodes and edges of the trees
get.paths.to.leaf <- function(dt.tree) { get.paths.to.leaf <- function(dt_tree) {
dt.not.leaf.edges <- dt.not.leaf.edges <-
dt.tree[Feature != "Leaf",.(ID, Yes, Tree)] %>% list(dt.tree[Feature != "Leaf",.(ID, No, Tree)]) %>% rbindlist(use.names = F) dt_tree[Feature != "Leaf",.(ID, Yes, Tree)] %>% list(dt_tree[Feature != "Leaf",.(ID, No, Tree)]) %>% rbindlist(use.names = F)
trees <- dt.tree[,unique(Tree)] trees <- dt_tree[,unique(Tree)]
paths <- list() paths <- list()
for (tree in trees) { for (tree in trees) {
graph <- graph <-
igraph::graph_from_data_frame(dt.not.leaf.edges[Tree == tree]) igraph::graph_from_data_frame(dt.not.leaf.edges[Tree == tree])
paths.tmp <- paths.tmp <-
igraph::shortest_paths(graph, from = paste0(tree, "-0"), to = dt.tree[Tree == tree & igraph::shortest_paths(graph, from = paste0(tree, "-0"), to = dt_tree[Tree == tree &
Feature == "Leaf", c(ID)]) Feature == "Leaf", c(ID)])
paths <- c(paths, paths.tmp$vpath) paths <- c(paths, paths.tmp$vpath)
} }
@ -64,11 +63,6 @@ get.paths.to.leaf <- function(dt.tree) {
#' #'
#' Generate a graph to plot the distribution of deepness among trees. #' Generate a graph to plot the distribution of deepness among trees.
#' #'
#' @importFrom data.table data.table
#' @importFrom data.table rbindlist
#' @importFrom data.table setnames
#' @importFrom data.table :=
#' @importFrom magrittr %>%
#' @param model dump generated by the \code{xgb.train} function. #' @param model dump generated by the \code{xgb.train} function.
#' #'
#' @return Two graphs showing the distribution of the model deepness. #' @return Two graphs showing the distribution of the model deepness.
@ -78,7 +72,7 @@ get.paths.to.leaf <- function(dt.tree) {
#' by tree deepness level. #' by tree deepness level.
#' #'
#' The purpose of this function is to help the user to find the best trade-off to set #' The purpose of this function is to help the user to find the best trade-off to set
#' the \code{max.depth} and \code{min_child_weight} parameters according to the bias / variance trade-off. #' the \code{max_depth} and \code{min_child_weight} parameters according to the bias / variance trade-off.
#' #'
#' See \link{xgb.train} for more information about these parameters. #' See \link{xgb.train} for more information about these parameters.
#' #'
@ -94,8 +88,8 @@ get.paths.to.leaf <- function(dt.tree) {
#' @examples #' @examples
#' data(agaricus.train, package='xgboost') #' data(agaricus.train, package='xgboost')
#' #'
#' bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, max.depth = 15, #' bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, max_depth = 15,
#' eta = 1, nthread = 2, nround = 30, objective = "binary:logistic", #' eta = 1, nthread = 2, nrounds = 30, objective = "binary:logistic",
#' min_child_weight = 50) #' min_child_weight = 50)
#' #'
#' xgb.plot.deepness(model = bst) #' xgb.plot.deepness(model = bst)

View File

@ -2,9 +2,9 @@
#' #'
#' Read a data.table containing feature importance details and plot it (for both GLM and Trees). #' Read a data.table containing feature importance details and plot it (for both GLM and Trees).
#' #'
#' @importFrom magrittr %>%
#' @param importance_matrix a \code{data.table} returned by the \code{xgb.importance} function. #' @param importance_matrix a \code{data.table} returned by the \code{xgb.importance} function.
#' @param numberOfClusters a \code{numeric} vector containing the min and the max range of the possible number of clusters of bars. #' @param n_clusters a \code{numeric} vector containing the min and the max range of the possible number of clusters of bars.
#' @param ... currently not used
#' #'
#' @return A \code{ggplot2} bar graph representing each feature by a horizontal bar. Longer is the bar, more important is the feature. Features are classified by importance and clustered by importance. The group is represented through the color of the bar. #' @return A \code{ggplot2} bar graph representing each feature by a horizontal bar. Longer is the bar, more important is the feature. Features are classified by importance and clustered by importance. The group is represented through the color of the bar.
#' #'
@ -20,16 +20,16 @@
#' #(labels = outcome column which will be learned). #' #(labels = outcome column which will be learned).
#' #Each column of the sparse Matrix is a feature in one hot encoding format. #' #Each column of the sparse Matrix is a feature in one hot encoding format.
#' #'
#' bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, max.depth = 2, #' bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, max_depth = 2,
#' eta = 1, nthread = 2, nround = 2,objective = "binary:logistic") #' eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic")
#' #'
#' #agaricus.train$data@@Dimnames[[2]] represents the column names of the sparse matrix. #' importance_matrix <- xgb.importance(colnames(agaricus.train$data), model = bst)
#' importance_matrix <- xgb.importance(agaricus.train$data@@Dimnames[[2]], model = bst)
#' xgb.plot.importance(importance_matrix) #' xgb.plot.importance(importance_matrix)
#' #'
#' @export #' @export
xgb.plot.importance <- xgb.plot.importance <-
function(importance_matrix = NULL, numberOfClusters = c(1:10)) { function(importance_matrix = NULL, n_clusters = c(1:10), ...) {
check.deprecation(...)
if (!"data.table" %in% class(importance_matrix)) { if (!"data.table" %in% class(importance_matrix)) {
stop("importance_matrix: Should be a data.table.") stop("importance_matrix: Should be a data.table.")
} }
@ -53,7 +53,7 @@ xgb.plot.importance <-
importance_matrix[, .(Gain.or.Weight = sum(get(y.axe.name))), by = Feature] importance_matrix[, .(Gain.or.Weight = sum(get(y.axe.name))), by = Feature]
clusters <- clusters <-
suppressWarnings(Ckmeans.1d.dp::Ckmeans.1d.dp(importance_matrix[,Gain.or.Weight], numberOfClusters)) suppressWarnings(Ckmeans.1d.dp::Ckmeans.1d.dp(importance_matrix[,Gain.or.Weight], n_clusters))
importance_matrix[,"Cluster":= clusters$cluster %>% as.character] importance_matrix[,"Cluster":= clusters$cluster %>% as.character]
plot <- plot <-

View File

@ -2,19 +2,12 @@
#' #'
#' Visualization of the ensemble of trees as a single collective unit. #' Visualization of the ensemble of trees as a single collective unit.
#' #'
#' @importFrom data.table data.table
#' @importFrom data.table rbindlist
#' @importFrom data.table setnames
#' @importFrom data.table :=
#' @importFrom magrittr %>%
#' @importFrom stringr str_detect
#' @importFrom stringr str_extract
#'
#' @param model dump generated by the \code{xgb.train} function. #' @param model dump generated by the \code{xgb.train} function.
#' @param feature_names names of each feature as a \code{character} vector. Can be extracted from a sparse matrix (see example). If model dump already contains feature names, this argument should be \code{NULL}. #' @param feature_names names of each feature as a \code{character} vector. Can be extracted from a sparse matrix (see example). If model dump already contains feature names, this argument should be \code{NULL}.
#' @param features.keep number of features to keep in each position of the multi trees. #' @param features_keep number of features to keep in each position of the multi trees.
#' @param plot.width width in pixels of the graph to produce #' @param plot_width width in pixels of the graph to produce
#' @param plot.height height in pixels of the graph to produce #' @param plot_height height in pixels of the graph to produce
#' @param ... currently not used
#' #'
#' @return Two graphs showing the distribution of the model deepness. #' @return Two graphs showing the distribution of the model deepness.
#' #'
@ -34,7 +27,7 @@
#' Moreover, the trees tend to reuse the same features. #' Moreover, the trees tend to reuse the same features.
#' #'
#' The function will project each tree on one, and keep for each position the #' The function will project each tree on one, and keep for each position the
#' \code{features.keep} first features (based on Gain per feature measure). #' \code{features_keep} first features (based on Gain per feature measure).
#' #'
#' This function is inspired by this blog post: #' This function is inspired by this blog post:
#' \url{https://wellecks.wordpress.com/2015/02/21/peering-into-the-black-box-visualizing-lambdamart/} #' \url{https://wellecks.wordpress.com/2015/02/21/peering-into-the-black-box-visualizing-lambdamart/}
@ -42,15 +35,16 @@
#' @examples #' @examples
#' data(agaricus.train, package='xgboost') #' data(agaricus.train, package='xgboost')
#' #'
#' bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, max.depth = 15, #' bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, max_depth = 15,
#' eta = 1, nthread = 2, nround = 30, objective = "binary:logistic", #' eta = 1, nthread = 2, nrounds = 30, objective = "binary:logistic",
#' min_child_weight = 50) #' min_child_weight = 50)
#' #'
#' p <- xgb.plot.multi.trees(model = bst, feature_names = agaricus.train$data@Dimnames[[2]], features.keep = 3) #' p <- xgb.plot.multi.trees(model = bst, feature_names = colnames(agaricus.train$data), features_keep = 3)
#' print(p) #' print(p)
#' #'
#' @export #' @export
xgb.plot.multi.trees <- function(model, feature_names = NULL, features.keep = 5, plot.width = NULL, plot.height = NULL){ xgb.plot.multi.trees <- function(model, feature_names = NULL, features_keep = 5, plot_width = NULL, plot_height = NULL, ...){
check.deprecation(...)
tree.matrix <- xgb.model.dt.tree(feature_names = feature_names, model = model) tree.matrix <- xgb.model.dt.tree(feature_names = feature_names, model = model)
# first number of the path represents the tree, then the following numbers are related to the path to follow # first number of the path represents the tree, then the following numbers are related to the path to follow
@ -80,7 +74,7 @@ xgb.plot.multi.trees <- function(model, feature_names = NULL, features.keep = 5,
tree.matrix[,`:=`(abs.node.position=remove.tree(abs.node.position), Yes=remove.tree(Yes), No=remove.tree(No))] tree.matrix[,`:=`(abs.node.position=remove.tree(abs.node.position), Yes=remove.tree(Yes), No=remove.tree(No))]
nodes.dt <- tree.matrix[,.(Quality = sum(Quality)),by = .(abs.node.position, Feature)][,.(Text =paste0(Feature[1:min(length(Feature), features.keep)], " (", Quality[1:min(length(Quality), features.keep)], ")") %>% paste0(collapse = "\n")), by=abs.node.position] nodes.dt <- tree.matrix[,.(Quality = sum(Quality)),by = .(abs.node.position, Feature)][,.(Text =paste0(Feature[1:min(length(Feature), features_keep)], " (", Quality[1:min(length(Quality), features_keep)], ")") %>% paste0(collapse = "\n")), by=abs.node.position]
edges.dt <- tree.matrix[Feature != "Leaf",.(abs.node.position, Yes)] %>% list(tree.matrix[Feature != "Leaf",.(abs.node.position, No)]) %>% rbindlist() %>% setnames(c("From", "To")) %>% .[,.N,.(From, To)] %>% .[,N:=NULL] edges.dt <- tree.matrix[Feature != "Leaf",.(abs.node.position, Yes)] %>% list(tree.matrix[Feature != "Leaf",.(abs.node.position, No)]) %>% rbindlist() %>% setnames(c("From", "To")) %>% .[,.N,.(From, To)] %>% .[,N:=NULL]
nodes <- DiagrammeR::create_nodes(nodes = nodes.dt[,abs.node.position], nodes <- DiagrammeR::create_nodes(nodes = nodes.dt[,abs.node.position],
@ -104,7 +98,7 @@ xgb.plot.multi.trees <- function(model, feature_names = NULL, features.keep = 5,
edges_df = edges, edges_df = edges,
graph_attrs = "rankdir = LR") graph_attrs = "rankdir = LR")
DiagrammeR::render_graph(graph, width = plot.width, height = plot.height) DiagrammeR::render_graph(graph, width = plot_width, height = plot_height)
} }
globalVariables( globalVariables(

View File

@ -2,14 +2,12 @@
#' #'
#' Read a tree model text dump and plot the model. #' Read a tree model text dump and plot the model.
#' #'
#' @importFrom data.table data.table
#' @importFrom data.table :=
#' @importFrom magrittr %>%
#' @param feature_names names of each feature as a \code{character} vector. Can be extracted from a sparse matrix (see example). If model dump already contains feature names, this argument should be \code{NULL}. #' @param feature_names names of each feature as a \code{character} vector. Can be extracted from a sparse matrix (see example). If model dump already contains feature names, this argument should be \code{NULL}.
#' @param model generated by the \code{xgb.train} function. Avoid the creation of a dump file. #' @param model generated by the \code{xgb.train} function. Avoid the creation of a dump file.
#' @param n_first_tree limit the plot to the n first trees. If \code{NULL}, all trees of the model are plotted. Performance can be low for huge models. #' @param n_first_tree limit the plot to the n first trees. If \code{NULL}, all trees of the model are plotted. Performance can be low for huge models.
#' @param plot.width the width of the diagram in pixels. #' @param plot_width the width of the diagram in pixels.
#' @param plot.height the height of the diagram in pixels. #' @param plot_height the height of the diagram in pixels.
#' @param ... currently not used.
#' #'
#' @return A \code{DiagrammeR} of the model. #' @return A \code{DiagrammeR} of the model.
#' #'
@ -28,15 +26,14 @@
#' @examples #' @examples
#' data(agaricus.train, package='xgboost') #' data(agaricus.train, package='xgboost')
#' #'
#' bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, max.depth = 2, #' bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, max_depth = 2,
#' eta = 1, nthread = 2, nround = 2,objective = "binary:logistic") #' eta = 1, nthread = 2, nrounds = 2,objective = "binary:logistic")
#' #'
#' # agaricus.train$data@@Dimnames[[2]] represents the column names of the sparse matrix. #' xgb.plot.tree(feature_names = colnames(agaricus.train$data), model = bst)
#' xgb.plot.tree(feature_names = agaricus.train$data@@Dimnames[[2]], model = bst)
#' #'
#' @export #' @export
xgb.plot.tree <- function(feature_names = NULL, model = NULL, n_first_tree = NULL, plot.width = NULL, plot.height = NULL){ xgb.plot.tree <- function(feature_names = NULL, model = NULL, n_first_tree = NULL, plot_width = NULL, plot_height = NULL, ...){
check.deprecation(...)
if (class(model) != "xgb.Booster") { if (class(model) != "xgb.Booster") {
stop("model: Has to be an object of class xgb.Booster model generaged by the xgb.train function.") stop("model: Has to be an object of class xgb.Booster model generaged by the xgb.train function.")
} }
@ -75,7 +72,7 @@ xgb.plot.tree <- function(feature_names = NULL, model = NULL, n_first_tree = NUL
edges_df = edges, edges_df = edges,
graph_attrs = "rankdir = LR") graph_attrs = "rankdir = LR")
DiagrammeR::render_graph(graph, width = plot.width, height = plot.height) DiagrammeR::render_graph(graph, width = plot_width, height = plot_height)
} }
# Avoid error messages during CRAN check. # Avoid error messages during CRAN check.