Add SHAP summary plot using ggplot2 (#5882)
* add SHAP summary plot using ggplot2 * Update xgb.plot.shap * Update example in xgb.plot.shap documentation * update logic, add tests * whitespace fixes * whitespace fixes for test_helpers * namespace for sd function * explicitly declare variables that are automatically evaluated by data.table * Fix R lint Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -351,11 +351,47 @@ test_that("xgb.plot.deepness works", {
|
||||
xgb.ggplot.deepness(model = bst.Tree)
|
||||
})
|
||||
|
||||
test_that("xgb.shap.data works when top_n is provided", {
|
||||
data_list <- xgb.shap.data(data = sparse_matrix, model = bst.Tree, top_n = 2)
|
||||
expect_equal(names(data_list), c("data", "shap_contrib"))
|
||||
expect_equal(NCOL(data_list$data), 2)
|
||||
expect_equal(NCOL(data_list$shap_contrib), 2)
|
||||
expect_equal(NROW(data_list$data), NROW(data_list$shap_contrib))
|
||||
expect_gt(length(colnames(data_list$data)), 0)
|
||||
expect_gt(length(colnames(data_list$shap_contrib)), 0)
|
||||
|
||||
# for multiclass without target class provided
|
||||
data_list <- xgb.shap.data(data = as.matrix(iris[, -5]), model = mbst.Tree, top_n = 2)
|
||||
expect_equal(dim(data_list$shap_contrib), c(nrow(iris), 2))
|
||||
# for multiclass with target class provided
|
||||
data_list <- xgb.shap.data(data = as.matrix(iris[, -5]), model = mbst.Tree, top_n = 2, target_class = 0)
|
||||
expect_equal(dim(data_list$shap_contrib), c(nrow(iris), 2))
|
||||
})
|
||||
|
||||
test_that("xgb.shap.data works with subsampling", {
|
||||
data_list <- xgb.shap.data(data = sparse_matrix, model = bst.Tree, top_n = 2, subsample = 0.8)
|
||||
expect_equal(NROW(data_list$data), as.integer(0.8 * nrow(sparse_matrix)))
|
||||
expect_equal(NROW(data_list$data), NROW(data_list$shap_contrib))
|
||||
})
|
||||
|
||||
test_that("prepare.ggplot.shap.data works", {
|
||||
data_list <- xgb.shap.data(data = sparse_matrix, model = bst.Tree, top_n = 2)
|
||||
plot_data <- prepare.ggplot.shap.data(data_list, normalize = TRUE)
|
||||
expect_s3_class(plot_data, "data.frame")
|
||||
expect_equal(names(plot_data), c("id", "feature", "feature_value", "shap_value"))
|
||||
expect_s3_class(plot_data$feature, "factor")
|
||||
# Each observation should have 1 row for each feature
|
||||
expect_equal(nrow(plot_data), nrow(sparse_matrix) * 2)
|
||||
})
|
||||
|
||||
test_that("xgb.plot.shap works", {
|
||||
sh <- xgb.plot.shap(data = sparse_matrix, model = bst.Tree, top_n = 2, col = 4)
|
||||
expect_equal(names(sh), c("data", "shap_contrib"))
|
||||
expect_equal(NCOL(sh$data), 2)
|
||||
expect_equal(NCOL(sh$shap_contrib), 2)
|
||||
})
|
||||
|
||||
test_that("xgb.plot.shap.summary works", {
|
||||
xgb.plot.shap.summary(data = sparse_matrix, model = bst.Tree, top_n = 2)
|
||||
xgb.ggplot.shap.summary(data = sparse_matrix, model = bst.Tree, top_n = 2)
|
||||
})
|
||||
|
||||
test_that("check.deprecation works", {
|
||||
|
||||
Reference in New Issue
Block a user