Add new tests for new functions
This commit is contained in:
parent
ad8766dfa4
commit
96c43cf197
@ -77,13 +77,22 @@ xgb.importance <- function(feature_names = NULL, model = NULL, data = NULL, labe
|
|||||||
if(sum(label == 0) / length(label) > 0.5) label <- as(label, "sparseVector")
|
if(sum(label == 0) / length(label) > 0.5) label <- as(label, "sparseVector")
|
||||||
}
|
}
|
||||||
|
|
||||||
text <- xgb.dump(model = model, with.stats = T)
|
treeDump <- function(feature_names, text, keepDetail){
|
||||||
|
if(keepDetail) groupBy <- c("Feature", "Split", "MissingNo") else groupBy <- "Feature"
|
||||||
|
xgb.model.dt.tree(feature_names = feature_names, text = text)[,"MissingNo" := Missing == No ][Feature != "Leaf",.(Gain = sum(Quality), Cover = sum(Cover), Frequency = .N), by = groupBy, with = T][,`:=`(Gain = Gain / sum(Gain), Cover = Cover / sum(Cover), Frequency = Frequency / sum(Frequency))][order(Gain, decreasing = T)]
|
||||||
|
}
|
||||||
|
|
||||||
if(text[2] == "bias:"){
|
linearDump <- function(feature_names, text){
|
||||||
result <- readLines(filename_dump) %>% linearDump(feature_names, .)
|
which(text == "weight:") %>% {a =. + 1; text[a:length(text)]} %>% as.numeric %>% data.table(Feature = feature_names, Weight = .)
|
||||||
|
}
|
||||||
|
|
||||||
|
model.text.dump <- xgb.dump(model = model, with.stats = T)
|
||||||
|
|
||||||
|
if(model.text.dump[2] == "bias:"){
|
||||||
|
result <- model.text.dump %>% linearDump(feature_names, .)
|
||||||
if(!is.null(data) | !is.null(label)) warning("data/label: these parameters should only be provided with decision tree based models.")
|
if(!is.null(data) | !is.null(label)) warning("data/label: these parameters should only be provided with decision tree based models.")
|
||||||
} else {
|
} else {
|
||||||
result <- treeDump(feature_names, text = text, keepDetail = !is.null(data))
|
result <- treeDump(feature_names, text = model.text.dump, keepDetail = !is.null(data))
|
||||||
|
|
||||||
# Co-occurence computation
|
# Co-occurence computation
|
||||||
if(!is.null(data) & !is.null(label) & nrow(result) > 0) {
|
if(!is.null(data) & !is.null(label) & nrow(result) > 0) {
|
||||||
@ -102,17 +111,7 @@ xgb.importance <- function(feature_names = NULL, model = NULL, data = NULL, labe
|
|||||||
result
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
treeDump <- function(feature_names, text, keepDetail){
|
|
||||||
if(keepDetail) groupBy <- c("Feature", "Split", "MissingNo") else groupBy <- "Feature"
|
|
||||||
|
|
||||||
result <- xgb.model.dt.tree(feature_names = feature_names, text = text)[,"MissingNo" := Missing == No ][Feature != "Leaf",.(Gain = sum(Quality), Cover = sum(Cover), Frequency = .N), by = groupBy, with = T][,`:=`(Gain = Gain / sum(Gain), Cover = Cover / sum(Cover), Frequency = Frequency / sum(Frequency))][order(Gain, decreasing = T)]
|
|
||||||
|
|
||||||
result
|
|
||||||
}
|
|
||||||
|
|
||||||
linearDump <- function(feature_names, text){
|
|
||||||
which(text == "weight:") %>% {a =. + 1; text[a:length(text)]} %>% as.numeric %>% data.table(Feature = feature_names, Weight = .)
|
|
||||||
}
|
|
||||||
|
|
||||||
# Avoid error messages during CRAN check.
|
# Avoid error messages during CRAN check.
|
||||||
# The reason is that these variables are never declared
|
# The reason is that these variables are never declared
|
||||||
|
|||||||
@ -59,19 +59,19 @@ xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL, n
|
|||||||
stop("feature_names: Has to be a vector of character or NULL if the model dump already contains feature name. Look at this function documentation to see where to get feature names.")
|
stop("feature_names: Has to be a vector of character or NULL if the model dump already contains feature name. Look at this function documentation to see where to get feature names.")
|
||||||
}
|
}
|
||||||
|
|
||||||
if (class(model) != "xgb.Booster") {
|
if (class(model) != "xgb.Booster" & class(text) != "character") {
|
||||||
stop("model: Has to be an object of class xgb.Booster model generaged by the xgb.train function.")
|
"model: Has to be an object of class xgb.Booster model generaged by the xgb.train function.\n" %>%
|
||||||
}
|
paste0("text: Has to be a vector of character or NULL if a path to the model dump has already been provided.") %>%
|
||||||
|
stop()
|
||||||
if (!class(text) %in% c("character", "NULL")) {
|
|
||||||
stop("text: Has to be a vector of character or NULL if a path to the model dump has already been provided.")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!class(n_first_tree) %in% c("numeric", "NULL") | length(n_first_tree) > 1) {
|
if (!class(n_first_tree) %in% c("numeric", "NULL") | length(n_first_tree) > 1) {
|
||||||
stop("n_first_tree: Has to be a numeric vector of size 1.")
|
stop("n_first_tree: Has to be a numeric vector of size 1.")
|
||||||
}
|
}
|
||||||
|
|
||||||
text <- xgb.dump(model = model, with.stats = T)
|
if(is.null(text)){
|
||||||
|
text <- xgb.dump(model = model, with.stats = T)
|
||||||
|
}
|
||||||
|
|
||||||
position <- str_match(text, "booster") %>% is.na %>% not %>% which %>% c(length(text) + 1)
|
position <- str_match(text, "booster") %>% is.na %>% not %>% which %>% c(length(text) + 1)
|
||||||
|
|
||||||
|
|||||||
@ -45,7 +45,7 @@
|
|||||||
#' eta = 1, nthread = 2, nround = 30, objective = "binary:logistic",
|
#' eta = 1, nthread = 2, nround = 30, objective = "binary:logistic",
|
||||||
#' min_child_weight = 50)
|
#' min_child_weight = 50)
|
||||||
#'
|
#'
|
||||||
#' p <- xgb.plot.multi.trees(bst, agaricus.train$data@Dimnames[[2]], 3)
|
#' p <- xgb.plot.multi.trees(model = bst, names = agaricus.train$data@Dimnames[[2]], 3)
|
||||||
#' print(p)
|
#' print(p)
|
||||||
#'
|
#'
|
||||||
#' @export
|
#' @export
|
||||||
|
|||||||
@ -49,7 +49,7 @@ bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, max.dep
|
|||||||
eta = 1, nthread = 2, nround = 30, objective = "binary:logistic",
|
eta = 1, nthread = 2, nround = 30, objective = "binary:logistic",
|
||||||
min_child_weight = 50)
|
min_child_weight = 50)
|
||||||
|
|
||||||
p <- xgb.plot.multi.trees(bst, agaricus.train$data@Dimnames[[2]], 3)
|
p <- xgb.plot.multi.trees(model = bst, names = agaricus.train$data@Dimnames[[2]], 3)
|
||||||
print(p)
|
print(p)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -19,15 +19,23 @@ bst <- xgboost(data = sparse_matrix, label = output_vector, max.depth = 9,
|
|||||||
|
|
||||||
test_that("xgb.dump works", {
|
test_that("xgb.dump works", {
|
||||||
capture.output(print(xgb.dump(bst)))
|
capture.output(print(xgb.dump(bst)))
|
||||||
|
expect_true(xgb.dump(bst, 'xgb.model.dump', with.stats = T))
|
||||||
})
|
})
|
||||||
|
|
||||||
test_that("xgb.importance works", {
|
test_that("xgb.importance works", {
|
||||||
expect_true(xgb.dump(bst, 'xgb.model.dump', with.stats = T))
|
|
||||||
importance <- xgb.importance(sparse_matrix@Dimnames[[2]], model = bst)
|
importance <- xgb.importance(sparse_matrix@Dimnames[[2]], model = bst)
|
||||||
expect_equal(dim(importance), c(7, 4))
|
expect_equal(dim(importance), c(7, 4))
|
||||||
expect_equal(colnames(importance), c("Feature", "Gain", "Cover", "Frequency"))
|
expect_equal(colnames(importance), c("Feature", "Gain", "Cover", "Frequency"))
|
||||||
})
|
})
|
||||||
|
|
||||||
test_that("xgb.plot.tree works", {
|
test_that("xgb.plot.tree works", {
|
||||||
xgb.plot.tree(agaricus.train$data@Dimnames[[2]], model = bst)
|
xgb.plot.tree(names = agaricus.train$data@Dimnames[[2]], model = bst)
|
||||||
|
})
|
||||||
|
|
||||||
|
test_that("xgb.plot.deepness works", {
|
||||||
|
xgb.plot.deepness(model = bst)
|
||||||
|
})
|
||||||
|
|
||||||
|
test_that("xgb.plot.multi.trees works", {
|
||||||
|
xgb.plot.multi.trees(model = bst, names = agaricus.train$data@Dimnames[[2]], 3)
|
||||||
})
|
})
|
||||||
Loading…
x
Reference in New Issue
Block a user