202 lines
7.6 KiB
R
202 lines
7.6 KiB
R
#' Parse model text dump
|
|
#'
|
|
#' Parse a boosted tree model text dump into a `data.table` structure.
|
|
#'
|
|
#' @param model Object of class `xgb.Booster`. If it contains feature names (they can be set through
|
|
#' \link{setinfo}), they will be used in the output from this function.
|
|
#' @param text Character vector previously generated by the function [xgb.dump()]
|
|
#' (called with parameter `with_stats = TRUE`). `text` takes precedence over `model`.
|
|
#' @param trees An integer vector of tree indices that should be used.
|
|
#' The default (`NULL`) uses all trees.
|
|
#' Useful, e.g., in multiclass classification to get only
|
|
#' the trees of one class. *Important*: the tree index in XGBoost models
|
|
#' is zero-based (e.g., use `trees = 0:4` for the first five trees).
|
|
#' @param use_int_id A logical flag indicating whether nodes in columns "Yes", "No", and
|
|
#' "Missing" should be represented as integers (when `TRUE`) or as "Tree-Node"
|
|
#' character strings (when `FALSE`, default).
|
|
#' @param ... Currently not used.
|
|
#'
|
|
#' @return
|
|
#' A `data.table` with detailed information about tree nodes. It has the following columns:
|
|
#' - `Tree`: integer ID of a tree in a model (zero-based index).
|
|
#' - `Node`: integer ID of a node in a tree (zero-based index).
|
|
#' - `ID`: character identifier of a node in a model (only when `use_int_id = FALSE`).
|
|
#' - `Feature`: for a branch node, a feature ID or name (when available);
|
|
#' for a leaf node, it simply labels it as `"Leaf"`.
|
|
#' - `Split`: location of the split for a branch node (split condition is always "less than").
|
|
#' - `Yes`: ID of the next node when the split condition is met.
|
|
#' - `No`: ID of the next node when the split condition is not met.
|
|
#' - `Missing`: ID of the next node when the branch value is missing.
|
|
#' - `Gain`: either the split gain (change in loss) or the leaf value.
|
|
#' - `Cover`: metric related to the number of observations either seen by a split
|
|
#' or collected by a leaf during training.
|
|
#'
|
|
#' When `use_int_id = FALSE`, columns "Yes", "No", and "Missing" point to model-wide node identifiers
|
|
#' in the "ID" column. When `use_int_id = TRUE`, those columns point to node identifiers from
|
|
#' the corresponding trees in the "Node" column.
|
|
#'
|
|
#' @examples
|
|
#' # Basic use:
|
|
#'
|
|
#' data(agaricus.train, package = "xgboost")
|
|
#' ## Keep the number of threads to 1 for examples
|
|
#' nthread <- 1
|
|
#' data.table::setDTthreads(nthread)
|
|
#'
|
|
#' bst <- xgb.train(
|
|
#' data = xgb.DMatrix(agaricus.train$data, label = agaricus.train$label),
|
|
#' max_depth = 2,
|
|
#' eta = 1,
|
|
#' nthread = nthread,
|
|
#' nrounds = 2,
|
|
#' objective = "binary:logistic"
|
|
#' )
|
|
#'
|
|
#' # This bst model already has feature_names stored with it, so those would be used when
|
|
#' # feature_names is not set:
|
|
#' dt <- xgb.model.dt.tree(bst)
|
|
#'
|
|
#' # How to match feature names of splits that are following a current 'Yes' branch:
|
|
#' merge(
|
|
#' dt,
|
|
#' dt[, .(ID, Y.Feature = Feature)], by.x = "Yes", by.y = "ID", all.x = TRUE
|
|
#' )[
|
|
#' order(Tree, Node)
|
|
#' ]
|
|
#'
|
|
#' @export
|
|
xgb.model.dt.tree <- function(model = NULL, text = NULL,
|
|
trees = NULL, use_int_id = FALSE, ...) {
|
|
check.deprecation(...)
|
|
|
|
if (!inherits(model, "xgb.Booster") && !is.character(text)) {
|
|
stop("Either 'model' must be an object of class xgb.Booster\n",
|
|
" or 'text' must be a character vector with the result of xgb.dump\n",
|
|
" (or NULL if 'model' was provided).")
|
|
}
|
|
|
|
if (!(is.null(trees) || is.numeric(trees))) {
|
|
stop("trees: must be a vector of integers.")
|
|
}
|
|
|
|
feature_names <- NULL
|
|
if (inherits(model, "xgb.Booster")) {
|
|
feature_names <- xgb.feature_names(model)
|
|
}
|
|
|
|
from_text <- TRUE
|
|
if (is.null(text)) {
|
|
text <- xgb.dump(model = model, with_stats = TRUE)
|
|
from_text <- FALSE
|
|
}
|
|
|
|
if (length(text) < 2 || !any(grepl('leaf=(\\d+)', text))) {
|
|
stop("Non-tree model detected! This function can only be used with tree models.")
|
|
}
|
|
|
|
position <- which(grepl("booster", text, fixed = TRUE))
|
|
|
|
add.tree.id <- function(node, tree) if (use_int_id) node else paste(tree, node, sep = "-")
|
|
|
|
anynumber_regex <- "[-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?"
|
|
|
|
td <- data.table(t = text)
|
|
td[position, Tree := 1L]
|
|
td[, Tree := cumsum(ifelse(is.na(Tree), 0L, Tree)) - 1L]
|
|
|
|
if (is.null(trees)) {
|
|
trees <- 0:max(td$Tree)
|
|
} else {
|
|
trees <- trees[trees >= 0 & trees <= max(td$Tree)]
|
|
}
|
|
td <- td[Tree %in% trees & !grepl('^booster', t)]
|
|
|
|
td[, Node := as.integer(sub("^([0-9]+):.*", "\\1", t))]
|
|
if (!use_int_id) td[, ID := add.tree.id(Node, Tree)]
|
|
td[, isLeaf := grepl("leaf", t, fixed = TRUE)]
|
|
|
|
# parse branch lines
|
|
branch_rx_nonames <- paste0("f(\\d+)<(", anynumber_regex, ")\\] yes=(\\d+),no=(\\d+),missing=(\\d+),",
|
|
"gain=(", anynumber_regex, "),cover=(", anynumber_regex, ")")
|
|
branch_rx_w_names <- paste0("\\d+:\\[(.+)<(", anynumber_regex, ")\\] yes=(\\d+),no=(\\d+),missing=(\\d+),",
|
|
"gain=(", anynumber_regex, "),cover=(", anynumber_regex, ")")
|
|
text_has_feature_names <- FALSE
|
|
if (NROW(feature_names)) {
|
|
branch_rx <- branch_rx_w_names
|
|
text_has_feature_names <- TRUE
|
|
} else {
|
|
# Note: when passing a text dump, it might or might not have feature names,
|
|
# but that aspect is unknown from just the text attributes
|
|
branch_rx <- branch_rx_nonames
|
|
if (from_text) {
|
|
if (sum(grepl(branch_rx_w_names, text)) > sum(grepl(branch_rx_nonames, text))) {
|
|
branch_rx <- branch_rx_w_names
|
|
text_has_feature_names <- TRUE
|
|
}
|
|
}
|
|
}
|
|
branch_cols <- c("Feature", "Split", "Yes", "No", "Missing", "Gain", "Cover")
|
|
td[
|
|
isLeaf == FALSE,
|
|
(branch_cols) := {
|
|
matches <- regmatches(t, regexec(branch_rx, t))
|
|
# skip some indices with spurious capture groups from anynumber_regex
|
|
xtr <- do.call(rbind, matches)[, c(2, 3, 5, 6, 7, 8, 10), drop = FALSE]
|
|
xtr[, 3:5] <- add.tree.id(xtr[, 3:5], Tree)
|
|
if (length(xtr) == 0) {
|
|
as.data.table(
|
|
list(Feature = "NA", Split = "NA", Yes = "NA", No = "NA", Missing = "NA", Gain = "NA", Cover = "NA")
|
|
)
|
|
} else {
|
|
as.data.table(xtr)
|
|
}
|
|
}
|
|
]
|
|
|
|
# assign feature_names when available
|
|
is_stump <- function() {
|
|
return(length(td$Feature) == 1 && is.na(td$Feature))
|
|
}
|
|
if (!text_has_feature_names) {
|
|
if (!is.null(feature_names) && !is_stump()) {
|
|
if (length(feature_names) <= max(as.numeric(td$Feature), na.rm = TRUE))
|
|
stop("feature_names has less elements than there are features used in the model")
|
|
td[isLeaf == FALSE, Feature := feature_names[as.numeric(Feature) + 1]]
|
|
}
|
|
}
|
|
|
|
# parse leaf lines
|
|
leaf_rx <- paste0("leaf=(", anynumber_regex, "),cover=(", anynumber_regex, ")")
|
|
leaf_cols <- c("Feature", "Gain", "Cover")
|
|
td[
|
|
isLeaf == TRUE,
|
|
(leaf_cols) := {
|
|
matches <- regmatches(t, regexec(leaf_rx, t))
|
|
xtr <- do.call(rbind, matches)[, c(2, 4)]
|
|
if (length(xtr) == 2) {
|
|
c("Leaf", as.data.table(xtr[1]), as.data.table(xtr[2]))
|
|
} else {
|
|
c("Leaf", as.data.table(xtr))
|
|
}
|
|
}
|
|
]
|
|
|
|
# convert some columns to numeric
|
|
numeric_cols <- c("Split", "Gain", "Cover")
|
|
td[, (numeric_cols) := lapply(.SD, as.numeric), .SDcols = numeric_cols]
|
|
if (use_int_id) {
|
|
int_cols <- c("Yes", "No", "Missing")
|
|
td[, (int_cols) := lapply(.SD, as.integer), .SDcols = int_cols]
|
|
}
|
|
|
|
td[, t := NULL]
|
|
td[, isLeaf := NULL]
|
|
|
|
td[order(Tree, Node)]
|
|
}
|
|
|
|
# Avoid error messages during CRAN check.
|
|
# The reason is that these variables are never declared
|
|
# They are mainly column names inferred by Data.table...
|
|
globalVariables(c("Tree", "Node", "ID", "Feature", "t", "isLeaf", ".SD", ".SDcols"))
|