[R-package] remove dependency on {magrittr} (#6928)

Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
James Lamb
2021-05-12 15:34:59 -05:00
committed by GitHub
parent 44cc9c04ea
commit 894e9bc5d4
14 changed files with 131 additions and 66 deletions

View File

@@ -642,8 +642,13 @@ cb.gblinear.history <- function(sparse=FALSE) {
coefs <<- list2mat(coefs)
} else { # xgb.cv:
# first lapply transposes the list
coefs <<- lapply(seq_along(coefs[[1]]), function(i) lapply(coefs, "[[", i)) %>%
lapply(function(x) list2mat(x))
coefs <<- lapply(
X = lapply(
X = seq_along(coefs[[1]]),
FUN = function(i) lapply(coefs, "[[", i)
),
FUN = function(x) list2mat(x)
)
}
}

View File

@@ -372,8 +372,14 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
} else if (n_group == 1) {
matrix(ret, nrow = n_row, byrow = TRUE, dimnames = list(NULL, cnames))
} else {
arr <- array(ret, c(n_col1, n_group, n_row),
dimnames = list(cnames, NULL, NULL)) %>% aperm(c(2, 3, 1)) # [group, row, col]
arr <- aperm(
a = array(
data = ret,
dim = c(n_col1, n_group, n_row),
dimnames = list(cnames, NULL, NULL)
),
perm = c(2, 3, 1) # [group, row, col]
)
lapply(seq_len(n_group), function(g) arr[g, , ])
}
} else if (predinteraction) {
@@ -383,10 +389,23 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
ret <- if (n_ret == n_row) {
matrix(ret, ncol = 1, dimnames = list(NULL, cnames))
} else if (n_group == 1) {
array(ret, c(n_col1, n_col1, n_row), dimnames = list(cnames, cnames, NULL)) %>% aperm(c(3, 1, 2))
aperm(
a = array(
data = ret,
dim = c(n_col1, n_col1, n_row),
dimnames = list(cnames, cnames, NULL)
),
perm = c(3, 1, 2)
)
} else {
arr <- array(ret, c(n_col1, n_col1, n_group, n_row),
dimnames = list(cnames, cnames, NULL, NULL)) %>% aperm(c(3, 4, 1, 2)) # [group, row, col1, col2]
arr <- aperm(
a = array(
data = ret,
dim = c(n_col1, n_col1, n_group, n_row),
dimnames = list(cnames, cnames, NULL, NULL)
),
perm = c(3, 4, 1, 2) # [group, row, col1, col2]
)
lapply(seq_len(n_group), function(g) arr[g, , , ])
}
} else if (reshape && npred_per_case > 1) {

View File

@@ -100,9 +100,10 @@ xgb.importance <- function(feature_names = NULL, model = NULL, trees = NULL,
# linear model
if (model_text_dump[2] == "bias:"){
weights <- which(model_text_dump == "weight:") %>%
{model_text_dump[(. + 1):length(model_text_dump)]} %>%
as.numeric
weight_index <- which(model_text_dump == "weight:") + 1
weights <- as.numeric(
model_text_dump[weight_index:length(model_text_dump)]
)
num_class <- NVL(model$params$num_class, 1)
if (is.null(feature_names))

View File

@@ -75,8 +75,8 @@ xgb.plot.multi.trees <- function(model, feature_names = NULL, features_keep = 5,
while (tree.matrix[, sum(is.na(abs.node.position))] > 0) {
yes.row.nodes <- tree.matrix[abs.node.position %in% precedent.nodes & !is.na(Yes)]
no.row.nodes <- tree.matrix[abs.node.position %in% precedent.nodes & !is.na(No)]
yes.nodes.abs.pos <- yes.row.nodes[, abs.node.position] %>% paste0("_0")
no.nodes.abs.pos <- no.row.nodes[, abs.node.position] %>% paste0("_1")
yes.nodes.abs.pos <- paste0(yes.row.nodes[, abs.node.position], "_0")
no.nodes.abs.pos <- paste0(no.row.nodes[, abs.node.position], "_1")
tree.matrix[ID %in% yes.row.nodes[, Yes], abs.node.position := yes.nodes.abs.pos]
tree.matrix[ID %in% no.row.nodes[, No], abs.node.position := no.nodes.abs.pos]
@@ -92,19 +92,28 @@ xgb.plot.multi.trees <- function(model, feature_names = NULL, features_keep = 5,
nodes.dt <- tree.matrix[
, .(Quality = sum(Quality))
, by = .(abs.node.position, Feature)
][, .(Text = paste0(Feature[1:min(length(Feature), features_keep)],
" (",
format(Quality[1:min(length(Quality), features_keep)], digits = 5),
")") %>%
paste0(collapse = "\n"))
, by = abs.node.position]
][, .(Text = paste0(
paste0(
Feature[1:min(length(Feature), features_keep)],
" (",
format(Quality[1:min(length(Quality), features_keep)], digits = 5),
")"
),
collapse = "\n"
)
)
, by = abs.node.position
]
edges.dt <- tree.matrix[Feature != "Leaf", .(abs.node.position, Yes)] %>%
list(tree.matrix[Feature != "Leaf", .(abs.node.position, No)]) %>%
rbindlist() %>%
setnames(c("From", "To")) %>%
.[, .N, .(From, To)] %>%
.[, N := NULL]
edges.dt <- data.table::rbindlist(
l = list(
tree.matrix[Feature != "Leaf", .(abs.node.position, Yes)],
tree.matrix[Feature != "Leaf", .(abs.node.position, No)]
)
)
data.table::setnames(edges.dt, c("From", "To"))
edges.dt <- edges.dt[, .N, .(From, To)]
edges.dt[, N := NULL]
nodes <- DiagrammeR::create_node_df(
n = nrow(nodes.dt),
@@ -120,21 +129,25 @@ xgb.plot.multi.trees <- function(model, feature_names = NULL, features_keep = 5,
nodes_df = nodes,
edges_df = edges,
attr_theme = NULL
) %>%
DiagrammeR::add_global_graph_attrs(
)
graph <- DiagrammeR::add_global_graph_attrs(
graph = graph,
attr_type = "graph",
attr = c("layout", "rankdir"),
value = c("dot", "LR")
) %>%
DiagrammeR::add_global_graph_attrs(
)
graph <- DiagrammeR::add_global_graph_attrs(
graph = graph,
attr_type = "node",
attr = c("color", "fillcolor", "style", "shape", "fontname"),
value = c("DimGray", "beige", "filled", "rectangle", "Helvetica")
) %>%
DiagrammeR::add_global_graph_attrs(
)
graph <- DiagrammeR::add_global_graph_attrs(
graph = graph,
attr_type = "edge",
attr = c("color", "arrowsize", "arrowhead", "fontname"),
value = c("DimGray", "1.5", "vee", "Helvetica"))
value = c("DimGray", "1.5", "vee", "Helvetica")
)
if (!render) return(invisible(graph))

View File

@@ -99,33 +99,41 @@ xgb.plot.tree <- function(feature_names = NULL, model = NULL, trees = NULL, plot
fontcolor = "black")
edges <- DiagrammeR::create_edge_df(
from = match(dt[Feature != "Leaf", c(ID)] %>% rep(2), dt$ID),
from = match(rep(dt[Feature != "Leaf", c(ID)], 2), dt$ID),
to = match(dt[Feature != "Leaf", c(Yes, No)], dt$ID),
label = dt[Feature != "Leaf", paste("<", Split)] %>%
c(rep("", nrow(dt[Feature != "Leaf"]))),
style = dt[Feature != "Leaf", ifelse(Missing == Yes, "bold", "solid")] %>%
c(dt[Feature != "Leaf", ifelse(Missing == No, "bold", "solid")]),
label = c(
dt[Feature != "Leaf", paste("<", Split)],
rep("", nrow(dt[Feature != "Leaf"]))
),
style = c(
dt[Feature != "Leaf", ifelse(Missing == Yes, "bold", "solid")],
dt[Feature != "Leaf", ifelse(Missing == No, "bold", "solid")]
),
rel = "leading_to")
graph <- DiagrammeR::create_graph(
nodes_df = nodes,
edges_df = edges,
attr_theme = NULL
) %>%
DiagrammeR::add_global_graph_attrs(
)
graph <- DiagrammeR::add_global_graph_attrs(
graph = graph,
attr_type = "graph",
attr = c("layout", "rankdir"),
value = c("dot", "LR")
) %>%
DiagrammeR::add_global_graph_attrs(
)
graph <- DiagrammeR::add_global_graph_attrs(
graph = graph,
attr_type = "node",
attr = c("color", "style", "fontname"),
value = c("DimGray", "filled", "Helvetica")
) %>%
DiagrammeR::add_global_graph_attrs(
)
graph <- DiagrammeR::add_global_graph_attrs(
graph = graph,
attr_type = "edge",
attr = c("color", "arrowsize", "arrowhead", "fontname"),
value = c("DimGray", "1.5", "vee", "Helvetica"))
value = c("DimGray", "1.5", "vee", "Helvetica")
)
if (!render) return(invisible(graph))

View File

@@ -90,7 +90,6 @@ NULL
#' @importFrom data.table setkey
#' @importFrom data.table setkeyv
#' @importFrom data.table setnames
#' @importFrom magrittr %>%
#' @importFrom jsonlite fromJSON
#' @importFrom jsonlite toJSON
#' @importFrom utils object.size str tail