[R] Improve more docstrings (#9919)
This commit is contained in:
@@ -127,22 +127,20 @@ xgb.ggplot.shap.summary <- function(data, shap_contrib = NULL, features = NULL,
|
||||
p
|
||||
}
|
||||
|
||||
#' Combine and melt feature values and SHAP contributions for sample
|
||||
#' observations.
|
||||
#' Combine feature values and SHAP values
|
||||
#'
|
||||
#' Conforms to data format required for ggplot functions.
|
||||
#' Internal function used to combine and melt feature values and SHAP contributions
|
||||
#' as required for ggplot functions related to SHAP.
|
||||
#'
|
||||
#' Internal utility function.
|
||||
#' @param data_list The result of `xgb.shap.data()`.
|
||||
#' @param normalize Whether to standardize feature values to mean 0 and
|
||||
#' standard deviation 1. This is useful for comparing multiple features on the same
|
||||
#' plot. Default is \code{FALSE}.
|
||||
#'
|
||||
#' @param data_list List containing 'data' and 'shap_contrib' returned by
|
||||
#' \code{xgb.shap.data()}.
|
||||
#' @param normalize Whether to standardize feature values to have mean 0 and
|
||||
#' standard deviation 1 (useful for comparing multiple features on the same
|
||||
#' plot). Default \code{FALSE}.
|
||||
#'
|
||||
#' @return A data.table containing the observation ID, the feature name, the
|
||||
#' @return A `data.table` containing the observation ID, the feature name, the
|
||||
#' feature value (normalized if specified), and the SHAP contribution value.
|
||||
#' @noRd
|
||||
#' @keywords internal
|
||||
prepare.ggplot.shap.data <- function(data_list, normalize = FALSE) {
|
||||
data <- data_list[["data"]]
|
||||
shap_contrib <- data_list[["shap_contrib"]]
|
||||
@@ -163,15 +161,16 @@ prepare.ggplot.shap.data <- function(data_list, normalize = FALSE) {
|
||||
p_data
|
||||
}
|
||||
|
||||
#' Scale feature value to have mean 0, standard deviation 1
|
||||
#' Scale feature values
|
||||
#'
|
||||
#' This is used to compare multiple features on the same plot.
|
||||
#' Internal utility function
|
||||
#' Internal function that scales feature values to mean 0 and standard deviation 1.
|
||||
#' Useful to compare multiple features on the same plot.
|
||||
#'
|
||||
#' @param x Numeric vector
|
||||
#' @param x Numeric vector.
|
||||
#'
|
||||
#' @return Numeric vector with mean 0 and sd 1.
|
||||
#' @return Numeric vector with mean 0 and standard deviation 1.
|
||||
#' @noRd
|
||||
#' @keywords internal
|
||||
normalize <- function(x) {
|
||||
loc <- mean(x, na.rm = TRUE)
|
||||
scale <- stats::sd(x, na.rm = TRUE)
|
||||
|
||||
@@ -1,83 +1,115 @@
|
||||
#' Importance of features in a model.
|
||||
#' Feature importance
|
||||
#'
|
||||
#' Creates a \code{data.table} of feature importances in a model.
|
||||
#' Creates a `data.table` of feature importances.
|
||||
#'
|
||||
#' @param feature_names character vector of feature names. If the model already
|
||||
#' contains feature names, those would be used when \code{feature_names=NULL} (default value).
|
||||
#' Non-null \code{feature_names} could be provided to override those in the model.
|
||||
#' @param model object of class \code{xgb.Booster}.
|
||||
#' @param trees (only for the gbtree booster) an integer vector of tree indices that should be included
|
||||
#' into the importance calculation. If set to \code{NULL}, all trees of the model are parsed.
|
||||
#' It could be useful, e.g., in multiclass classification to get feature importances
|
||||
#' for each class separately. IMPORTANT: the tree index in xgboost models
|
||||
#' is zero-based (e.g., use \code{trees = 0:4} for first 5 trees).
|
||||
#' @param data deprecated.
|
||||
#' @param label deprecated.
|
||||
#' @param target deprecated.
|
||||
#' @param feature_names Character vector used to overwrite the feature names
|
||||
#' of the model. The default is `NULL` (use original feature names).
|
||||
#' @param model Object of class `xgb.Booster`.
|
||||
#' @param trees An integer vector of tree indices that should be included
|
||||
#' into the importance calculation (only for the "gbtree" booster).
|
||||
#' The default (`NULL`) parses all trees.
|
||||
#' It could be useful, e.g., in multiclass classification to get feature importances
|
||||
#' for each class separately. *Important*: the tree index in XGBoost models
|
||||
#' is zero-based (e.g., use `trees = 0:4` for the first five trees).
|
||||
#' @param data Deprecated.
|
||||
#' @param label Deprecated.
|
||||
#' @param target Deprecated.
|
||||
#'
|
||||
#' @details
|
||||
#'
|
||||
#' This function works for both linear and tree models.
|
||||
#'
|
||||
#' For linear models, the importance is the absolute magnitude of linear coefficients.
|
||||
#' For that reason, in order to obtain a meaningful ranking by importance for a linear model,
|
||||
#' the features need to be on the same scale (which you also would want to do when using either
|
||||
#' L1 or L2 regularization).
|
||||
#' To obtain a meaningful ranking by importance for linear models, the features need to
|
||||
#' be on the same scale (which is also recommended when using L1 or L2 regularization).
|
||||
#'
|
||||
#' @return
|
||||
#' @return A `data.table` with the following columns:
|
||||
#'
|
||||
#' For a tree model, a \code{data.table} with the following columns:
|
||||
#' \itemize{
|
||||
#' \item \code{Features} names of the features used in the model;
|
||||
#' \item \code{Gain} represents fractional contribution of each feature to the model based on
|
||||
#' the total gain of this feature's splits. Higher percentage means a more important
|
||||
#' predictive feature.
|
||||
#' \item \code{Cover} metric of the number of observation related to this feature;
|
||||
#' \item \code{Frequency} percentage representing the relative number of times
|
||||
#' a feature have been used in trees.
|
||||
#' }
|
||||
#' For a tree model:
|
||||
#' - `Features`: Names of the features used in the model.
|
||||
#' - `Gain`: Fractional contribution of each feature to the model based on
|
||||
#' the total gain of this feature's splits. Higher percentage means higher importance.
|
||||
#' - `Cover`: Metric of the number of observation related to this feature.
|
||||
#' - `Frequency`: Percentage of times a feature has been used in trees.
|
||||
#'
|
||||
#' A linear model's importance \code{data.table} has the following columns:
|
||||
#' \itemize{
|
||||
#' \item \code{Features} names of the features used in the model;
|
||||
#' \item \code{Weight} the linear coefficient of this feature;
|
||||
#' \item \code{Class} (only for multiclass models) class label.
|
||||
#' }
|
||||
#' For a linear model:
|
||||
#' - `Features`: Names of the features used in the model.
|
||||
#' - `Weight`: Linear coefficient of this feature.
|
||||
#' - `Class`: Class label (only for multiclass models).
|
||||
#'
|
||||
#' If \code{feature_names} is not provided and \code{model} doesn't have \code{feature_names},
|
||||
#' index of the features will be used instead. Because the index is extracted from the model dump
|
||||
#' If `feature_names` is not provided and `model` doesn't have `feature_names`,
|
||||
#' the index of the features will be used instead. Because the index is extracted from the model dump
|
||||
#' (based on C++ code), it starts at 0 (as in C/C++ or Python) instead of 1 (usual in R).
|
||||
#'
|
||||
#' @examples
|
||||
#'
|
||||
#' # binomial classification using gbtree:
|
||||
#' data(agaricus.train, package='xgboost')
|
||||
#' bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, max_depth = 2,
|
||||
#' eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic")
|
||||
#' # binomial classification using "gbtree":
|
||||
#' data(agaricus.train, package = "xgboost")
|
||||
#'
|
||||
#' bst <- xgboost(
|
||||
#' data = agaricus.train$data,
|
||||
#' label = agaricus.train$label,
|
||||
#' max_depth = 2,
|
||||
#' eta = 1,
|
||||
#' nthread = 2,
|
||||
#' nrounds = 2,
|
||||
#' objective = "binary:logistic"
|
||||
#' )
|
||||
#'
|
||||
#' xgb.importance(model = bst)
|
||||
#'
|
||||
#' # binomial classification using gblinear:
|
||||
#' bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, booster = "gblinear",
|
||||
#' eta = 0.3, nthread = 1, nrounds = 20, objective = "binary:logistic")
|
||||
#' # binomial classification using "gblinear":
|
||||
#' bst <- xgboost(
|
||||
#' data = agaricus.train$data,
|
||||
#' label = agaricus.train$label,
|
||||
#' booster = "gblinear",
|
||||
#' eta = 0.3,
|
||||
#' nthread = 1,
|
||||
#' nrounds = 20,objective = "binary:logistic"
|
||||
#' )
|
||||
#'
|
||||
#' xgb.importance(model = bst)
|
||||
#'
|
||||
#' # multiclass classification using gbtree:
|
||||
#' # multiclass classification using "gbtree":
|
||||
#' nclass <- 3
|
||||
#' nrounds <- 10
|
||||
#' mbst <- xgboost(data = as.matrix(iris[, -5]), label = as.numeric(iris$Species) - 1,
|
||||
#' max_depth = 3, eta = 0.2, nthread = 2, nrounds = nrounds,
|
||||
#' objective = "multi:softprob", num_class = nclass)
|
||||
#' mbst <- xgboost(
|
||||
#' data = as.matrix(iris[, -5]),
|
||||
#' label = as.numeric(iris$Species) - 1,
|
||||
#' max_depth = 3,
|
||||
#' eta = 0.2,
|
||||
#' nthread = 2,
|
||||
#' nrounds = nrounds,
|
||||
#' objective = "multi:softprob",
|
||||
#' num_class = nclass
|
||||
#' )
|
||||
#'
|
||||
#' # all classes clumped together:
|
||||
#' xgb.importance(model = mbst)
|
||||
#' # inspect importances separately for each class:
|
||||
#' xgb.importance(model = mbst, trees = seq(from=0, by=nclass, length.out=nrounds))
|
||||
#' xgb.importance(model = mbst, trees = seq(from=1, by=nclass, length.out=nrounds))
|
||||
#' xgb.importance(model = mbst, trees = seq(from=2, by=nclass, length.out=nrounds))
|
||||
#'
|
||||
#' # multiclass classification using gblinear:
|
||||
#' mbst <- xgboost(data = scale(as.matrix(iris[, -5])), label = as.numeric(iris$Species) - 1,
|
||||
#' booster = "gblinear", eta = 0.2, nthread = 1, nrounds = 15,
|
||||
#' objective = "multi:softprob", num_class = nclass)
|
||||
#' # inspect importances separately for each class:
|
||||
#' xgb.importance(
|
||||
#' model = mbst, trees = seq(from = 0, by = nclass, length.out = nrounds)
|
||||
#' )
|
||||
#' xgb.importance(
|
||||
#' model = mbst, trees = seq(from = 1, by = nclass, length.out = nrounds)
|
||||
#' )
|
||||
#' xgb.importance(
|
||||
#' model = mbst, trees = seq(from = 2, by = nclass, length.out = nrounds)
|
||||
#' )
|
||||
#'
|
||||
#' # multiclass classification using "gblinear":
|
||||
#' mbst <- xgboost(
|
||||
#' data = scale(as.matrix(iris[, -5])),
|
||||
#' label = as.numeric(iris$Species) - 1,
|
||||
#' booster = "gblinear",
|
||||
#' eta = 0.2,
|
||||
#' nthread = 1,
|
||||
#' nrounds = 15,
|
||||
#' objective = "multi:softprob",
|
||||
#' num_class = nclass
|
||||
#' )
|
||||
#'
|
||||
#' xgb.importance(model = mbst)
|
||||
#'
|
||||
#' @export
|
||||
|
||||
@@ -1,57 +1,58 @@
|
||||
#' Parse a boosted tree model text dump
|
||||
#' Parse model text dump
|
||||
#'
|
||||
#' Parse a boosted tree model text dump into a \code{data.table} structure.
|
||||
#' Parse a boosted tree model text dump into a `data.table` structure.
|
||||
#'
|
||||
#' @param feature_names character vector of feature names. If the model already
|
||||
#' contains feature names, those would be used when \code{feature_names=NULL} (default value).
|
||||
#' Non-null \code{feature_names} could be provided to override those in the model.
|
||||
#' @param model object of class \code{xgb.Booster}
|
||||
#' @param text \code{character} vector previously generated by the \code{xgb.dump}
|
||||
#' function (where parameter \code{with_stats = TRUE} should have been set).
|
||||
#' \code{text} takes precedence over \code{model}.
|
||||
#' @param trees an integer vector of tree indices that should be parsed.
|
||||
#' If set to \code{NULL}, all trees of the model are parsed.
|
||||
#' It could be useful, e.g., in multiclass classification to get only
|
||||
#' the trees of one certain class. IMPORTANT: the tree index in xgboost models
|
||||
#' is zero-based (e.g., use \code{trees = 0:4} for first 5 trees).
|
||||
#' @param use_int_id a logical flag indicating whether nodes in columns "Yes", "No", "Missing" should be
|
||||
#' represented as integers (when FALSE) or as "Tree-Node" character strings (when FALSE).
|
||||
#' @param ... currently not used.
|
||||
#' @param feature_names Character vector used to overwrite the feature names
|
||||
#' of the model. The default (`NULL`) uses the original feature names.
|
||||
#' @param model Object of class `xgb.Booster`.
|
||||
#' @param text Character vector previously generated by the function [xgb.dump()]
|
||||
#' (called with parameter `with_stats = TRUE`). `text` takes precedence over `model`.
|
||||
#' @param trees An integer vector of tree indices that should be used.
|
||||
#' The default (`NULL`) uses all trees.
|
||||
#' Useful, e.g., in multiclass classification to get only
|
||||
#' the trees of one class. *Important*: the tree index in XGBoost models
|
||||
#' is zero-based (e.g., use `trees = 0:4` for the first five trees).
|
||||
#' @param use_int_id A logical flag indicating whether nodes in columns "Yes", "No", and
|
||||
#' "Missing" should be represented as integers (when `TRUE`) or as "Tree-Node"
|
||||
#' character strings (when `FALSE`, default).
|
||||
#' @param ... Currently not used.
|
||||
#'
|
||||
#' @return
|
||||
#' A \code{data.table} with detailed information about model trees' nodes.
|
||||
#' A `data.table` with detailed information about tree nodes. It has the following columns:
|
||||
#' - `Tree`: integer ID of a tree in a model (zero-based index).
|
||||
#' - `Node`: integer ID of a node in a tree (zero-based index).
|
||||
#' - `ID`: character identifier of a node in a model (only when `use_int_id = FALSE`).
|
||||
#' - `Feature`: for a branch node, a feature ID or name (when available);
|
||||
#' for a leaf node, it simply labels it as `"Leaf"`.
|
||||
#' - `Split`: location of the split for a branch node (split condition is always "less than").
|
||||
#' - `Yes`: ID of the next node when the split condition is met.
|
||||
#' - `No`: ID of the next node when the split condition is not met.
|
||||
#' - `Missing`: ID of the next node when the branch value is missing.
|
||||
#' - `Quality`: either the split gain (change in loss) or the leaf value.
|
||||
#' - `Cover`: metric related to the number of observations either seen by a split
|
||||
#' or collected by a leaf during training.
|
||||
#'
|
||||
#' The columns of the \code{data.table} are:
|
||||
#'
|
||||
#' \itemize{
|
||||
#' \item \code{Tree}: integer ID of a tree in a model (zero-based index)
|
||||
#' \item \code{Node}: integer ID of a node in a tree (zero-based index)
|
||||
#' \item \code{ID}: character identifier of a node in a model (only when \code{use_int_id=FALSE})
|
||||
#' \item \code{Feature}: for a branch node, it's a feature id or name (when available);
|
||||
#' for a leaf note, it simply labels it as \code{'Leaf'}
|
||||
#' \item \code{Split}: location of the split for a branch node (split condition is always "less than")
|
||||
#' \item \code{Yes}: ID of the next node when the split condition is met
|
||||
#' \item \code{No}: ID of the next node when the split condition is not met
|
||||
#' \item \code{Missing}: ID of the next node when branch value is missing
|
||||
#' \item \code{Quality}: either the split gain (change in loss) or the leaf value
|
||||
#' \item \code{Cover}: metric related to the number of observation either seen by a split
|
||||
#' or collected by a leaf during training.
|
||||
#' }
|
||||
#'
|
||||
#' When \code{use_int_id=FALSE}, columns "Yes", "No", and "Missing" point to model-wide node identifiers
|
||||
#' in the "ID" column. When \code{use_int_id=TRUE}, those columns point to node identifiers from
|
||||
#' When `use_int_id = FALSE`, columns "Yes", "No", and "Missing" point to model-wide node identifiers
|
||||
#' in the "ID" column. When `use_int_id = TRUE`, those columns point to node identifiers from
|
||||
#' the corresponding trees in the "Node" column.
|
||||
#'
|
||||
#' @examples
|
||||
#' # Basic use:
|
||||
#'
|
||||
#' data(agaricus.train, package='xgboost')
|
||||
#' data(agaricus.train, package = "xgboost")
|
||||
#' ## Keep the number of threads to 1 for examples
|
||||
#' nthread <- 1
|
||||
#' data.table::setDTthreads(nthread)
|
||||
#'
|
||||
#' bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, max_depth = 2,
|
||||
#' eta = 1, nthread = nthread, nrounds = 2,objective = "binary:logistic")
|
||||
#' bst <- xgboost(
|
||||
#' data = agaricus.train$data,
|
||||
#' label = agaricus.train$label,
|
||||
#' max_depth = 2,
|
||||
#' eta = 1,
|
||||
#' nthread = nthread,
|
||||
#' nrounds = 2,
|
||||
#' objective = "binary:logistic"
|
||||
#' )
|
||||
#'
|
||||
#' (dt <- xgb.model.dt.tree(colnames(agaricus.train$data), bst))
|
||||
#'
|
||||
@@ -60,8 +61,12 @@
|
||||
#' (dt <- xgb.model.dt.tree(model = bst))
|
||||
#'
|
||||
#' # How to match feature names of splits that are following a current 'Yes' branch:
|
||||
#'
|
||||
#' merge(dt, dt[, .(ID, Y.Feature=Feature)], by.x='Yes', by.y='ID', all.x=TRUE)[order(Tree,Node)]
|
||||
#' merge(
|
||||
#' dt,
|
||||
#' dt[, .(ID, Y.Feature = Feature)], by.x = "Yes", by.y = "ID", all.x = TRUE
|
||||
#' )[
|
||||
#' order(Tree, Node)
|
||||
#' ]
|
||||
#'
|
||||
#' @export
|
||||
xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL,
|
||||
|
||||
@@ -1,65 +1,74 @@
|
||||
#' Plot model trees deepness
|
||||
#' Plot model tree depth
|
||||
#'
|
||||
#' Visualizes distributions related to depth of tree leafs.
|
||||
#' \code{xgb.plot.deepness} uses base R graphics, while \code{xgb.ggplot.deepness} uses the ggplot backend.
|
||||
#' Visualizes distributions related to the depth of tree leaves.
|
||||
#' - `xgb.plot.deepness()` uses base R graphics, while
|
||||
#' - `xgb.ggplot.deepness()` uses "ggplot2".
|
||||
#'
|
||||
#' @param model either an \code{xgb.Booster} model generated by the \code{xgb.train} function
|
||||
#' or a data.table result of the \code{xgb.model.dt.tree} function.
|
||||
#' @param plot (base R barplot) whether a barplot should be produced.
|
||||
#' If FALSE, only a data.table is returned.
|
||||
#' @param which which distribution to plot (see details).
|
||||
#' @param ... other parameters passed to \code{barplot} or \code{plot}.
|
||||
#' @param model Either an `xgb.Booster` model, or the "data.table" returned by [xgb.model.dt.tree()].
|
||||
#' @param which Which distribution to plot (see details).
|
||||
#' @param plot Should the plot be shown? Default is `TRUE`.
|
||||
#' @param ... Other parameters passed to [graphics::barplot()] or [graphics::plot()].
|
||||
#'
|
||||
#' @details
|
||||
#'
|
||||
#' When \code{which="2x1"}, two distributions with respect to the leaf depth
|
||||
#' When `which = "2x1"`, two distributions with respect to the leaf depth
|
||||
#' are plotted on top of each other:
|
||||
#' \itemize{
|
||||
#' \item the distribution of the number of leafs in a tree model at a certain depth;
|
||||
#' \item the distribution of average weighted number of observations ("cover")
|
||||
#' ending up in leafs at certain depth.
|
||||
#' }
|
||||
#' Those could be helpful in determining sensible ranges of the \code{max_depth}
|
||||
#' and \code{min_child_weight} parameters.
|
||||
#' 1. The distribution of the number of leaves in a tree model at a certain depth.
|
||||
#' 2. The distribution of the average weighted number of observations ("cover")
|
||||
#' ending up in leaves at a certain depth.
|
||||
#'
|
||||
#' When \code{which="max.depth"} or \code{which="med.depth"}, plots of either maximum or median depth
|
||||
#' per tree with respect to tree number are created. And \code{which="med.weight"} allows to see how
|
||||
#' Those could be helpful in determining sensible ranges of the `max_depth`
|
||||
#' and `min_child_weight` parameters.
|
||||
#'
|
||||
#' When `which = "max.depth"` or `which = "med.depth"`, plots of either maximum or
|
||||
#' median depth per tree with respect to the tree number are created.
|
||||
#'
|
||||
#' Finally, `which = "med.weight"` allows to see how
|
||||
#' a tree's median absolute leaf weight changes through the iterations.
|
||||
#'
|
||||
#' This function was inspired by the blog post
|
||||
#' \url{https://github.com/aysent/random-forest-leaf-visualization}.
|
||||
#' These functions have been inspired by the blog post
|
||||
#' <https://github.com/aysent/random-forest-leaf-visualization>.
|
||||
#'
|
||||
#' @return
|
||||
#' The return value of the two functions is as follows:
|
||||
#' - `xgb.plot.deepness()`: A "data.table" (invisibly).
|
||||
#' Each row corresponds to a terminal leaf in the model. It contains its information
|
||||
#' about depth, cover, and weight (used in calculating predictions).
|
||||
#' If `plot = TRUE`, also a plot is shown.
|
||||
#' - `xgb.ggplot.deepness()`: When `which = "2x1"`, a list of two "ggplot" objects,
|
||||
#' and a single "ggplot" object otherwise.
|
||||
#'
|
||||
#' Other than producing plots (when \code{plot=TRUE}), the \code{xgb.plot.deepness} function
|
||||
#' silently returns a processed data.table where each row corresponds to a terminal leaf in a tree model,
|
||||
#' and contains information about leaf's depth, cover, and weight (which is used in calculating predictions).
|
||||
#'
|
||||
#' The \code{xgb.ggplot.deepness} silently returns either a list of two ggplot graphs when \code{which="2x1"}
|
||||
#' or a single ggplot graph for the other \code{which} options.
|
||||
#'
|
||||
#' @seealso
|
||||
#'
|
||||
#' \code{\link{xgb.train}}, \code{\link{xgb.model.dt.tree}}.
|
||||
#' @seealso [xgb.train()] and [xgb.model.dt.tree()].
|
||||
#'
|
||||
#' @examples
|
||||
#'
|
||||
#' data(agaricus.train, package='xgboost')
|
||||
#' data(agaricus.train, package = "xgboost")
|
||||
#' ## Keep the number of threads to 2 for examples
|
||||
#' nthread <- 2
|
||||
#' data.table::setDTthreads(nthread)
|
||||
#'
|
||||
#' ## Change max_depth to a higher number to get a more significant result
|
||||
#' bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, max_depth = 6,
|
||||
#' eta = 0.1, nthread = nthread, nrounds = 50, objective = "binary:logistic",
|
||||
#' subsample = 0.5, min_child_weight = 2)
|
||||
#' bst <- xgboost(
|
||||
#' data = agaricus.train$data,
|
||||
#' label = agaricus.train$label,
|
||||
#' max_depth = 6,
|
||||
#' nthread = nthread,
|
||||
#' nrounds = 50,
|
||||
#' objective = "binary:logistic",
|
||||
#' subsample = 0.5,
|
||||
#' min_child_weight = 2
|
||||
#' )
|
||||
#'
|
||||
#' xgb.plot.deepness(bst)
|
||||
#' xgb.ggplot.deepness(bst)
|
||||
#'
|
||||
#' xgb.plot.deepness(bst, which='max.depth', pch=16, col=rgb(0,0,1,0.3), cex=2)
|
||||
#' xgb.plot.deepness(
|
||||
#' bst, which = "max.depth", pch = 16, col = rgb(0, 0, 1, 0.3), cex = 2
|
||||
#' )
|
||||
#'
|
||||
#' xgb.plot.deepness(bst, which='med.weight', pch=16, col=rgb(0,0,1,0.3), cex=2)
|
||||
#' xgb.plot.deepness(
|
||||
#' bst, which = "med.weight", pch = 16, col = rgb(0, 0, 1, 0.3), cex = 2
|
||||
#' )
|
||||
#'
|
||||
#' @rdname xgb.plot.deepness
|
||||
#' @export
|
||||
|
||||
@@ -1,64 +1,75 @@
|
||||
#' Plot feature importance as a bar graph
|
||||
#' Plot feature importance
|
||||
#'
|
||||
#' Represents previously calculated feature importance as a bar graph.
|
||||
#' \code{xgb.plot.importance} uses base R graphics, while \code{xgb.ggplot.importance} uses the ggplot backend.
|
||||
#' - `xgb.plot.importance()` uses base R graphics, while
|
||||
#' - `xgb.ggplot.importance()` uses "ggplot".
|
||||
#'
|
||||
#' @param importance_matrix a \code{data.table} returned by \code{\link{xgb.importance}}.
|
||||
#' @param top_n maximal number of top features to include into the plot.
|
||||
#' @param measure the name of importance measure to plot.
|
||||
#' When \code{NULL}, 'Gain' would be used for trees and 'Weight' would be used for gblinear.
|
||||
#' @param rel_to_first whether importance values should be represented as relative to the highest ranked feature.
|
||||
#' See Details.
|
||||
#' @param left_margin (base R barplot) allows to adjust the left margin size to fit feature names.
|
||||
#' When it is NULL, the existing \code{par('mar')} is used.
|
||||
#' @param cex (base R barplot) passed as \code{cex.names} parameter to \code{barplot}.
|
||||
#' @param plot (base R barplot) whether a barplot should be produced.
|
||||
#' If FALSE, only a data.table is returned.
|
||||
#' @param n_clusters (ggplot only) a \code{numeric} vector containing the min and the max range
|
||||
#' @param importance_matrix A `data.table` as returned by [xgb.importance()].
|
||||
#' @param top_n Maximal number of top features to include into the plot.
|
||||
#' @param measure The name of importance measure to plot.
|
||||
#' When `NULL`, 'Gain' would be used for trees and 'Weight' would be used for gblinear.
|
||||
#' @param rel_to_first Whether importance values should be represented as relative to
|
||||
#' the highest ranked feature, see Details.
|
||||
#' @param left_margin Adjust the left margin size to fit feature names.
|
||||
#' When `NULL`, the existing `par("mar")` is used.
|
||||
#' @param cex Passed as `cex.names` parameter to [graphics::barplot()].
|
||||
#' @param plot Should the barplot be shown? Default is `TRUE`.
|
||||
#' @param n_clusters A numeric vector containing the min and the max range
|
||||
#' of the possible number of clusters of bars.
|
||||
#' @param ... other parameters passed to \code{barplot} (except horiz, border, cex.names, names.arg, and las).
|
||||
#' @param ... Other parameters passed to [graphics::barplot()]
|
||||
#' (except `horiz`, `border`, `cex.names`, `names.arg`, and `las`).
|
||||
#' Only used in `xgb.plot.importance()`.
|
||||
#'
|
||||
#' @details
|
||||
#' The graph represents each feature as a horizontal bar of length proportional to the importance of a feature.
|
||||
#' Features are shown ranked in a decreasing importance order.
|
||||
#' It works for importances from both \code{gblinear} and \code{gbtree} models.
|
||||
#' Features are sorted by decreasing importance.
|
||||
#' It works for both "gblinear" and "gbtree" models.
|
||||
#'
|
||||
#' When \code{rel_to_first = FALSE}, the values would be plotted as they were in \code{importance_matrix}.
|
||||
#' For gbtree model, that would mean being normalized to the total of 1
|
||||
#' When `rel_to_first = FALSE`, the values would be plotted as in `importance_matrix`.
|
||||
#' For a "gbtree" model, that would mean being normalized to the total of 1
|
||||
#' ("what is feature's importance contribution relative to the whole model?").
|
||||
#' For linear models, \code{rel_to_first = FALSE} would show actual values of the coefficients.
|
||||
#' Setting \code{rel_to_first = TRUE} allows to see the picture from the perspective of
|
||||
#' For linear models, `rel_to_first = FALSE` would show actual values of the coefficients.
|
||||
#' Setting `rel_to_first = TRUE` allows to see the picture from the perspective of
|
||||
#' "what is feature's importance contribution relative to the most important feature?"
|
||||
#'
|
||||
#' The ggplot-backend method also performs 1-D clustering of the importance values,
|
||||
#' with bar colors corresponding to different clusters that have somewhat similar importance values.
|
||||
#' The "ggplot" backend performs 1-D clustering of the importance values,
|
||||
#' with bar colors corresponding to different clusters having similar importance values.
|
||||
#'
|
||||
#' @return
|
||||
#' The \code{xgb.plot.importance} function creates a \code{barplot} (when \code{plot=TRUE})
|
||||
#' and silently returns a processed data.table with \code{n_top} features sorted by importance.
|
||||
#' The return value depends on the function:
|
||||
#' - `xgb.plot.importance()`: Invisibly, a "data.table" with `n_top` features sorted
|
||||
#' by importance. If `plot = TRUE`, the values are also plotted as barplot.
|
||||
#' - `xgb.ggplot.importance()`: A customizable "ggplot" object.
|
||||
#' E.g., to change the title, set `+ ggtitle("A GRAPH NAME")`.
|
||||
#'
|
||||
#' The \code{xgb.ggplot.importance} function returns a ggplot graph which could be customized afterwards.
|
||||
#' E.g., to change the title of the graph, add \code{+ ggtitle("A GRAPH NAME")} to the result.
|
||||
#'
|
||||
#' @seealso
|
||||
#' \code{\link[graphics]{barplot}}.
|
||||
#' @seealso [graphics::barplot()]
|
||||
#'
|
||||
#' @examples
|
||||
#' data(agaricus.train)
|
||||
#'
|
||||
#' ## Keep the number of threads to 2 for examples
|
||||
#' nthread <- 2
|
||||
#' data.table::setDTthreads(nthread)
|
||||
#'
|
||||
#' bst <- xgboost(
|
||||
#' data = agaricus.train$data, label = agaricus.train$label, max_depth = 3,
|
||||
#' eta = 1, nthread = nthread, nrounds = 2, objective = "binary:logistic"
|
||||
#' data = agaricus.train$data,
|
||||
#' label = agaricus.train$label,
|
||||
#' max_depth = 3,
|
||||
#' eta = 1,
|
||||
#' nthread = nthread,
|
||||
#' nrounds = 2,
|
||||
#' objective = "binary:logistic"
|
||||
#' )
|
||||
#'
|
||||
#' importance_matrix <- xgb.importance(colnames(agaricus.train$data), model = bst)
|
||||
#' xgb.plot.importance(
|
||||
#' importance_matrix, rel_to_first = TRUE, xlab = "Relative importance"
|
||||
#' )
|
||||
#'
|
||||
#' xgb.plot.importance(importance_matrix, rel_to_first = TRUE, xlab = "Relative importance")
|
||||
#'
|
||||
#' (gg <- xgb.ggplot.importance(importance_matrix, measure = "Frequency", rel_to_first = TRUE))
|
||||
#' gg <- xgb.ggplot.importance(
|
||||
#' importance_matrix, measure = "Frequency", rel_to_first = TRUE
|
||||
#' )
|
||||
#' gg
|
||||
#' gg + ggplot2::ylab("Frequency")
|
||||
#'
|
||||
#' @rdname xgb.plot.importance
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
#' Project all trees on one tree and plot it
|
||||
#' Project all trees on one tree
|
||||
#'
|
||||
#' Visualization of the ensemble of trees as a single collective unit.
|
||||
#'
|
||||
#' @param model produced by the \code{xgb.train} function.
|
||||
#' @param feature_names names of each feature as a \code{character} vector.
|
||||
#' @param features_keep number of features to keep in each position of the multi trees.
|
||||
#' @param plot_width width in pixels of the graph to produce
|
||||
#' @param plot_height height in pixels of the graph to produce
|
||||
#' @param render a logical flag for whether the graph should be rendered (see Value).
|
||||
#' @param ... currently not used
|
||||
#' @inheritParams xgb.plot.tree
|
||||
#' @param features_keep Number of features to keep in each position of the multi trees,
|
||||
#' by default 5.
|
||||
#'
|
||||
#' @details
|
||||
#'
|
||||
@@ -24,33 +20,31 @@
|
||||
#' Moreover, the trees tend to reuse the same features.
|
||||
#'
|
||||
#' The function projects each tree onto one, and keeps for each position the
|
||||
#' \code{features_keep} first features (based on the Gain per feature measure).
|
||||
#' `features_keep` first features (based on the Gain per feature measure).
|
||||
#'
|
||||
#' This function is inspired by this blog post:
|
||||
#' \url{https://wellecks.wordpress.com/2015/02/21/peering-into-the-black-box-visualizing-lambdamart/}
|
||||
#' <https://wellecks.wordpress.com/2015/02/21/peering-into-the-black-box-visualizing-lambdamart/>
|
||||
#'
|
||||
#' @return
|
||||
#'
|
||||
#' When \code{render = TRUE}:
|
||||
#' returns a rendered graph object which is an \code{htmlwidget} of class \code{grViz}.
|
||||
#' Similar to ggplot objects, it needs to be printed to see it when not running from command line.
|
||||
#'
|
||||
#' When \code{render = FALSE}:
|
||||
#' silently returns a graph object which is of DiagrammeR's class \code{dgr_graph}.
|
||||
#' This could be useful if one wants to modify some of the graph attributes
|
||||
#' before rendering the graph with \code{\link[DiagrammeR]{render_graph}}.
|
||||
#' @inherit xgb.plot.tree return
|
||||
#'
|
||||
#' @examples
|
||||
#'
|
||||
#' data(agaricus.train, package='xgboost')
|
||||
#' data(agaricus.train, package = "xgboost")
|
||||
#'
|
||||
#' ## Keep the number of threads to 2 for examples
|
||||
#' nthread <- 2
|
||||
#' data.table::setDTthreads(nthread)
|
||||
#'
|
||||
#' bst <- xgboost(
|
||||
#' data = agaricus.train$data, label = agaricus.train$label, max_depth = 15,
|
||||
#' eta = 1, nthread = nthread, nrounds = 30, objective = "binary:logistic",
|
||||
#' min_child_weight = 50, verbose = 0
|
||||
#' data = agaricus.train$data,
|
||||
#' label = agaricus.train$label,
|
||||
#' max_depth = 15,
|
||||
#' eta = 1,
|
||||
#' nthread = nthread,
|
||||
#' nrounds = 30,
|
||||
#' objective = "binary:logistic",
|
||||
#' min_child_weight = 50,
|
||||
#' verbose = 0
|
||||
#' )
|
||||
#'
|
||||
#' p <- xgb.plot.multi.trees(model = bst, features_keep = 3)
|
||||
@@ -58,10 +52,13 @@
|
||||
#'
|
||||
#' \dontrun{
|
||||
#' # Below is an example of how to save this plot to a file.
|
||||
#' # Note that for `export_graph` to work, the DiagrammeRsvg and rsvg packages must also be installed.
|
||||
#' # Note that for export_graph() to work, the {DiagrammeRsvg} and {rsvg} packages
|
||||
#' # must also be installed.
|
||||
#'
|
||||
#' library(DiagrammeR)
|
||||
#' gr <- xgb.plot.multi.trees(model=bst, features_keep = 3, render=FALSE)
|
||||
#' export_graph(gr, 'tree.pdf', width=1500, height=600)
|
||||
#'
|
||||
#' gr <- xgb.plot.multi.trees(model = bst, features_keep = 3, render = FALSE)
|
||||
#' export_graph(gr, "tree.pdf", width = 1500, height = 600)
|
||||
#' }
|
||||
#'
|
||||
#' @export
|
||||
|
||||
@@ -1,110 +1,165 @@
|
||||
#' SHAP contribution dependency plots
|
||||
#' SHAP dependence plots
|
||||
#'
|
||||
#' Visualizing the SHAP feature contribution to prediction dependencies on feature value.
|
||||
#' Visualizes SHAP values against feature values to gain an impression of feature effects.
|
||||
#'
|
||||
#' @param data data as a \code{matrix} or \code{dgCMatrix}.
|
||||
#' @param shap_contrib a matrix of SHAP contributions that was computed earlier for the above
|
||||
#' \code{data}. When it is NULL, it is computed internally using \code{model} and \code{data}.
|
||||
#' @param features a vector of either column indices or of feature names to plot. When it is NULL,
|
||||
#' feature importance is calculated, and \code{top_n} high ranked features are taken.
|
||||
#' @param top_n when \code{features} is NULL, top_n `[1, 100]` most important features in a model are taken.
|
||||
#' @param model an \code{xgb.Booster} model. It has to be provided when either \code{shap_contrib}
|
||||
#' or \code{features} is missing.
|
||||
#' @param trees passed to \code{\link{xgb.importance}} when \code{features = NULL}.
|
||||
#' @param target_class is only relevant for multiclass models. When it is set to a 0-based class index,
|
||||
#' only SHAP contributions for that specific class are used.
|
||||
#' If it is not set, SHAP importances are averaged over all classes.
|
||||
#' @param approxcontrib passed to \code{\link{predict.xgb.Booster}} when \code{shap_contrib = NULL}.
|
||||
#' @param subsample a random fraction of data points to use for plotting. When it is NULL,
|
||||
#' it is set so that up to 100K data points are used.
|
||||
#' @param n_col a number of columns in a grid of plots.
|
||||
#' @param col color of the scatterplot markers.
|
||||
#' @param pch scatterplot marker.
|
||||
#' @param discrete_n_uniq a maximal number of unique values in a feature to consider it as discrete.
|
||||
#' @param discrete_jitter an \code{amount} parameter of jitter added to discrete features' positions.
|
||||
#' @param ylab a y-axis label in 1D plots.
|
||||
#' @param plot_NA whether the contributions of cases with missing values should also be plotted.
|
||||
#' @param col_NA a color of marker for missing value contributions.
|
||||
#' @param pch_NA a marker type for NA values.
|
||||
#' @param pos_NA a relative position of the x-location where NA values are shown:
|
||||
#' \code{min(x) + (max(x) - min(x)) * pos_NA}.
|
||||
#' @param plot_loess whether to plot loess-smoothed curves. The smoothing is only done for features with
|
||||
#' more than 5 distinct values.
|
||||
#' @param col_loess a color to use for the loess curves.
|
||||
#' @param span_loess the \code{span} parameter in \code{\link[stats]{loess}}'s call.
|
||||
#' @param which whether to do univariate or bivariate plotting. NOTE: only 1D is implemented so far.
|
||||
#' @param plot whether a plot should be drawn. If FALSE, only a list of matrices is returned.
|
||||
#' @param ... other parameters passed to \code{plot}.
|
||||
#' @param data The data to explain as a `matrix` or `dgCMatrix`.
|
||||
#' @param shap_contrib Matrix of SHAP contributions of `data`.
|
||||
#' The default (`NULL`) computes it from `model` and `data`.
|
||||
#' @param features Vector of column indices or feature names to plot.
|
||||
#' When `NULL` (default), the `top_n` most important features are selected
|
||||
#' by [xgb.importance()].
|
||||
#' @param top_n How many of the most important features (<= 100) should be selected?
|
||||
#' By default 1 for SHAP dependence and 10 for SHAP summary).
|
||||
#' Only used when `features = NULL`.
|
||||
#' @param model An `xgb.Booster` model. Only required when `shap_contrib = NULL` or
|
||||
#' `features = NULL`.
|
||||
#' @param trees Passed to [xgb.importance()] when `features = NULL`.
|
||||
#' @param target_class Only relevant for multiclass models. The default (`NULL`)
|
||||
#' averages the SHAP values over all classes. Pass a (0-based) class index
|
||||
#' to show only SHAP values of that class.
|
||||
#' @param approxcontrib Passed to `predict()` when `shap_contrib = NULL`.
|
||||
#' @param subsample Fraction of data points randomly picked for plotting.
|
||||
#' The default (`NULL`) will use up to 100k data points.
|
||||
#' @param n_col Number of columns in a grid of plots.
|
||||
#' @param col Color of the scatterplot markers.
|
||||
#' @param pch Scatterplot marker.
|
||||
#' @param discrete_n_uniq Maximal number of unique feature values to consider the
|
||||
#' feature as discrete.
|
||||
#' @param discrete_jitter Jitter amount added to the values of discrete features.
|
||||
#' @param ylab The y-axis label in 1D plots.
|
||||
#' @param plot_NA Should contributions of cases with missing values be plotted?
|
||||
#' Default is `TRUE`.
|
||||
#' @param col_NA Color of marker for missing value contributions.
|
||||
#' @param pch_NA Marker type for `NA` values.
|
||||
#' @param pos_NA Relative position of the x-location where `NA` values are shown:
|
||||
#' `min(x) + (max(x) - min(x)) * pos_NA`.
|
||||
#' @param plot_loess Should loess-smoothed curves be plotted? (Default is `TRUE`).
|
||||
#' The smoothing is only done for features with more than 5 distinct values.
|
||||
#' @param col_loess Color of loess curves.
|
||||
#' @param span_loess The `span` parameter of [stats::loess()].
|
||||
#' @param which Whether to do univariate or bivariate plotting. Currently, only "1d" is implemented.
|
||||
#' @param plot Should the plot be drawn? (Default is `TRUE`).
|
||||
#' If `FALSE`, only a list of matrices is returned.
|
||||
#' @param ... Other parameters passed to [graphics::plot()].
|
||||
#'
|
||||
#' @details
|
||||
#'
|
||||
#' These scatterplots represent how SHAP feature contributions depend of feature values.
|
||||
#' The similarity to partial dependency plots is that they also give an idea for how feature values
|
||||
#' affect predictions. However, in partial dependency plots, we usually see marginal dependencies
|
||||
#' of model prediction on feature value, while SHAP contribution dependency plots display the estimated
|
||||
#' contributions of a feature to model prediction for each individual case.
|
||||
#' The similarity to partial dependence plots is that they also give an idea for how feature values
|
||||
#' affect predictions. However, in partial dependence plots, we see marginal dependencies
|
||||
#' of model prediction on feature value, while SHAP dependence plots display the estimated
|
||||
#' contributions of a feature to the prediction for each individual case.
|
||||
#'
|
||||
#' When \code{plot_loess = TRUE} is set, feature values are rounded to 3 significant digits and
|
||||
#' weighted LOESS is computed and plotted, where weights are the numbers of data points
|
||||
#' When `plot_loess = TRUE`, feature values are rounded to three significant digits and
|
||||
#' weighted LOESS is computed and plotted, where the weights are the numbers of data points
|
||||
#' at each rounded value.
|
||||
#'
|
||||
#' Note: SHAP contributions are shown on the scale of model margin. E.g., for a logistic binomial objective,
|
||||
#' the margin is prediction before a sigmoidal transform into probability-like values.
|
||||
#' Note: SHAP contributions are on the scale of the model margin.
|
||||
#' E.g., for a logistic binomial objective, the margin is on log-odds scale.
|
||||
#' Also, since SHAP stands for "SHapley Additive exPlanation" (model prediction = sum of SHAP
|
||||
#' contributions for all features + bias), depending on the objective used, transforming SHAP
|
||||
#' contributions for a feature from the marginal to the prediction space is not necessarily
|
||||
#' a meaningful thing to do.
|
||||
#'
|
||||
#' @return
|
||||
#'
|
||||
#' In addition to producing plots (when \code{plot=TRUE}), it silently returns a list of two matrices:
|
||||
#' \itemize{
|
||||
#' \item \code{data} the values of selected features;
|
||||
#' \item \code{shap_contrib} the contributions of selected features.
|
||||
#' }
|
||||
#' In addition to producing plots (when `plot = TRUE`), it silently returns a list of two matrices:
|
||||
#' - `data`: Feature value matrix.
|
||||
#' - `shap_contrib`: Corresponding SHAP value matrix.
|
||||
#'
|
||||
#' @references
|
||||
#'
|
||||
#' Scott M. Lundberg, Su-In Lee, "A Unified Approach to Interpreting Model Predictions", NIPS Proceedings 2017, \url{https://arxiv.org/abs/1705.07874}
|
||||
#'
|
||||
#' Scott M. Lundberg, Su-In Lee, "Consistent feature attribution for tree ensembles", \url{https://arxiv.org/abs/1706.06060}
|
||||
#' 1. Scott M. Lundberg, Su-In Lee, "A Unified Approach to Interpreting Model Predictions",
|
||||
#' NIPS Proceedings 2017, <https://arxiv.org/abs/1705.07874>
|
||||
#' 2. Scott M. Lundberg, Su-In Lee, "Consistent feature attribution for tree ensembles",
|
||||
#' <https://arxiv.org/abs/1706.06060>
|
||||
#'
|
||||
#' @examples
|
||||
#'
|
||||
#' data(agaricus.train, package='xgboost')
|
||||
#' data(agaricus.test, package='xgboost')
|
||||
#' data(agaricus.train, package = "xgboost")
|
||||
#' data(agaricus.test, package = "xgboost")
|
||||
#'
|
||||
#' ## Keep the number of threads to 1 for examples
|
||||
#' nthread <- 1
|
||||
#' data.table::setDTthreads(nthread)
|
||||
#' nrounds <- 20
|
||||
#'
|
||||
#' bst <- xgboost(agaricus.train$data, agaricus.train$label, nrounds = nrounds,
|
||||
#' eta = 0.1, max_depth = 3, subsample = .5,
|
||||
#' method = "hist", objective = "binary:logistic", nthread = nthread, verbose = 0)
|
||||
#' bst <- xgboost(
|
||||
#' agaricus.train$data,
|
||||
#' agaricus.train$label,
|
||||
#' nrounds = nrounds,
|
||||
#' eta = 0.1,
|
||||
#' max_depth = 3,
|
||||
#' subsample = 0.5,
|
||||
#' objective = "binary:logistic",
|
||||
#' nthread = nthread,
|
||||
#' verbose = 0
|
||||
#' )
|
||||
#'
|
||||
#' xgb.plot.shap(agaricus.test$data, model = bst, features = "odor=none")
|
||||
#'
|
||||
#' contr <- predict(bst, agaricus.test$data, predcontrib = TRUE)
|
||||
#' xgb.plot.shap(agaricus.test$data, contr, model = bst, top_n = 12, n_col = 3)
|
||||
#' xgb.ggplot.shap.summary(agaricus.test$data, contr, model = bst, top_n = 12) # Summary plot
|
||||
#'
|
||||
#' # multiclass example - plots for each class separately:
|
||||
#' # Summary plot
|
||||
#' xgb.ggplot.shap.summary(agaricus.test$data, contr, model = bst, top_n = 12)
|
||||
#'
|
||||
#' # Multiclass example - plots for each class separately:
|
||||
#' nclass <- 3
|
||||
#' x <- as.matrix(iris[, -5])
|
||||
#' set.seed(123)
|
||||
#' is.na(x[sample(nrow(x) * 4, 30)]) <- TRUE # introduce some missing values
|
||||
#' mbst <- xgboost(data = x, label = as.numeric(iris$Species) - 1, nrounds = nrounds,
|
||||
#' max_depth = 2, eta = 0.3, subsample = .5, nthread = nthread,
|
||||
#' objective = "multi:softprob", num_class = nclass, verbose = 0)
|
||||
#' trees0 <- seq(from=0, by=nclass, length.out=nrounds)
|
||||
#'
|
||||
#' mbst <- xgboost(
|
||||
#' data = x,
|
||||
#' label = as.numeric(iris$Species) - 1,
|
||||
#' nrounds = nrounds,
|
||||
#' max_depth = 2,
|
||||
#' eta = 0.3,
|
||||
#' subsample = 0.5,
|
||||
#' nthread = nthread,
|
||||
#' objective = "multi:softprob",
|
||||
#' num_class = nclass,
|
||||
#' verbose = 0
|
||||
#' )
|
||||
#' trees0 <- seq(from = 0, by = nclass, length.out = nrounds)
|
||||
#' col <- rgb(0, 0, 1, 0.5)
|
||||
#' xgb.plot.shap(x, model = mbst, trees = trees0, target_class = 0, top_n = 4,
|
||||
#' n_col = 2, col = col, pch = 16, pch_NA = 17)
|
||||
#' xgb.plot.shap(x, model = mbst, trees = trees0 + 1, target_class = 1, top_n = 4,
|
||||
#' n_col = 2, col = col, pch = 16, pch_NA = 17)
|
||||
#' xgb.plot.shap(x, model = mbst, trees = trees0 + 2, target_class = 2, top_n = 4,
|
||||
#' n_col = 2, col = col, pch = 16, pch_NA = 17)
|
||||
#' xgb.ggplot.shap.summary(x, model = mbst, target_class = 0, top_n = 4) # Summary plot
|
||||
#' xgb.plot.shap(
|
||||
#' x,
|
||||
#' model = mbst,
|
||||
#' trees = trees0,
|
||||
#' target_class = 0,
|
||||
#' top_n = 4,
|
||||
#' n_col = 2,
|
||||
#' col = col,
|
||||
#' pch = 16,
|
||||
#' pch_NA = 17
|
||||
#' )
|
||||
#'
|
||||
#' xgb.plot.shap(
|
||||
#' x,
|
||||
#' model = mbst,
|
||||
#' trees = trees0 + 1,
|
||||
#' target_class = 1,
|
||||
#' top_n = 4,
|
||||
#' n_col = 2,
|
||||
#' col = col,
|
||||
#' pch = 16,
|
||||
#' pch_NA = 17
|
||||
#' )
|
||||
#'
|
||||
#' xgb.plot.shap(
|
||||
#' x,
|
||||
#' model = mbst,
|
||||
#' trees = trees0 + 2,
|
||||
#' target_class = 2,
|
||||
#' top_n = 4,
|
||||
#' n_col = 2,
|
||||
#' col = col,
|
||||
#' pch = 16,
|
||||
#' pch_NA = 17
|
||||
#' )
|
||||
#'
|
||||
#' # Summary plot
|
||||
#' xgb.ggplot.shap.summary(x, model = mbst, target_class = 0, top_n = 4)
|
||||
#'
|
||||
#' @rdname xgb.plot.shap
|
||||
#' @export
|
||||
@@ -187,41 +242,48 @@ xgb.plot.shap <- function(data, shap_contrib = NULL, features = NULL, top_n = 1,
|
||||
invisible(list(data = data, shap_contrib = shap_contrib))
|
||||
}
|
||||
|
||||
#' SHAP contribution dependency summary plot
|
||||
#' SHAP summary plot
|
||||
#'
|
||||
#' Compare SHAP contributions of different features.
|
||||
#' Visualizes SHAP contributions of different features.
|
||||
#'
|
||||
#' A point plot (each point representing one sample from \code{data}) is
|
||||
#' A point plot (each point representing one observation from `data`) is
|
||||
#' produced for each feature, with the points plotted on the SHAP value axis.
|
||||
#' Each point (observation) is coloured based on its feature value. The plot
|
||||
#' hence allows us to see which features have a negative / positive contribution
|
||||
#' Each point (observation) is coloured based on its feature value.
|
||||
#'
|
||||
#' The plot allows to see which features have a negative / positive contribution
|
||||
#' on the model prediction, and whether the contribution is different for larger
|
||||
#' or smaller values of the feature. We effectively try to replicate the
|
||||
#' \code{summary_plot} function from <https://github.com/shap/shap>.
|
||||
#' or smaller values of the feature. Inspired by the summary plot of
|
||||
#' <https://github.com/shap/shap>.
|
||||
#'
|
||||
#' @inheritParams xgb.plot.shap
|
||||
#'
|
||||
#' @return A \code{ggplot2} object.
|
||||
#' @return A `ggplot2` object.
|
||||
#' @export
|
||||
#'
|
||||
#' @examples # See \code{\link{xgb.plot.shap}}.
|
||||
#' @seealso \code{\link{xgb.plot.shap}}, \code{\link{xgb.ggplot.shap.summary}},
|
||||
#' \url{https://github.com/shap/shap}
|
||||
#' @examples
|
||||
#' # See examples in xgb.plot.shap()
|
||||
#'
|
||||
#' @seealso [xgb.plot.shap()], [xgb.ggplot.shap.summary()],
|
||||
#' and the Python library <https://github.com/shap/shap>.
|
||||
xgb.plot.shap.summary <- function(data, shap_contrib = NULL, features = NULL, top_n = 10, model = NULL,
|
||||
trees = NULL, target_class = NULL, approxcontrib = FALSE, subsample = NULL) {
|
||||
# Only ggplot implementation is available.
|
||||
xgb.ggplot.shap.summary(data, shap_contrib, features, top_n, model, trees, target_class, approxcontrib, subsample)
|
||||
}
|
||||
|
||||
#' Prepare data for SHAP plots. To be used in xgb.plot.shap, xgb.plot.shap.summary, etc.
|
||||
#' Internal utility function.
|
||||
#' Prepare data for SHAP plots
|
||||
#'
|
||||
#' Internal function used in [xgb.plot.shap()], [xgb.plot.shap.summary()], etc.
|
||||
#'
|
||||
#' @inheritParams xgb.plot.shap
|
||||
#' @param max_observations Maximum number of observations to consider.
|
||||
#' @keywords internal
|
||||
#' @noRd
|
||||
#'
|
||||
#' @return A list containing: 'data', a matrix containing sample observations
|
||||
#' and their feature values; 'shap_contrib', a matrix containing the SHAP contribution
|
||||
#' values for these observations.
|
||||
#' @return
|
||||
#' A list containing:
|
||||
#' - `data`: The matrix of feature values.
|
||||
#' - `shap_contrib`: The matrix with corresponding SHAP values.
|
||||
xgb.shap.data <- function(data, shap_contrib = NULL, features = NULL, top_n = 1, model = NULL,
|
||||
trees = NULL, target_class = NULL, approxcontrib = FALSE,
|
||||
subsample = NULL, max_observations = 100000) {
|
||||
|
||||
@@ -1,69 +1,78 @@
|
||||
#' Plot a boosted tree model
|
||||
#' Plot boosted trees
|
||||
#'
|
||||
#' Read a tree model text dump and plot the model.
|
||||
#'
|
||||
#' @param feature_names names of each feature as a \code{character} vector.
|
||||
#' @param model produced by the \code{xgb.train} function.
|
||||
#' @param trees an integer vector of tree indices that should be visualized.
|
||||
#' If set to \code{NULL}, all trees of the model are included.
|
||||
#' IMPORTANT: the tree index in xgboost model is zero-based
|
||||
#' (e.g., use \code{trees = 0:2} for the first 3 trees in a model).
|
||||
#' @param plot_width the width of the diagram in pixels.
|
||||
#' @param plot_height the height of the diagram in pixels.
|
||||
#' @param render a logical flag for whether the graph should be rendered (see Value).
|
||||
#' @param feature_names Character vector used to overwrite the feature names
|
||||
#' of the model. The default (`NULL`) uses the original feature names.
|
||||
#' @param model Object of class `xgb.Booster`.
|
||||
#' @param trees An integer vector of tree indices that should be used.
|
||||
#' The default (`NULL`) uses all trees.
|
||||
#' Useful, e.g., in multiclass classification to get only
|
||||
#' the trees of one class. *Important*: the tree index in XGBoost models
|
||||
#' is zero-based (e.g., use `trees = 0:2` for the first three trees).
|
||||
#' @param plot_width,plot_height Width and height of the graph in pixels.
|
||||
#' The values are passed to [DiagrammeR::render_graph()].
|
||||
#' @param render Should the graph be rendered or not? The default is `TRUE`.
|
||||
#' @param show_node_id a logical flag for whether to show node id's in the graph.
|
||||
#' @param ... currently not used.
|
||||
#'
|
||||
#' @details
|
||||
#'
|
||||
#' The content of each node is organised that way:
|
||||
#'
|
||||
#' \itemize{
|
||||
#' \item Feature name.
|
||||
#' \item \code{Cover}: The sum of second order gradient of training data classified to the leaf.
|
||||
#' If it is square loss, this simply corresponds to the number of instances seen by a split
|
||||
#' or collected by a leaf during training.
|
||||
#' The deeper in the tree a node is, the lower this metric will be.
|
||||
#' \item \code{Gain} (for split nodes): the information gain metric of a split
|
||||
#' The content of each node is visualized like this:
|
||||
#' - *Feature name*.
|
||||
#' - *Cover:* The sum of second order gradients of training data.
|
||||
#' For the squared loss, this simply corresponds to the number of instances in the node.
|
||||
#' The deeper in the tree, the lower the value.
|
||||
#' - *Gain* (for split nodes): Information gain metric of a split
|
||||
#' (corresponds to the importance of the node in the model).
|
||||
#' \item \code{Value} (for leafs): the margin value that the leaf may contribute to prediction.
|
||||
#' }
|
||||
#' The tree root nodes also indicate the Tree index (0-based).
|
||||
#' - *Value* (for leaves): Margin value that the leaf may contribute to the prediction.
|
||||
#'
|
||||
#' The tree root nodes also indicate the tree index (0-based).
|
||||
#'
|
||||
#' The "Yes" branches are marked by the "< split_value" label.
|
||||
#' The branches that also used for missing values are marked as bold
|
||||
#' The branches also used for missing values are marked as bold
|
||||
#' (as in "carrying extra capacity").
|
||||
#'
|
||||
#' This function uses \href{https://www.graphviz.org/}{GraphViz} as a backend of DiagrammeR.
|
||||
#' This function uses [GraphViz](https://www.graphviz.org/) as DiagrammeR backend.
|
||||
#'
|
||||
#' @return
|
||||
#'
|
||||
#' When \code{render = TRUE}:
|
||||
#' returns a rendered graph object which is an \code{htmlwidget} of class \code{grViz}.
|
||||
#' Similar to ggplot objects, it needs to be printed to see it when not running from command line.
|
||||
#'
|
||||
#' When \code{render = FALSE}:
|
||||
#' silently returns a graph object which is of DiagrammeR's class \code{dgr_graph}.
|
||||
#' This could be useful if one wants to modify some of the graph attributes
|
||||
#' before rendering the graph with \code{\link[DiagrammeR]{render_graph}}.
|
||||
#' The value depends on the `render` parameter:
|
||||
#' - If `render = TRUE` (default): Rendered graph object which is an htmlwidget of
|
||||
#' class `grViz`. Similar to "ggplot" objects, it needs to be printed when not
|
||||
#' running from the command line.
|
||||
#' - If `render = FALSE`: Graph object which is of DiagrammeR's class `dgr_graph`.
|
||||
#' This could be useful if one wants to modify some of the graph attributes
|
||||
#' before rendering the graph with [DiagrammeR::render_graph()].
|
||||
#'
|
||||
#' @examples
|
||||
#' data(agaricus.train, package='xgboost')
|
||||
#' data(agaricus.train, package = "xgboost")
|
||||
#'
|
||||
#' bst <- xgboost(
|
||||
#' data = agaricus.train$data,
|
||||
#' label = agaricus.train$label,
|
||||
#' max_depth = 3,
|
||||
#' eta = 1,
|
||||
#' nthread = 2,
|
||||
#' nrounds = 2,
|
||||
#' objective = "binary:logistic"
|
||||
#' )
|
||||
#'
|
||||
#' bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, max_depth = 3,
|
||||
#' eta = 1, nthread = 2, nrounds = 2,objective = "binary:logistic")
|
||||
#' # plot all the trees
|
||||
#' xgb.plot.tree(model = bst)
|
||||
#'
|
||||
#' # plot only the first tree and display the node ID:
|
||||
#' xgb.plot.tree(model = bst, trees = 0, show_node_id = TRUE)
|
||||
#'
|
||||
#' \dontrun{
|
||||
#' # Below is an example of how to save this plot to a file.
|
||||
#' # Note that for `export_graph` to work, the DiagrammeRsvg and rsvg packages must also be installed.
|
||||
#' # Note that for export_graph() to work, the {DiagrammeRsvg}
|
||||
#' # and {rsvg} packages must also be installed.
|
||||
#'
|
||||
#' library(DiagrammeR)
|
||||
#' gr <- xgb.plot.tree(model=bst, trees=0:1, render=FALSE)
|
||||
#' export_graph(gr, 'tree.pdf', width=1500, height=1900)
|
||||
#' export_graph(gr, 'tree.png', width=1500, height=1900)
|
||||
#'
|
||||
#' gr <- xgb.plot.tree(model = bst, trees = 0:1, render = FALSE)
|
||||
#' export_graph(gr, "tree.pdf", width = 1500, height = 1900)
|
||||
#' export_graph(gr, "tree.png", width = 1500, height = 1900)
|
||||
#' }
|
||||
#'
|
||||
#' @export
|
||||
|
||||
Reference in New Issue
Block a user