diff --git a/R-package/R/xgb.model.dt.tree.R b/R-package/R/xgb.model.dt.tree.R index df0ce54dc..5411c35d2 100644 --- a/R-package/R/xgb.model.dt.tree.R +++ b/R-package/R/xgb.model.dt.tree.R @@ -87,7 +87,7 @@ xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL, } if (length(text) < 2 || - sum(grepl('yes=(\\d+),no=(\\d+)', text)) < 1) { + sum(grepl('leaf=(\\d+)', text)) < 1) { stop("Non-tree model detected! This function can only be used with tree models.") } @@ -116,16 +116,28 @@ xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL, branch_rx <- paste0("f(\\d+)<(", anynumber_regex, ")\\] yes=(\\d+),no=(\\d+),missing=(\\d+),", "gain=(", anynumber_regex, "),cover=(", anynumber_regex, ")") branch_cols <- c("Feature", "Split", "Yes", "No", "Missing", "Quality", "Cover") - td[isLeaf == FALSE, - (branch_cols) := { - matches <- regmatches(t, regexec(branch_rx, t)) - # skip some indices with spurious capture groups from anynumber_regex - xtr <- do.call(rbind, matches)[, c(2, 3, 5, 6, 7, 8, 10), drop = FALSE] - xtr[, 3:5] <- add.tree.id(xtr[, 3:5], Tree) - as.data.table(xtr) - }] + td[ + isLeaf == FALSE, + (branch_cols) := { + matches <- regmatches(t, regexec(branch_rx, t)) + # skip some indices with spurious capture groups from anynumber_regex + xtr <- do.call(rbind, matches)[, c(2, 3, 5, 6, 7, 8, 10), drop = FALSE] + xtr[, 3:5] <- add.tree.id(xtr[, 3:5], Tree) + if (length(xtr) == 0) { + as.data.table( + list(Feature = "NA", Split = "NA", Yes = "NA", No = "NA", Missing = "NA", Quality = "NA", Cover = "NA") + ) + } else { + as.data.table(xtr) + } + } + ] + # assign feature_names when available - if (!is.null(feature_names)) { + is_stump <- function() { + return(length(td$Feature) == 1 && is.na(td$Feature)) + } + if (!is.null(feature_names) && !is_stump()) { if (length(feature_names) <= max(as.numeric(td$Feature), na.rm = TRUE)) stop("feature_names has less elements than there are features used in the model") td[isLeaf == FALSE, Feature := feature_names[as.numeric(Feature) + 1]] @@ -134,12 +146,18 @@ xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL, # parse leaf lines leaf_rx <- paste0("leaf=(", anynumber_regex, "),cover=(", anynumber_regex, ")") leaf_cols <- c("Feature", "Quality", "Cover") - td[isLeaf == TRUE, - (leaf_cols) := { - matches <- regmatches(t, regexec(leaf_rx, t)) - xtr <- do.call(rbind, matches)[, c(2, 4)] - c("Leaf", as.data.table(xtr)) - }] + td[ + isLeaf == TRUE, + (leaf_cols) := { + matches <- regmatches(t, regexec(leaf_rx, t)) + xtr <- do.call(rbind, matches)[, c(2, 4)] + if (length(xtr) == 2) { + c("Leaf", as.data.table(xtr[1]), as.data.table(xtr[2])) + } else { + c("Leaf", as.data.table(xtr)) + } + } + ] # convert some columns to numeric numeric_cols <- c("Split", "Quality", "Cover") diff --git a/R-package/R/xgb.plot.tree.R b/R-package/R/xgb.plot.tree.R index 71b9f08a5..dc2656170 100644 --- a/R-package/R/xgb.plot.tree.R +++ b/R-package/R/xgb.plot.tree.R @@ -98,18 +98,22 @@ xgb.plot.tree <- function(feature_names = NULL, model = NULL, trees = NULL, plot data = dt$Feature, fontcolor = "black") - edges <- DiagrammeR::create_edge_df( - from = match(rep(dt[Feature != "Leaf", c(ID)], 2), dt$ID), - to = match(dt[Feature != "Leaf", c(Yes, No)], dt$ID), - label = c( - dt[Feature != "Leaf", paste("<", Split)], - rep("", nrow(dt[Feature != "Leaf"])) - ), - style = c( - dt[Feature != "Leaf", ifelse(Missing == Yes, "bold", "solid")], - dt[Feature != "Leaf", ifelse(Missing == No, "bold", "solid")] - ), - rel = "leading_to") + if (nrow(dt[Feature != "Leaf"]) != 0) { + edges <- DiagrammeR::create_edge_df( + from = match(rep(dt[Feature != "Leaf", c(ID)], 2), dt$ID), + to = match(dt[Feature != "Leaf", c(Yes, No)], dt$ID), + label = c( + dt[Feature != "Leaf", paste("<", Split)], + rep("", nrow(dt[Feature != "Leaf"])) + ), + style = c( + dt[Feature != "Leaf", ifelse(Missing == Yes, "bold", "solid")], + dt[Feature != "Leaf", ifelse(Missing == No, "bold", "solid")] + ), + rel = "leading_to") + } else { + edges <- NULL + } graph <- DiagrammeR::create_graph( nodes_df = nodes, diff --git a/R-package/tests/testthat/test_helpers.R b/R-package/tests/testthat/test_helpers.R index c426c8377..fdd0ce02b 100644 --- a/R-package/tests/testthat/test_helpers.R +++ b/R-package/tests/testthat/test_helpers.R @@ -340,6 +340,16 @@ test_that("xgb.importance works with and without feature names", { imp } expect_equal(importance_from_dump(), importance, tolerance = 1e-6) + + ## decision stump + m <- xgboost::xgboost( + data = as.matrix(data.frame(x = c(0, 1))), + label = c(1, 2), + nrounds = 1 + ) + df <- xgb.model.dt.tree(model = m) + expect_equal(df$Feature, "Leaf") + expect_equal(df$Cover, 2) }) test_that("xgb.importance works with GLM model", {