Add new tests for new functions

This commit is contained in:
pommedeterresautee 2015-11-30 15:04:17 +01:00
parent ad8766dfa4
commit 96c43cf197
5 changed files with 33 additions and 26 deletions

View File

@ -76,14 +76,23 @@ xgb.importance <- function(feature_names = NULL, model = NULL, data = NULL, labe
if(class(label) == "numeric"){ if(class(label) == "numeric"){
if(sum(label == 0) / length(label) > 0.5) label <- as(label, "sparseVector") if(sum(label == 0) / length(label) > 0.5) label <- as(label, "sparseVector")
} }
text <- xgb.dump(model = model, with.stats = T)
if(text[2] == "bias:"){ treeDump <- function(feature_names, text, keepDetail){
result <- readLines(filename_dump) %>% linearDump(feature_names, .) if(keepDetail) groupBy <- c("Feature", "Split", "MissingNo") else groupBy <- "Feature"
xgb.model.dt.tree(feature_names = feature_names, text = text)[,"MissingNo" := Missing == No ][Feature != "Leaf",.(Gain = sum(Quality), Cover = sum(Cover), Frequency = .N), by = groupBy, with = T][,`:=`(Gain = Gain / sum(Gain), Cover = Cover / sum(Cover), Frequency = Frequency / sum(Frequency))][order(Gain, decreasing = T)]
}
linearDump <- function(feature_names, text){
which(text == "weight:") %>% {a =. + 1; text[a:length(text)]} %>% as.numeric %>% data.table(Feature = feature_names, Weight = .)
}
model.text.dump <- xgb.dump(model = model, with.stats = T)
if(model.text.dump[2] == "bias:"){
result <- model.text.dump %>% linearDump(feature_names, .)
if(!is.null(data) | !is.null(label)) warning("data/label: these parameters should only be provided with decision tree based models.") if(!is.null(data) | !is.null(label)) warning("data/label: these parameters should only be provided with decision tree based models.")
} else { } else {
result <- treeDump(feature_names, text = text, keepDetail = !is.null(data)) result <- treeDump(feature_names, text = model.text.dump, keepDetail = !is.null(data))
# Co-occurence computation # Co-occurence computation
if(!is.null(data) & !is.null(label) & nrow(result) > 0) { if(!is.null(data) & !is.null(label) & nrow(result) > 0) {
@ -102,17 +111,7 @@ xgb.importance <- function(feature_names = NULL, model = NULL, data = NULL, labe
result result
} }
treeDump <- function(feature_names, text, keepDetail){
if(keepDetail) groupBy <- c("Feature", "Split", "MissingNo") else groupBy <- "Feature"
result <- xgb.model.dt.tree(feature_names = feature_names, text = text)[,"MissingNo" := Missing == No ][Feature != "Leaf",.(Gain = sum(Quality), Cover = sum(Cover), Frequency = .N), by = groupBy, with = T][,`:=`(Gain = Gain / sum(Gain), Cover = Cover / sum(Cover), Frequency = Frequency / sum(Frequency))][order(Gain, decreasing = T)]
result
}
linearDump <- function(feature_names, text){
which(text == "weight:") %>% {a =. + 1; text[a:length(text)]} %>% as.numeric %>% data.table(Feature = feature_names, Weight = .)
}
# Avoid error messages during CRAN check. # Avoid error messages during CRAN check.
# The reason is that these variables are never declared # The reason is that these variables are never declared

View File

@ -59,19 +59,19 @@ xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL, n
stop("feature_names: Has to be a vector of character or NULL if the model dump already contains feature name. Look at this function documentation to see where to get feature names.") stop("feature_names: Has to be a vector of character or NULL if the model dump already contains feature name. Look at this function documentation to see where to get feature names.")
} }
if (class(model) != "xgb.Booster") { if (class(model) != "xgb.Booster" & class(text) != "character") {
stop("model: Has to be an object of class xgb.Booster model generaged by the xgb.train function.") "model: Has to be an object of class xgb.Booster model generaged by the xgb.train function.\n" %>%
} paste0("text: Has to be a vector of character or NULL if a path to the model dump has already been provided.") %>%
stop()
if (!class(text) %in% c("character", "NULL")) {
stop("text: Has to be a vector of character or NULL if a path to the model dump has already been provided.")
} }
if (!class(n_first_tree) %in% c("numeric", "NULL") | length(n_first_tree) > 1) { if (!class(n_first_tree) %in% c("numeric", "NULL") | length(n_first_tree) > 1) {
stop("n_first_tree: Has to be a numeric vector of size 1.") stop("n_first_tree: Has to be a numeric vector of size 1.")
} }
text <- xgb.dump(model = model, with.stats = T) if(is.null(text)){
text <- xgb.dump(model = model, with.stats = T)
}
position <- str_match(text, "booster") %>% is.na %>% not %>% which %>% c(length(text) + 1) position <- str_match(text, "booster") %>% is.na %>% not %>% which %>% c(length(text) + 1)

View File

@ -45,7 +45,7 @@
#' eta = 1, nthread = 2, nround = 30, objective = "binary:logistic", #' eta = 1, nthread = 2, nround = 30, objective = "binary:logistic",
#' min_child_weight = 50) #' min_child_weight = 50)
#' #'
#' p <- xgb.plot.multi.trees(bst, agaricus.train$data@Dimnames[[2]], 3) #' p <- xgb.plot.multi.trees(model = bst, names = agaricus.train$data@Dimnames[[2]], 3)
#' print(p) #' print(p)
#' #'
#' @export #' @export

View File

@ -49,7 +49,7 @@ bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, max.dep
eta = 1, nthread = 2, nround = 30, objective = "binary:logistic", eta = 1, nthread = 2, nround = 30, objective = "binary:logistic",
min_child_weight = 50) min_child_weight = 50)
p <- xgb.plot.multi.trees(bst, agaricus.train$data@Dimnames[[2]], 3) p <- xgb.plot.multi.trees(model = bst, names = agaricus.train$data@Dimnames[[2]], 3)
print(p) print(p)
} }

View File

@ -19,15 +19,23 @@ bst <- xgboost(data = sparse_matrix, label = output_vector, max.depth = 9,
test_that("xgb.dump works", { test_that("xgb.dump works", {
capture.output(print(xgb.dump(bst))) capture.output(print(xgb.dump(bst)))
expect_true(xgb.dump(bst, 'xgb.model.dump', with.stats = T))
}) })
test_that("xgb.importance works", { test_that("xgb.importance works", {
expect_true(xgb.dump(bst, 'xgb.model.dump', with.stats = T))
importance <- xgb.importance(sparse_matrix@Dimnames[[2]], model = bst) importance <- xgb.importance(sparse_matrix@Dimnames[[2]], model = bst)
expect_equal(dim(importance), c(7, 4)) expect_equal(dim(importance), c(7, 4))
expect_equal(colnames(importance), c("Feature", "Gain", "Cover", "Frequency")) expect_equal(colnames(importance), c("Feature", "Gain", "Cover", "Frequency"))
}) })
test_that("xgb.plot.tree works", { test_that("xgb.plot.tree works", {
xgb.plot.tree(agaricus.train$data@Dimnames[[2]], model = bst) xgb.plot.tree(names = agaricus.train$data@Dimnames[[2]], model = bst)
})
test_that("xgb.plot.deepness works", {
xgb.plot.deepness(model = bst)
})
test_that("xgb.plot.multi.trees works", {
xgb.plot.multi.trees(model = bst, names = agaricus.train$data@Dimnames[[2]], 3)
}) })