From 96c43cf1978f721f40329c4edd4101a90a9d7d35 Mon Sep 17 00:00:00 2001 From: pommedeterresautee Date: Mon, 30 Nov 2015 15:04:17 +0100 Subject: [PATCH] Add new tests for new functions --- R-package/R/xgb.importance.R | 29 ++++++++++++------------- R-package/R/xgb.model.dt.tree.R | 14 ++++++------ R-package/R/xgb.plot.multi.trees.R | 2 +- R-package/man/xgb.plot.multi.trees.Rd | 2 +- R-package/tests/testthat/test_helpers.R | 12 ++++++++-- 5 files changed, 33 insertions(+), 26 deletions(-) diff --git a/R-package/R/xgb.importance.R b/R-package/R/xgb.importance.R index 74151b1c4..d3a5910b4 100644 --- a/R-package/R/xgb.importance.R +++ b/R-package/R/xgb.importance.R @@ -76,14 +76,23 @@ xgb.importance <- function(feature_names = NULL, model = NULL, data = NULL, labe if(class(label) == "numeric"){ if(sum(label == 0) / length(label) > 0.5) label <- as(label, "sparseVector") } - - text <- xgb.dump(model = model, with.stats = T) - if(text[2] == "bias:"){ - result <- readLines(filename_dump) %>% linearDump(feature_names, .) + treeDump <- function(feature_names, text, keepDetail){ + if(keepDetail) groupBy <- c("Feature", "Split", "MissingNo") else groupBy <- "Feature" + xgb.model.dt.tree(feature_names = feature_names, text = text)[,"MissingNo" := Missing == No ][Feature != "Leaf",.(Gain = sum(Quality), Cover = sum(Cover), Frequency = .N), by = groupBy, with = T][,`:=`(Gain = Gain / sum(Gain), Cover = Cover / sum(Cover), Frequency = Frequency / sum(Frequency))][order(Gain, decreasing = T)] + } + + linearDump <- function(feature_names, text){ + which(text == "weight:") %>% {a =. + 1; text[a:length(text)]} %>% as.numeric %>% data.table(Feature = feature_names, Weight = .) + } + + model.text.dump <- xgb.dump(model = model, with.stats = T) + + if(model.text.dump[2] == "bias:"){ + result <- model.text.dump %>% linearDump(feature_names, .) if(!is.null(data) | !is.null(label)) warning("data/label: these parameters should only be provided with decision tree based models.") } else { - result <- treeDump(feature_names, text = text, keepDetail = !is.null(data)) + result <- treeDump(feature_names, text = model.text.dump, keepDetail = !is.null(data)) # Co-occurence computation if(!is.null(data) & !is.null(label) & nrow(result) > 0) { @@ -102,17 +111,7 @@ xgb.importance <- function(feature_names = NULL, model = NULL, data = NULL, labe result } -treeDump <- function(feature_names, text, keepDetail){ - if(keepDetail) groupBy <- c("Feature", "Split", "MissingNo") else groupBy <- "Feature" - result <- xgb.model.dt.tree(feature_names = feature_names, text = text)[,"MissingNo" := Missing == No ][Feature != "Leaf",.(Gain = sum(Quality), Cover = sum(Cover), Frequency = .N), by = groupBy, with = T][,`:=`(Gain = Gain / sum(Gain), Cover = Cover / sum(Cover), Frequency = Frequency / sum(Frequency))][order(Gain, decreasing = T)] - - result -} - -linearDump <- function(feature_names, text){ - which(text == "weight:") %>% {a =. + 1; text[a:length(text)]} %>% as.numeric %>% data.table(Feature = feature_names, Weight = .) -} # Avoid error messages during CRAN check. # The reason is that these variables are never declared diff --git a/R-package/R/xgb.model.dt.tree.R b/R-package/R/xgb.model.dt.tree.R index a70c344cc..29ef2e1df 100644 --- a/R-package/R/xgb.model.dt.tree.R +++ b/R-package/R/xgb.model.dt.tree.R @@ -59,19 +59,19 @@ xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL, n stop("feature_names: Has to be a vector of character or NULL if the model dump already contains feature name. Look at this function documentation to see where to get feature names.") } - if (class(model) != "xgb.Booster") { - stop("model: Has to be an object of class xgb.Booster model generaged by the xgb.train function.") - } - - if (!class(text) %in% c("character", "NULL")) { - stop("text: Has to be a vector of character or NULL if a path to the model dump has already been provided.") + if (class(model) != "xgb.Booster" & class(text) != "character") { + "model: Has to be an object of class xgb.Booster model generaged by the xgb.train function.\n" %>% + paste0("text: Has to be a vector of character or NULL if a path to the model dump has already been provided.") %>% + stop() } if (!class(n_first_tree) %in% c("numeric", "NULL") | length(n_first_tree) > 1) { stop("n_first_tree: Has to be a numeric vector of size 1.") } - text <- xgb.dump(model = model, with.stats = T) + if(is.null(text)){ + text <- xgb.dump(model = model, with.stats = T) + } position <- str_match(text, "booster") %>% is.na %>% not %>% which %>% c(length(text) + 1) diff --git a/R-package/R/xgb.plot.multi.trees.R b/R-package/R/xgb.plot.multi.trees.R index 13416b480..1efa375a4 100644 --- a/R-package/R/xgb.plot.multi.trees.R +++ b/R-package/R/xgb.plot.multi.trees.R @@ -45,7 +45,7 @@ #' eta = 1, nthread = 2, nround = 30, objective = "binary:logistic", #' min_child_weight = 50) #' -#' p <- xgb.plot.multi.trees(bst, agaricus.train$data@Dimnames[[2]], 3) +#' p <- xgb.plot.multi.trees(model = bst, names = agaricus.train$data@Dimnames[[2]], 3) #' print(p) #' #' @export diff --git a/R-package/man/xgb.plot.multi.trees.Rd b/R-package/man/xgb.plot.multi.trees.Rd index b3cacc122..6e59915e2 100644 --- a/R-package/man/xgb.plot.multi.trees.Rd +++ b/R-package/man/xgb.plot.multi.trees.Rd @@ -49,7 +49,7 @@ bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, max.dep eta = 1, nthread = 2, nround = 30, objective = "binary:logistic", min_child_weight = 50) -p <- xgb.plot.multi.trees(bst, agaricus.train$data@Dimnames[[2]], 3) +p <- xgb.plot.multi.trees(model = bst, names = agaricus.train$data@Dimnames[[2]], 3) print(p) } diff --git a/R-package/tests/testthat/test_helpers.R b/R-package/tests/testthat/test_helpers.R index 11368216b..490b6b867 100644 --- a/R-package/tests/testthat/test_helpers.R +++ b/R-package/tests/testthat/test_helpers.R @@ -19,15 +19,23 @@ bst <- xgboost(data = sparse_matrix, label = output_vector, max.depth = 9, test_that("xgb.dump works", { capture.output(print(xgb.dump(bst))) + expect_true(xgb.dump(bst, 'xgb.model.dump', with.stats = T)) }) test_that("xgb.importance works", { - expect_true(xgb.dump(bst, 'xgb.model.dump', with.stats = T)) importance <- xgb.importance(sparse_matrix@Dimnames[[2]], model = bst) expect_equal(dim(importance), c(7, 4)) expect_equal(colnames(importance), c("Feature", "Gain", "Cover", "Frequency")) }) test_that("xgb.plot.tree works", { - xgb.plot.tree(agaricus.train$data@Dimnames[[2]], model = bst) + xgb.plot.tree(names = agaricus.train$data@Dimnames[[2]], model = bst) +}) + +test_that("xgb.plot.deepness works", { + xgb.plot.deepness(model = bst) +}) + +test_that("xgb.plot.multi.trees works", { + xgb.plot.multi.trees(model = bst, names = agaricus.train$data@Dimnames[[2]], 3) }) \ No newline at end of file