[R] Remove unusable 'feature_names' argument and make 'model' first argument in inspection functions (#9939)

This commit is contained in:
david-cortes
2024-01-15 10:16:30 +01:00
committed by GitHub
parent 1168a68872
commit 547abb8c12
10 changed files with 30 additions and 74 deletions

View File

@@ -113,19 +113,12 @@
#' xgb.importance(model = mbst)
#'
#' @export
xgb.importance <- function(feature_names = NULL, model = NULL, trees = NULL,
xgb.importance <- function(model = NULL, feature_names = getinfo(model, "feature_name"), trees = NULL,
data = NULL, label = NULL, target = NULL) {
if (!(is.null(data) && is.null(label) && is.null(target)))
warning("xgb.importance: parameters 'data', 'label' and 'target' are deprecated")
if (is.null(feature_names)) {
model_feature_names <- xgb.feature_names(model)
if (NROW(model_feature_names)) {
feature_names <- model_feature_names
}
}
if (!(is.null(feature_names) || is.character(feature_names)))
stop("feature_names: Has to be a character vector")

View File

@@ -2,11 +2,8 @@
#'
#' Parse a boosted tree model text dump into a `data.table` structure.
#'
#' @param feature_names Character vector of feature names. If the model already
#' contains feature names, those will be used when \code{feature_names=NULL} (default value).
#'
#' Note that, if the model already contains feature names, it's \bold{not} possible to override them here.
#' @param model Object of class `xgb.Booster`.
#' @param model Object of class `xgb.Booster`. If it contains feature names (they can be set through
#' \link{setinfo}), they will be used in the output from this function.
#' @param text Character vector previously generated by the function [xgb.dump()]
#' (called with parameter `with_stats = TRUE`). `text` takes precedence over `model`.
#' @param trees An integer vector of tree indices that should be used.
@@ -58,7 +55,7 @@
#'
#' # This bst model already has feature_names stored with it, so those would be used when
#' # feature_names is not set:
#' (dt <- xgb.model.dt.tree(model = bst))
#' dt <- xgb.model.dt.tree(bst)
#'
#' # How to match feature names of splits that are following a current 'Yes' branch:
#' merge(
@@ -69,7 +66,7 @@
#' ]
#'
#' @export
xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL,
xgb.model.dt.tree <- function(model = NULL, text = NULL,
trees = NULL, use_int_id = FALSE, ...) {
check.deprecation(...)
@@ -79,24 +76,15 @@ xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL,
" (or NULL if 'model' was provided).")
}
model_feature_names <- NULL
if (inherits(model, "xgb.Booster")) {
model_feature_names <- xgb.feature_names(model)
if (NROW(model_feature_names) && !is.null(feature_names)) {
stop("'model' contains feature names. Cannot override them.")
}
}
if (is.null(feature_names) && !is.null(model) && !is.null(model_feature_names))
feature_names <- model_feature_names
if (!(is.null(feature_names) || is.character(feature_names))) {
stop("feature_names: must be a character vector")
}
if (!(is.null(trees) || is.numeric(trees))) {
stop("trees: must be a vector of integers.")
}
feature_names <- NULL
if (inherits(model, "xgb.Booster")) {
feature_names <- xgb.feature_names(model)
}
from_text <- TRUE
if (is.null(text)) {
text <- xgb.dump(model = model, with_stats = TRUE)
@@ -134,7 +122,7 @@ xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL,
branch_rx_w_names <- paste0("\\d+:\\[(.+)<(", anynumber_regex, ")\\] yes=(\\d+),no=(\\d+),missing=(\\d+),",
"gain=(", anynumber_regex, "),cover=(", anynumber_regex, ")")
text_has_feature_names <- FALSE
if (NROW(model_feature_names)) {
if (NROW(feature_names)) {
branch_rx <- branch_rx_w_names
text_has_feature_names <- TRUE
} else {
@@ -148,9 +136,6 @@ xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL,
}
}
}
if (text_has_feature_names && is.null(model) && !is.null(feature_names)) {
stop("'text' contains feature names. Cannot override them.")
}
branch_cols <- c("Feature", "Split", "Yes", "No", "Missing", "Gain", "Cover")
td[
isLeaf == FALSE,

View File

@@ -62,13 +62,13 @@
#' }
#'
#' @export
xgb.plot.multi.trees <- function(model, feature_names = NULL, features_keep = 5, plot_width = NULL, plot_height = NULL,
xgb.plot.multi.trees <- function(model, features_keep = 5, plot_width = NULL, plot_height = NULL,
render = TRUE, ...) {
if (!requireNamespace("DiagrammeR", quietly = TRUE)) {
stop("DiagrammeR is required for xgb.plot.multi.trees")
}
check.deprecation(...)
tree.matrix <- xgb.model.dt.tree(feature_names = feature_names, model = model)
tree.matrix <- xgb.model.dt.tree(model = model)
# first number of the path represents the tree, then the following numbers are related to the path to follow
# root init

View File

@@ -2,9 +2,8 @@
#'
#' Read a tree model text dump and plot the model.
#'
#' @param feature_names Character vector used to overwrite the feature names
#' of the model. The default (`NULL`) uses the original feature names.
#' @param model Object of class `xgb.Booster`.
#' @param model Object of class `xgb.Booster`. If it contains feature names (they can be set through
#' \link{setinfo}), they will be used in the output from this function.
#' @param trees An integer vector of tree indices that should be used.
#' The default (`NULL`) uses all trees.
#' Useful, e.g., in multiclass classification to get only
@@ -103,7 +102,7 @@
#' }
#'
#' @export
xgb.plot.tree <- function(feature_names = NULL, model = NULL, trees = NULL, plot_width = NULL, plot_height = NULL,
xgb.plot.tree <- function(model = NULL, trees = NULL, plot_width = NULL, plot_height = NULL,
render = TRUE, show_node_id = FALSE, style = c("R", "xgboost"), ...) {
check.deprecation(...)
if (!inherits(model, "xgb.Booster")) {
@@ -120,17 +119,12 @@ xgb.plot.tree <- function(feature_names = NULL, model = NULL, trees = NULL, plot
if (NROW(trees) != 1L || !render || show_node_id) {
stop("style='xgboost' is only supported for single, rendered tree, without node IDs.")
}
if (!is.null(feature_names)) {
stop(
"style='xgboost' cannot override 'feature_names'. Will automatically take them from the model."
)
}
txt <- xgb.dump(model, dump_format = "dot")
return(DiagrammeR::grViz(txt[[trees + 1]], width = plot_width, height = plot_height))
}
dt <- xgb.model.dt.tree(feature_names = feature_names, model = model, trees = trees)
dt <- xgb.model.dt.tree(model = model, trees = trees)
dt[, label := paste0(Feature, "\nCover: ", Cover, ifelse(Feature == "Leaf", "\nValue: ", "\nGain: "), Gain)]
if (show_node_id)