[R] Fix parsing decision stump. (#7689)

This commit is contained in:
Jiaming Yuan 2022-03-17 01:08:22 +08:00 committed by GitHub
parent e78a38b837
commit da351621a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 60 additions and 28 deletions

View File

@ -87,7 +87,7 @@ xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL,
}
if (length(text) < 2 ||
sum(grepl('yes=(\\d+),no=(\\d+)', text)) < 1) {
sum(grepl('leaf=(\\d+)', text)) < 1) {
stop("Non-tree model detected! This function can only be used with tree models.")
}
@ -116,16 +116,28 @@ xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL,
branch_rx <- paste0("f(\\d+)<(", anynumber_regex, ")\\] yes=(\\d+),no=(\\d+),missing=(\\d+),",
"gain=(", anynumber_regex, "),cover=(", anynumber_regex, ")")
branch_cols <- c("Feature", "Split", "Yes", "No", "Missing", "Quality", "Cover")
td[isLeaf == FALSE,
td[
isLeaf == FALSE,
(branch_cols) := {
matches <- regmatches(t, regexec(branch_rx, t))
# skip some indices with spurious capture groups from anynumber_regex
xtr <- do.call(rbind, matches)[, c(2, 3, 5, 6, 7, 8, 10), drop = FALSE]
xtr[, 3:5] <- add.tree.id(xtr[, 3:5], Tree)
if (length(xtr) == 0) {
as.data.table(
list(Feature = "NA", Split = "NA", Yes = "NA", No = "NA", Missing = "NA", Quality = "NA", Cover = "NA")
)
} else {
as.data.table(xtr)
}]
}
}
]
# assign feature_names when available
if (!is.null(feature_names)) {
is_stump <- function() {
return(length(td$Feature) == 1 && is.na(td$Feature))
}
if (!is.null(feature_names) && !is_stump()) {
if (length(feature_names) <= max(as.numeric(td$Feature), na.rm = TRUE))
stop("feature_names has less elements than there are features used in the model")
td[isLeaf == FALSE, Feature := feature_names[as.numeric(Feature) + 1]]
@ -134,12 +146,18 @@ xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL,
# parse leaf lines
leaf_rx <- paste0("leaf=(", anynumber_regex, "),cover=(", anynumber_regex, ")")
leaf_cols <- c("Feature", "Quality", "Cover")
td[isLeaf == TRUE,
td[
isLeaf == TRUE,
(leaf_cols) := {
matches <- regmatches(t, regexec(leaf_rx, t))
xtr <- do.call(rbind, matches)[, c(2, 4)]
if (length(xtr) == 2) {
c("Leaf", as.data.table(xtr[1]), as.data.table(xtr[2]))
} else {
c("Leaf", as.data.table(xtr))
}]
}
}
]
# convert some columns to numeric
numeric_cols <- c("Split", "Quality", "Cover")

View File

@ -98,6 +98,7 @@ xgb.plot.tree <- function(feature_names = NULL, model = NULL, trees = NULL, plot
data = dt$Feature,
fontcolor = "black")
if (nrow(dt[Feature != "Leaf"]) != 0) {
edges <- DiagrammeR::create_edge_df(
from = match(rep(dt[Feature != "Leaf", c(ID)], 2), dt$ID),
to = match(dt[Feature != "Leaf", c(Yes, No)], dt$ID),
@ -110,6 +111,9 @@ xgb.plot.tree <- function(feature_names = NULL, model = NULL, trees = NULL, plot
dt[Feature != "Leaf", ifelse(Missing == No, "bold", "solid")]
),
rel = "leading_to")
} else {
edges <- NULL
}
graph <- DiagrammeR::create_graph(
nodes_df = nodes,

View File

@ -340,6 +340,16 @@ test_that("xgb.importance works with and without feature names", {
imp
}
expect_equal(importance_from_dump(), importance, tolerance = 1e-6)
## decision stump
m <- xgboost::xgboost(
data = as.matrix(data.frame(x = c(0, 1))),
label = c(1, 2),
nrounds = 1
)
df <- xgb.model.dt.tree(model = m)
expect_equal(df$Feature, "Leaf")
expect_equal(df$Cover, 2)
})
test_that("xgb.importance works with GLM model", {