[R] maintenance Nov 2017; SHAP plots (#2888)
* [R] fix predict contributions for data with no colnames * [R] add a render parameter for xgb.plot.multi.trees; fixes #2628 * [R] update Rd's * [R] remove unnecessary dep-package from R cmake install * silence type warnings; readability * [R] silence complaint about incomplete line at the end * [R] initial version of xgb.plot.shap() * [R] more work on xgb.plot.shap * [R] enforce black font in xgb.plot.tree; fixes #2640 * [R] if feature names are available, check in predict that they are the same; fixes #2857 * [R] cran check and lint fixes * remove tabs * [R] add references; a test for plot.shap
This commit is contained in:
committed by
Tong He
parent
1b77903eeb
commit
e8a6597957
@@ -150,7 +150,7 @@ xgb.Booster.complete <- function(object, saveraw = TRUE) {
|
||||
#' Setting \code{predcontrib = TRUE} allows to calculate contributions of each feature to
|
||||
#' individual predictions. For "gblinear" booster, feature contributions are simply linear terms
|
||||
#' (feature_beta * feature_value). For "gbtree" booster, feature contributions are SHAP
|
||||
#' values (https://arxiv.org/abs/1706.06060) that sum to the difference between the expected output
|
||||
#' values (Lundberg 2017) that sum to the difference between the expected output
|
||||
#' of the model and the current prediction (where the hessian weights are used to compute the expectations).
|
||||
#' Setting \code{approxcontrib = TRUE} approximates these values following the idea explained
|
||||
#' in \url{http://blog.datadive.net/interpreting-random-forests/}.
|
||||
@@ -172,6 +172,12 @@ xgb.Booster.complete <- function(object, saveraw = TRUE) {
|
||||
#'
|
||||
#' @seealso
|
||||
#' \code{\link{xgb.train}}.
|
||||
#'
|
||||
#' @references
|
||||
#'
|
||||
#' Scott M. Lundberg, Su-In Lee, "A Unified Approach to Interpreting Model Predictions", NIPS Proceedings 2017, \url{https://arxiv.org/abs/1705.07874}
|
||||
#'
|
||||
#' Scott M. Lundberg, Su-In Lee, "Consistent feature attribution for tree ensembles", \url{https://arxiv.org/abs/1706.06060}
|
||||
#'
|
||||
#' @examples
|
||||
#' ## binary classification:
|
||||
@@ -265,6 +271,10 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
|
||||
object <- xgb.Booster.complete(object, saveraw = FALSE)
|
||||
if (!inherits(newdata, "xgb.DMatrix"))
|
||||
newdata <- xgb.DMatrix(newdata, missing = missing)
|
||||
if (!is.null(object[["feature_names"]]) &&
|
||||
!is.null(colnames(newdata)) &&
|
||||
!identical(object[["feature_names"]], colnames(newdata)))
|
||||
stop("Feature names stored in `object` and `newdata` are different!")
|
||||
if (is.null(ntreelimit))
|
||||
ntreelimit <- NVL(object$best_ntreelimit, 0)
|
||||
if (NVL(object$params[['booster']], '') == 'gblinear')
|
||||
@@ -292,7 +302,7 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
|
||||
} else if (predcontrib) {
|
||||
n_col1 <- ncol(newdata) + 1
|
||||
n_group <- npred_per_case / n_col1
|
||||
dnames <- list(NULL, c(colnames(newdata), "BIAS"))
|
||||
dnames <- if (!is.null(colnames(newdata))) list(NULL, c(colnames(newdata), "BIAS")) else NULL
|
||||
ret <- if (n_ret == n_row) {
|
||||
matrix(ret, ncol = 1, dimnames = dnames)
|
||||
} else if (n_group == 1) {
|
||||
|
||||
@@ -136,4 +136,4 @@ xgb.importance <- function(feature_names = NULL, model = NULL, trees = NULL,
|
||||
# Avoid error messages during CRAN check.
|
||||
# The reason is that these variables are never declared
|
||||
# They are mainly column names inferred by Data.table...
|
||||
globalVariables(c(".", ".N", "Gain", "Cover", "Frequency", "Feature"))
|
||||
globalVariables(c(".", ".N", "Gain", "Cover", "Frequency", "Feature", "Class"))
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#' Project all trees on one tree and plot it
|
||||
#'
|
||||
#'
|
||||
#' Visualization of the ensemble of trees as a single collective unit.
|
||||
#'
|
||||
#' @param model produced by the \code{xgb.train} function.
|
||||
@@ -7,52 +7,71 @@
|
||||
#' @param features_keep number of features to keep in each position of the multi trees.
|
||||
#' @param plot_width width in pixels of the graph to produce
|
||||
#' @param plot_height height in pixels of the graph to produce
|
||||
#' @param render a logical flag for whether the graph should be rendered (see Value).
|
||||
#' @param ... currently not used
|
||||
#'
|
||||
#' @return Two graphs showing the distribution of the model deepness.
|
||||
#'
|
||||
#' @details
|
||||
#'
|
||||
#' This function tries to capture the complexity of a gradient boosted tree model
|
||||
#'
|
||||
#' This function tries to capture the complexity of a gradient boosted tree model
|
||||
#' in a cohesive way by compressing an ensemble of trees into a single tree-graph representation.
|
||||
#' The goal is to improve the interpretability of a model generally seen as black box.
|
||||
#'
|
||||
#'
|
||||
#' Note: this function is applicable to tree booster-based models only.
|
||||
#'
|
||||
#' It takes advantage of the fact that the shape of a binary tree is only defined by
|
||||
#' its depth (therefore, in a boosting model, all trees have similar shape).
|
||||
#'
|
||||
#'
|
||||
#' It takes advantage of the fact that the shape of a binary tree is only defined by
|
||||
#' its depth (therefore, in a boosting model, all trees have similar shape).
|
||||
#'
|
||||
#' Moreover, the trees tend to reuse the same features.
|
||||
#'
|
||||
#' The function projects each tree onto one, and keeps for each position the
|
||||
#'
|
||||
#' The function projects each tree onto one, and keeps for each position the
|
||||
#' \code{features_keep} first features (based on the Gain per feature measure).
|
||||
#'
|
||||
#'
|
||||
#' This function is inspired by this blog post:
|
||||
#' \url{https://wellecks.wordpress.com/2015/02/21/peering-into-the-black-box-visualizing-lambdamart/}
|
||||
#'
|
||||
#' @return
|
||||
#'
|
||||
#' When \code{render = TRUE}:
|
||||
#' returns a rendered graph object which is an \code{htmlwidget} of class \code{grViz}.
|
||||
#' Similar to ggplot objects, it needs to be printed to see it when not running from command line.
|
||||
#'
|
||||
#' When \code{render = FALSE}:
|
||||
#' silently returns a graph object which is of DiagrammeR's class \code{dgr_graph}.
|
||||
#' This could be useful if one wants to modify some of the graph attributes
|
||||
#' before rendering the graph with \code{\link[DiagrammeR]{render_graph}}.
|
||||
#'
|
||||
#' @examples
|
||||
#'
|
||||
#' data(agaricus.train, package='xgboost')
|
||||
#'
|
||||
#' bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, max_depth = 15,
|
||||
#' eta = 1, nthread = 2, nrounds = 30, objective = "binary:logistic",
|
||||
#' min_child_weight = 50)
|
||||
#' eta = 1, nthread = 2, nrounds = 30, objective = "binary:logistic",
|
||||
#' min_child_weight = 50, verbose = 0)
|
||||
#'
|
||||
#' p <- xgb.plot.multi.trees(model = bst, feature_names = colnames(agaricus.train$data),
|
||||
#' features_keep = 3)
|
||||
#' p <- xgb.plot.multi.trees(model = bst, features_keep = 3)
|
||||
#' print(p)
|
||||
#'
|
||||
#' \dontrun{
|
||||
#' # Below is an example of how to save this plot to a file.
|
||||
#' # Note that for `export_graph` to work, the DiagrammeRsvg and rsvg packages must also be installed.
|
||||
#' library(DiagrammeR)
|
||||
#' gr <- xgb.plot.multi.trees(model=bst, features_keep = 3, render=FALSE)
|
||||
#' export_graph(gr, 'tree.pdf', width=1500, height=600)
|
||||
#' }
|
||||
#'
|
||||
#' @export
|
||||
xgb.plot.multi.trees <- function(model, feature_names = NULL, features_keep = 5, plot_width = NULL, plot_height = NULL, ...){
|
||||
xgb.plot.multi.trees <- function(model, feature_names = NULL, features_keep = 5, plot_width = NULL, plot_height = NULL,
|
||||
render = TRUE, ...){
|
||||
check.deprecation(...)
|
||||
tree.matrix <- xgb.model.dt.tree(feature_names = feature_names, model = model)
|
||||
|
||||
|
||||
# first number of the path represents the tree, then the following numbers are related to the path to follow
|
||||
# root init
|
||||
root.nodes <- tree.matrix[stri_detect_regex(ID, "\\d+-0"), ID]
|
||||
tree.matrix[ID %in% root.nodes, abs.node.position:=root.nodes]
|
||||
|
||||
tree.matrix[ID %in% root.nodes, abs.node.position := root.nodes]
|
||||
|
||||
precedent.nodes <- root.nodes
|
||||
|
||||
|
||||
while(tree.matrix[,sum(is.na(abs.node.position))] > 0) {
|
||||
yes.row.nodes <- tree.matrix[abs.node.position %in% precedent.nodes & !is.na(Yes)]
|
||||
no.row.nodes <- tree.matrix[abs.node.position %in% precedent.nodes & !is.na(No)]
|
||||
@@ -64,9 +83,8 @@ xgb.plot.multi.trees <- function(model, feature_names = NULL, features_keep = 5,
|
||||
precedent.nodes <- c(yes.nodes.abs.pos, no.nodes.abs.pos)
|
||||
}
|
||||
|
||||
tree.matrix[!is.na(Yes),Yes:= paste0(abs.node.position, "_0")]
|
||||
tree.matrix[!is.na(No),No:= paste0(abs.node.position, "_1")]
|
||||
|
||||
tree.matrix[!is.na(Yes), Yes := paste0(abs.node.position, "_0")]
|
||||
tree.matrix[!is.na(No), No := paste0(abs.node.position, "_1")]
|
||||
|
||||
remove.tree <- . %>% stri_replace_first_regex(pattern = "^\\d+-", replacement = "")
|
||||
|
||||
@@ -120,8 +138,10 @@ xgb.plot.multi.trees <- function(model, feature_names = NULL, features_keep = 5,
|
||||
attr_type = "edge",
|
||||
attr = c("color", "arrowsize", "arrowhead", "fontname"),
|
||||
value = c("DimGray", "1.5", "vee", "Helvetica"))
|
||||
|
||||
DiagrammeR::render_graph(graph, width = plot_width, height = plot_height)
|
||||
|
||||
if (!render) return(invisible(graph))
|
||||
|
||||
DiagrammeR::render_graph(graph, width = plot_width, height = plot_height)
|
||||
}
|
||||
|
||||
globalVariables(c(".N", "N", "From", "To", "Text", "Feature", "no.nodes.abs.pos",
|
||||
|
||||
214
R-package/R/xgb.plot.shap.R
Normal file
214
R-package/R/xgb.plot.shap.R
Normal file
@@ -0,0 +1,214 @@
|
||||
#' SHAP contribution dependency plots
|
||||
#'
|
||||
#' Visualizing the SHAP feature contribution to prediction dependencies on feature value.
|
||||
#'
|
||||
#' @param data data as a \code{matrix} or \code{dgCMatrix}.
|
||||
#' @param shap_contrib a matrix of SHAP contributions that was computed earlier for the above
|
||||
#' \code{data}. When it is NULL, it is computed internally using \code{model} and \code{data}.
|
||||
#' @param features a vector of either column indices or of feature names to plot. When it is NULL,
|
||||
#' feature importance is calculated, and \code{top_n} high ranked features are taken.
|
||||
#' @param top_n when \code{features} is NULL, top_n [1, 100] most important features in a model are taken.
|
||||
#' @param model an \code{xgb.Booster} model. It has to be provided when either \code{shap_contrib}
|
||||
#' or \code{features} is missing.
|
||||
#' @param trees passed to \code{\link{xgb.importance}} when \code{features = NULL}.
|
||||
#' @param target_class is only relevant for multiclass models. When it is set to a 0-based class index,
|
||||
#' only SHAP contributions for that specific class are used.
|
||||
#' If it is not set, SHAP importances are averaged over all classes.
|
||||
#' @param approxcontrib passed to \code{\link{predict.xgb.Booster}} when \code{shap_contrib = NULL}.
|
||||
#' @param subsample a random fraction of data points to use for plotting. When it is NULL,
|
||||
#' it is set so that up to 100K data points are used.
|
||||
#' @param n_col a number of columns in a grid of plots.
|
||||
#' @param col color of the scatterplot markers.
|
||||
#' @param pch scatterplot marker.
|
||||
#' @param discrete_n_uniq a maximal number of unique values in a feature to consider it as discrete.
|
||||
#' @param discrete_jitter an \code{amount} parameter of jitter added to discrete features' positions.
|
||||
#' @param ylab a y-axis label in 1D plots.
|
||||
#' @param plot_NA whether the contributions of cases with missing values should also be plotted.
|
||||
#' @param col_NA a color of marker for missing value contributions.
|
||||
#' @param pch_NA a marker type for NA values.
|
||||
#' @param pos_NA a relative position of the x-location where NA values are shown:
|
||||
#' \code{min(x) + (max(x) - min(x)) * pos_NA}.
|
||||
#' @param plot_loess whether to plot loess-smoothed curves. The smoothing is only done for features with
|
||||
#' more than 5 distinct values.
|
||||
#' @param col_loess a color to use for the loess curves.
|
||||
#' @param span_loess the \code{span} paramerer in \code{\link[stats]{loess}}'s call.
|
||||
#' @param which whether to do univariate or bivariate plotting. NOTE: only 1D is implemented so far.
|
||||
#' @param plot whether a plot should be drawn. If FALSE, only a lits of matrices is returned.
|
||||
#' @param ... other parameters passed to \code{plot}.
|
||||
#'
|
||||
#' @details
|
||||
#'
|
||||
#' These scatterplots represent how SHAP feature contributions depend of feature values.
|
||||
#' The similarity to partial dependency plots is that they also give an idea for how feature values
|
||||
#' affect predictions. However, in partial dependency plots, we usually see marginal dependencies
|
||||
#' of model prediction on feature value, while SHAP contribution dependency plots display the estimated
|
||||
#' contributions of a feature to model prediction for each individual case.
|
||||
#'
|
||||
#' When \code{plot_loess = TRUE} is set, feature values are rounded to 3 significant digits and
|
||||
#' weighted LOESS is computed and plotted, where weights are the numbers of data points
|
||||
#' at each rounded value.
|
||||
#'
|
||||
#' Note: SHAP contributions are shown on the scale of model margin. E.g., for a logistic binomial objective,
|
||||
#' the margin is prediction before a sigmoidal transform into probability-like values.
|
||||
#' Also, since SHAP stands for "SHapley Additive exPlanation" (model prediction = sum of SHAP
|
||||
#' contributions for all features + bias), depending on the objective used, transforming SHAP
|
||||
#' contributions for a feature from the marginal to the prediction space is not necessarily
|
||||
#' a meaningful thing to do.
|
||||
#'
|
||||
#' @return
|
||||
#'
|
||||
#' In addition to producing plots (when \code{plot=TRUE}), it silently returns a list of two matrices:
|
||||
#' \itemize{
|
||||
#' \item \code{data} the values of selected features;
|
||||
#' \item \code{shap_contrib} the contributions of selected features.
|
||||
#' }
|
||||
#'
|
||||
#' @references
|
||||
#'
|
||||
#' Scott M. Lundberg, Su-In Lee, "A Unified Approach to Interpreting Model Predictions", NIPS Proceedings 2017, \url{https://arxiv.org/abs/1705.07874}
|
||||
#'
|
||||
#' Scott M. Lundberg, Su-In Lee, "Consistent feature attribution for tree ensembles", \url{https://arxiv.org/abs/1706.06060}
|
||||
#'
|
||||
#' @examples
|
||||
#'
|
||||
#' data(agaricus.train, package='xgboost')
|
||||
#' data(agaricus.test, package='xgboost')
|
||||
#'
|
||||
#' bst <- xgboost(agaricus.train$data, agaricus.train$label, nrounds = 50,
|
||||
#' eta = 0.1, max_depth = 3, subsample = .5,
|
||||
#' method = "hist", objective = "binary:logistic", nthread = 2, verbose = 0)
|
||||
#'
|
||||
#' xgb.plot.shap(agaricus.test$data, model = bst, features = "odor=none")
|
||||
#' contr <- predict(bst, agaricus.test$data, predcontrib = TRUE)
|
||||
#' xgb.plot.shap(agaricus.test$data, contr, model = bst, top_n = 12, n_col = 3)
|
||||
#'
|
||||
#' # multiclass example - plots for each class separately:
|
||||
#' nclass <- 3
|
||||
#' nrounds <- 20
|
||||
#' x <- as.matrix(iris[, -5])
|
||||
#' set.seed(123)
|
||||
#' is.na(x[sample(nrow(x) * 4, 30)]) <- TRUE # introduce some missing values
|
||||
#' mbst <- xgboost(data = x, label = as.numeric(iris$Species) - 1, nrounds = nrounds,
|
||||
#' max_depth = 2, eta = 0.3, subsample = .5, nthread = 2,
|
||||
#' objective = "multi:softprob", num_class = nclass, verbose = 0)
|
||||
#' trees0 <- seq(from=0, by=nclass, length.out=nrounds)
|
||||
#' col <- rgb(0, 0, 1, 0.5)
|
||||
#' xgb.plot.shap(x, model = mbst, trees = trees0, target_class = 0, top_n = 4, n_col = 2, col = col, pch = 16, pch_NA = 17)
|
||||
#' xgb.plot.shap(x, model = mbst, trees = trees0 + 1, target_class = 1, top_n = 4, n_col = 2, col = col, pch = 16, pch_NA = 17)
|
||||
#' xgb.plot.shap(x, model = mbst, trees = trees0 + 2, target_class = 2, top_n = 4, n_col = 2, col = col, pch = 16, pch_NA = 17)
|
||||
#'
|
||||
#' @rdname xgb.plot.shap
|
||||
#' @export
|
||||
xgb.plot.shap <- function(data, shap_contrib = NULL, features = NULL, top_n = 1, model = NULL,
|
||||
trees = NULL, target_class = NULL, approxcontrib = FALSE,
|
||||
subsample = NULL, n_col = 1, col = rgb(0, 0, 1, 0.2), pch = '.',
|
||||
discrete_n_uniq = 5, discrete_jitter = 0.01, ylab = "SHAP",
|
||||
plot_NA = TRUE, col_NA = rgb(0.7, 0, 1, 0.6), pch_NA = '.', pos_NA = 1.07,
|
||||
plot_loess = TRUE, col_loess = 2, span_loess = 0.5,
|
||||
which = c("1d", "2d"), plot = TRUE, ...) {
|
||||
|
||||
if (!is.matrix(data) && !inherits(data, "dgCMatrix"))
|
||||
stop("data: must be either matrix or dgCMatrix")
|
||||
|
||||
if (is.null(shap_contrib) && (is.null(model) || !inherits(model, "xgb.Booster")))
|
||||
stop("when shap_contrib is not provided, one must provide an xgb.Booster model")
|
||||
|
||||
if (is.null(features) && (is.null(model) || !inherits(model, "xgb.Booster")))
|
||||
stop("when features are not provided, one must provide an xgb.Booster model to rank the features")
|
||||
|
||||
if (!is.null(shap_contrib) &&
|
||||
(!is.matrix(shap_contrib) || nrow(shap_contrib) != nrow(data) || ncol(shap_contrib) != ncol(data) + 1))
|
||||
stop("shap_contrib is not compatible with the provided data")
|
||||
|
||||
nsample <- if (is.null(subsample)) min(100000, nrow(data)) else as.integer(subsample * nrow(data))
|
||||
idx <- sample(1:nrow(data), nsample)
|
||||
data <- data[idx,]
|
||||
|
||||
if (is.null(shap_contrib)) {
|
||||
shap_contrib <- predict(model, data, predcontrib = TRUE, approxcontrib = approxcontrib)
|
||||
} else {
|
||||
shap_contrib <- shap_contrib[idx,]
|
||||
}
|
||||
|
||||
which <- match.arg(which)
|
||||
if (which == "2d")
|
||||
stop("2D plots are not implemented yet")
|
||||
|
||||
if (is.null(features)) {
|
||||
imp <- xgb.importance(model = model, trees = trees)
|
||||
top_n <- as.integer(top_n[1])
|
||||
if (top_n < 1 && top_n > 100)
|
||||
stop("top_n: must be an integer within [1, 100]")
|
||||
features <- imp$Feature[1:min(top_n, NROW(imp))]
|
||||
}
|
||||
|
||||
if (is.character(features)) {
|
||||
if (is.null(colnames(data)))
|
||||
stop("Either provide `data` with column names or provide `features` as column indices")
|
||||
features <- match(features, colnames(data))
|
||||
}
|
||||
|
||||
if (n_col > length(features)) n_col <- length(features)
|
||||
|
||||
if (is.list(shap_contrib)) { # multiclass: either choose a class or merge
|
||||
shap_contrib <- if (!is.null(target_class)) shap_contrib[[target_class + 1]]
|
||||
else Reduce("+", lapply(shap_contrib, abs))
|
||||
}
|
||||
|
||||
shap_contrib <- shap_contrib[, features, drop = FALSE]
|
||||
data <- data[, features, drop = FALSE]
|
||||
cols <- colnames(data)
|
||||
if (is.null(cols)) cols <- colnames(shap_contrib)
|
||||
if (is.null(cols)) cols <- paste0('X', 1:ncol(data))
|
||||
colnames(data) <- cols
|
||||
colnames(shap_contrib) <- cols
|
||||
|
||||
if (plot && which == "1d") {
|
||||
op <- par(mfrow = c(ceiling(length(features) / n_col), n_col),
|
||||
oma = c(0,0,0,0) + 0.2,
|
||||
mar = c(3.5,3.5,0,0) + 0.1,
|
||||
mgp = c(1.7, 0.6, 0))
|
||||
for (f in cols) {
|
||||
ord <- order(data[, f])
|
||||
x <- data[, f][ord]
|
||||
y <- shap_contrib[, f][ord]
|
||||
x_lim <- range(x, na.rm = TRUE)
|
||||
y_lim <- range(y, na.rm = TRUE)
|
||||
do_na <- plot_NA && any(is.na(x))
|
||||
if (do_na) {
|
||||
x_range <- diff(x_lim)
|
||||
loc_na <- min(x, na.rm = TRUE) + x_range * pos_NA
|
||||
x_lim <- range(c(x_lim, loc_na))
|
||||
}
|
||||
x_uniq <- unique(x)
|
||||
x2plot <- x
|
||||
# add small jitter for discrete features with <= 5 distinct values
|
||||
if (length(x_uniq) <= discrete_n_uniq)
|
||||
x2plot <- jitter(x, amount = discrete_jitter * min(diff(x_uniq), na.rm = TRUE))
|
||||
plot(x2plot, y, pch = pch, xlab = f, col = col, xlim = x_lim, ylim = y_lim, ylab = ylab, ...)
|
||||
grid()
|
||||
if (plot_loess) {
|
||||
# compress x to 3 digits, and mean-aggredate y
|
||||
zz <- data.table(x = signif(x, 3), y)[, .(.N, y=mean(y)), x]
|
||||
if (nrow(zz) <= 5) {
|
||||
lines(zz$x, zz$y, col = col_loess)
|
||||
} else {
|
||||
lo <- stats::loess(y ~ x, data = zz, weights = zz$N, span = span_loess)
|
||||
zz$y_lo <- predict(lo, zz, type = "link")
|
||||
lines(zz$x, zz$y_lo, col = col_loess)
|
||||
}
|
||||
}
|
||||
if (do_na) {
|
||||
i_na <- which(is.na(x))
|
||||
x_na <- rep(loc_na, length(i_na))
|
||||
x_na <- jitter(x_na, amount = x_range * 0.01)
|
||||
points(x_na, y[i_na], pch = pch_NA, col = col_NA)
|
||||
}
|
||||
}
|
||||
par(op)
|
||||
}
|
||||
if (plot && which == "2d") {
|
||||
# TODO
|
||||
}
|
||||
invisible(list(data = data, shap_contrib = shap_contrib))
|
||||
}
|
||||
@@ -95,7 +95,8 @@ xgb.plot.tree <- function(feature_names = NULL, model = NULL, trees = NULL, plot
|
||||
label = dt$label,
|
||||
fillcolor = dt$filledcolor,
|
||||
shape = dt$shape,
|
||||
data = dt$Feature)
|
||||
data = dt$Feature,
|
||||
fontcolor = "black")
|
||||
|
||||
edges <- DiagrammeR::create_edge_df(
|
||||
from = match(dt[Feature != "Leaf", c(ID)] %>% rep(2), dt$ID),
|
||||
|
||||
@@ -169,6 +169,11 @@
|
||||
#' \code{\link{predict.xgb.Booster}},
|
||||
#' \code{\link{xgb.cv}}
|
||||
#'
|
||||
#' @references
|
||||
#'
|
||||
#' Tianqi Chen and Carlos Guestrin, "XGBoost: A Scalable Tree Boosting System",
|
||||
#' 22nd SIGKDD Conference on Knowledge Discovery and Data Mining, 2016, \url{https://arxiv.org/abs/1603.02754}
|
||||
#'
|
||||
#' @examples
|
||||
#' data(agaricus.train, package='xgboost')
|
||||
#' data(agaricus.test, package='xgboost')
|
||||
|
||||
@@ -100,9 +100,12 @@ NULL
|
||||
#' @importFrom stats median
|
||||
#' @importFrom utils head
|
||||
#' @importFrom graphics barplot
|
||||
#' @importFrom graphics lines
|
||||
#' @importFrom graphics points
|
||||
#' @importFrom graphics grid
|
||||
#' @importFrom graphics par
|
||||
#' @importFrom graphics title
|
||||
#' @importFrom grDevices rgb
|
||||
#'
|
||||
#' @import methods
|
||||
#' @useDynLib xgboost, .registration = TRUE
|
||||
|
||||
Reference in New Issue
Block a user