[R] Fix parsing decision stump. (#7689)
This commit is contained in:
parent
e78a38b837
commit
da351621a1
@ -87,7 +87,7 @@ xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (length(text) < 2 ||
|
if (length(text) < 2 ||
|
||||||
sum(grepl('yes=(\\d+),no=(\\d+)', text)) < 1) {
|
sum(grepl('leaf=(\\d+)', text)) < 1) {
|
||||||
stop("Non-tree model detected! This function can only be used with tree models.")
|
stop("Non-tree model detected! This function can only be used with tree models.")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -116,16 +116,28 @@ xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL,
|
|||||||
branch_rx <- paste0("f(\\d+)<(", anynumber_regex, ")\\] yes=(\\d+),no=(\\d+),missing=(\\d+),",
|
branch_rx <- paste0("f(\\d+)<(", anynumber_regex, ")\\] yes=(\\d+),no=(\\d+),missing=(\\d+),",
|
||||||
"gain=(", anynumber_regex, "),cover=(", anynumber_regex, ")")
|
"gain=(", anynumber_regex, "),cover=(", anynumber_regex, ")")
|
||||||
branch_cols <- c("Feature", "Split", "Yes", "No", "Missing", "Quality", "Cover")
|
branch_cols <- c("Feature", "Split", "Yes", "No", "Missing", "Quality", "Cover")
|
||||||
td[isLeaf == FALSE,
|
td[
|
||||||
|
isLeaf == FALSE,
|
||||||
(branch_cols) := {
|
(branch_cols) := {
|
||||||
matches <- regmatches(t, regexec(branch_rx, t))
|
matches <- regmatches(t, regexec(branch_rx, t))
|
||||||
# skip some indices with spurious capture groups from anynumber_regex
|
# skip some indices with spurious capture groups from anynumber_regex
|
||||||
xtr <- do.call(rbind, matches)[, c(2, 3, 5, 6, 7, 8, 10), drop = FALSE]
|
xtr <- do.call(rbind, matches)[, c(2, 3, 5, 6, 7, 8, 10), drop = FALSE]
|
||||||
xtr[, 3:5] <- add.tree.id(xtr[, 3:5], Tree)
|
xtr[, 3:5] <- add.tree.id(xtr[, 3:5], Tree)
|
||||||
|
if (length(xtr) == 0) {
|
||||||
|
as.data.table(
|
||||||
|
list(Feature = "NA", Split = "NA", Yes = "NA", No = "NA", Missing = "NA", Quality = "NA", Cover = "NA")
|
||||||
|
)
|
||||||
|
} else {
|
||||||
as.data.table(xtr)
|
as.data.table(xtr)
|
||||||
}]
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
# assign feature_names when available
|
# assign feature_names when available
|
||||||
if (!is.null(feature_names)) {
|
is_stump <- function() {
|
||||||
|
return(length(td$Feature) == 1 && is.na(td$Feature))
|
||||||
|
}
|
||||||
|
if (!is.null(feature_names) && !is_stump()) {
|
||||||
if (length(feature_names) <= max(as.numeric(td$Feature), na.rm = TRUE))
|
if (length(feature_names) <= max(as.numeric(td$Feature), na.rm = TRUE))
|
||||||
stop("feature_names has less elements than there are features used in the model")
|
stop("feature_names has less elements than there are features used in the model")
|
||||||
td[isLeaf == FALSE, Feature := feature_names[as.numeric(Feature) + 1]]
|
td[isLeaf == FALSE, Feature := feature_names[as.numeric(Feature) + 1]]
|
||||||
@ -134,12 +146,18 @@ xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL,
|
|||||||
# parse leaf lines
|
# parse leaf lines
|
||||||
leaf_rx <- paste0("leaf=(", anynumber_regex, "),cover=(", anynumber_regex, ")")
|
leaf_rx <- paste0("leaf=(", anynumber_regex, "),cover=(", anynumber_regex, ")")
|
||||||
leaf_cols <- c("Feature", "Quality", "Cover")
|
leaf_cols <- c("Feature", "Quality", "Cover")
|
||||||
td[isLeaf == TRUE,
|
td[
|
||||||
|
isLeaf == TRUE,
|
||||||
(leaf_cols) := {
|
(leaf_cols) := {
|
||||||
matches <- regmatches(t, regexec(leaf_rx, t))
|
matches <- regmatches(t, regexec(leaf_rx, t))
|
||||||
xtr <- do.call(rbind, matches)[, c(2, 4)]
|
xtr <- do.call(rbind, matches)[, c(2, 4)]
|
||||||
|
if (length(xtr) == 2) {
|
||||||
|
c("Leaf", as.data.table(xtr[1]), as.data.table(xtr[2]))
|
||||||
|
} else {
|
||||||
c("Leaf", as.data.table(xtr))
|
c("Leaf", as.data.table(xtr))
|
||||||
}]
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
# convert some columns to numeric
|
# convert some columns to numeric
|
||||||
numeric_cols <- c("Split", "Quality", "Cover")
|
numeric_cols <- c("Split", "Quality", "Cover")
|
||||||
|
|||||||
@ -98,6 +98,7 @@ xgb.plot.tree <- function(feature_names = NULL, model = NULL, trees = NULL, plot
|
|||||||
data = dt$Feature,
|
data = dt$Feature,
|
||||||
fontcolor = "black")
|
fontcolor = "black")
|
||||||
|
|
||||||
|
if (nrow(dt[Feature != "Leaf"]) != 0) {
|
||||||
edges <- DiagrammeR::create_edge_df(
|
edges <- DiagrammeR::create_edge_df(
|
||||||
from = match(rep(dt[Feature != "Leaf", c(ID)], 2), dt$ID),
|
from = match(rep(dt[Feature != "Leaf", c(ID)], 2), dt$ID),
|
||||||
to = match(dt[Feature != "Leaf", c(Yes, No)], dt$ID),
|
to = match(dt[Feature != "Leaf", c(Yes, No)], dt$ID),
|
||||||
@ -110,6 +111,9 @@ xgb.plot.tree <- function(feature_names = NULL, model = NULL, trees = NULL, plot
|
|||||||
dt[Feature != "Leaf", ifelse(Missing == No, "bold", "solid")]
|
dt[Feature != "Leaf", ifelse(Missing == No, "bold", "solid")]
|
||||||
),
|
),
|
||||||
rel = "leading_to")
|
rel = "leading_to")
|
||||||
|
} else {
|
||||||
|
edges <- NULL
|
||||||
|
}
|
||||||
|
|
||||||
graph <- DiagrammeR::create_graph(
|
graph <- DiagrammeR::create_graph(
|
||||||
nodes_df = nodes,
|
nodes_df = nodes,
|
||||||
|
|||||||
@ -340,6 +340,16 @@ test_that("xgb.importance works with and without feature names", {
|
|||||||
imp
|
imp
|
||||||
}
|
}
|
||||||
expect_equal(importance_from_dump(), importance, tolerance = 1e-6)
|
expect_equal(importance_from_dump(), importance, tolerance = 1e-6)
|
||||||
|
|
||||||
|
## decision stump
|
||||||
|
m <- xgboost::xgboost(
|
||||||
|
data = as.matrix(data.frame(x = c(0, 1))),
|
||||||
|
label = c(1, 2),
|
||||||
|
nrounds = 1
|
||||||
|
)
|
||||||
|
df <- xgb.model.dt.tree(model = m)
|
||||||
|
expect_equal(df$Feature, "Leaf")
|
||||||
|
expect_equal(df$Cover, 2)
|
||||||
})
|
})
|
||||||
|
|
||||||
test_that("xgb.importance works with GLM model", {
|
test_that("xgb.importance works with GLM model", {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user