[R] xgb.plot.tree fixes (#1939)

* [R] a few fixes and improvements to xgb.plot.tree

* [R] deprecate n_first_tree replace with trees; fix types in xgb.model.dt.tree
This commit is contained in:
Vadim Khotilovich
2017-01-06 13:09:51 -06:00
committed by Tianqi Chen
parent d23ea5ca7d
commit d7406e07f3
7 changed files with 225 additions and 116 deletions

View File

@@ -7,8 +7,12 @@
#' @param model object of class \code{xgb.Booster}
#' @param text \code{character} vector previously generated by the \code{xgb.dump}
#' function (where parameter \code{with_stats = TRUE} should have been set).
#' @param n_first_tree limit the parsing to the \code{n} first trees.
#' @param trees an integer vector of tree indices that should be parsed.
#' If set to \code{NULL}, all trees of the model are parsed.
#' It could be useful, e.g., in multiclass classification to get only
#' the trees of one certain class. IMPORTANT: the tree index in xgboost model
#' is zero-based (e.g., use \code{trees = 0:4} for first 5 trees).
#' @param ... currently not used.
#'
#' @return
#' A \code{data.table} with detailed information about model trees' nodes.
@@ -16,9 +20,9 @@
#' The columns of the \code{data.table} are:
#'
#' \itemize{
#' \item \code{Tree}: ID of a tree in a model
#' \item \code{Node}: ID of a node in a tree
#' \item \code{ID}: unique identifier of a node in a model
#' \item \code{Tree}: ID of a tree in a model (integer)
#' \item \code{Node}: integer ID of a node in a tree (integer)
#' \item \code{ID}: identifier of a node in a model (character)
#' \item \code{Feature}: for a branch node, it's a feature id or name (when available);
#' for a leaf note, it simply labels it as \code{'Leaf'}
#' \item \code{Split}: location of the split for a branch node (split condition is always "less than")
@@ -47,8 +51,8 @@
#'
#' @export
xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL,
n_first_tree = NULL){
trees = NULL, ...){
check.deprecation(...)
if (!class(feature_names) %in% c("character", "NULL")) {
stop("feature_names: Has to be a vector of character\n",
" or NULL if the model dump already contains feature names.\n",
@@ -61,8 +65,8 @@ xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL,
" (or NULL if the model was provided).")
}
if (!class(n_first_tree) %in% c("numeric", "NULL") | length(n_first_tree) > 1) {
stop("n_first_tree: Has to be a numeric vector of size 1.")
if (!class(trees) %in% c("integer", "numeric", "NULL")) {
stop("trees: Has to be a vector of integers.")
}
if (is.null(text)){
@@ -84,10 +88,14 @@ xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL,
td[position, Tree := 1L]
td[, Tree := cumsum(ifelse(is.na(Tree), 0L, Tree)) - 1L]
n_first_tree <- min(max(td$Tree), n_first_tree)
td <- td[Tree <= n_first_tree & !grepl('^booster', t)]
if (is.null(trees)) {
trees <- 0:max(td$Tree)
} else {
trees <- trees[trees >= 0 & trees <= max(td$Tree)]
}
td <- td[Tree %in% trees & !grepl('^booster', t)]
td[, Node := stri_match_first_regex(t, "(\\d+):")[,2] %>% as.numeric ]
td[, Node := stri_match_first_regex(t, "(\\d+):")[,2] %>% as.integer ]
td[, ID := add.tree.id(Node, Tree)]
td[, isLeaf := !is.na(stri_match_first_regex(t, "leaf"))]
@@ -112,7 +120,7 @@ xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL,
}]
# convert some columns to numeric
numeric_cols <- c("Quality", "Cover")
numeric_cols <- c("Split", "Quality", "Cover")
td[, (numeric_cols) := lapply(.SD, as.numeric), .SDcols=numeric_cols]
td[, t := NULL]