Add new tests for new functions

This commit is contained in:
pommedeterresautee 2015-11-30 15:04:17 +01:00
parent ad8766dfa4
commit 96c43cf197
5 changed files with 33 additions and 26 deletions

View File

@ -76,14 +76,23 @@ xgb.importance <- function(feature_names = NULL, model = NULL, data = NULL, labe
if(class(label) == "numeric"){
if(sum(label == 0) / length(label) > 0.5) label <- as(label, "sparseVector")
}
text <- xgb.dump(model = model, with.stats = T)
if(text[2] == "bias:"){
result <- readLines(filename_dump) %>% linearDump(feature_names, .)
treeDump <- function(feature_names, text, keepDetail){
if(keepDetail) groupBy <- c("Feature", "Split", "MissingNo") else groupBy <- "Feature"
xgb.model.dt.tree(feature_names = feature_names, text = text)[,"MissingNo" := Missing == No ][Feature != "Leaf",.(Gain = sum(Quality), Cover = sum(Cover), Frequency = .N), by = groupBy, with = T][,`:=`(Gain = Gain / sum(Gain), Cover = Cover / sum(Cover), Frequency = Frequency / sum(Frequency))][order(Gain, decreasing = T)]
}
linearDump <- function(feature_names, text){
which(text == "weight:") %>% {a =. + 1; text[a:length(text)]} %>% as.numeric %>% data.table(Feature = feature_names, Weight = .)
}
model.text.dump <- xgb.dump(model = model, with.stats = T)
if(model.text.dump[2] == "bias:"){
result <- model.text.dump %>% linearDump(feature_names, .)
if(!is.null(data) | !is.null(label)) warning("data/label: these parameters should only be provided with decision tree based models.")
} else {
result <- treeDump(feature_names, text = text, keepDetail = !is.null(data))
result <- treeDump(feature_names, text = model.text.dump, keepDetail = !is.null(data))
# Co-occurence computation
if(!is.null(data) & !is.null(label) & nrow(result) > 0) {
@ -102,17 +111,7 @@ xgb.importance <- function(feature_names = NULL, model = NULL, data = NULL, labe
result
}
treeDump <- function(feature_names, text, keepDetail){
if(keepDetail) groupBy <- c("Feature", "Split", "MissingNo") else groupBy <- "Feature"
result <- xgb.model.dt.tree(feature_names = feature_names, text = text)[,"MissingNo" := Missing == No ][Feature != "Leaf",.(Gain = sum(Quality), Cover = sum(Cover), Frequency = .N), by = groupBy, with = T][,`:=`(Gain = Gain / sum(Gain), Cover = Cover / sum(Cover), Frequency = Frequency / sum(Frequency))][order(Gain, decreasing = T)]
result
}
linearDump <- function(feature_names, text){
which(text == "weight:") %>% {a =. + 1; text[a:length(text)]} %>% as.numeric %>% data.table(Feature = feature_names, Weight = .)
}
# Avoid error messages during CRAN check.
# The reason is that these variables are never declared

View File

@ -59,19 +59,19 @@ xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL, n
stop("feature_names: Has to be a vector of character or NULL if the model dump already contains feature name. Look at this function documentation to see where to get feature names.")
}
if (class(model) != "xgb.Booster") {
stop("model: Has to be an object of class xgb.Booster model generaged by the xgb.train function.")
}
if (!class(text) %in% c("character", "NULL")) {
stop("text: Has to be a vector of character or NULL if a path to the model dump has already been provided.")
if (class(model) != "xgb.Booster" & class(text) != "character") {
"model: Has to be an object of class xgb.Booster model generaged by the xgb.train function.\n" %>%
paste0("text: Has to be a vector of character or NULL if a path to the model dump has already been provided.") %>%
stop()
}
if (!class(n_first_tree) %in% c("numeric", "NULL") | length(n_first_tree) > 1) {
stop("n_first_tree: Has to be a numeric vector of size 1.")
}
text <- xgb.dump(model = model, with.stats = T)
if(is.null(text)){
text <- xgb.dump(model = model, with.stats = T)
}
position <- str_match(text, "booster") %>% is.na %>% not %>% which %>% c(length(text) + 1)

View File

@ -45,7 +45,7 @@
#' eta = 1, nthread = 2, nround = 30, objective = "binary:logistic",
#' min_child_weight = 50)
#'
#' p <- xgb.plot.multi.trees(bst, agaricus.train$data@Dimnames[[2]], 3)
#' p <- xgb.plot.multi.trees(model = bst, names = agaricus.train$data@Dimnames[[2]], 3)
#' print(p)
#'
#' @export

View File

@ -49,7 +49,7 @@ bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, max.dep
eta = 1, nthread = 2, nround = 30, objective = "binary:logistic",
min_child_weight = 50)
p <- xgb.plot.multi.trees(bst, agaricus.train$data@Dimnames[[2]], 3)
p <- xgb.plot.multi.trees(model = bst, names = agaricus.train$data@Dimnames[[2]], 3)
print(p)
}

View File

@ -19,15 +19,23 @@ bst <- xgboost(data = sparse_matrix, label = output_vector, max.depth = 9,
test_that("xgb.dump works", {
capture.output(print(xgb.dump(bst)))
expect_true(xgb.dump(bst, 'xgb.model.dump', with.stats = T))
})
test_that("xgb.importance works", {
expect_true(xgb.dump(bst, 'xgb.model.dump', with.stats = T))
importance <- xgb.importance(sparse_matrix@Dimnames[[2]], model = bst)
expect_equal(dim(importance), c(7, 4))
expect_equal(colnames(importance), c("Feature", "Gain", "Cover", "Frequency"))
})
test_that("xgb.plot.tree works", {
xgb.plot.tree(agaricus.train$data@Dimnames[[2]], model = bst)
xgb.plot.tree(names = agaricus.train$data@Dimnames[[2]], model = bst)
})
test_that("xgb.plot.deepness works", {
xgb.plot.deepness(model = bst)
})
test_that("xgb.plot.multi.trees works", {
xgb.plot.multi.trees(model = bst, names = agaricus.train$data@Dimnames[[2]], 3)
})