[R] Allow passing data.frame to SHAP (#10744)
This commit is contained in:
@@ -449,6 +449,26 @@ test_that("xgb.shap.data works with subsampling", {
|
||||
expect_equal(NROW(data_list$data), NROW(data_list$shap_contrib))
|
||||
})
|
||||
|
||||
test_that("xgb.shap.data works with data frames", {
|
||||
data(mtcars)
|
||||
df <- mtcars
|
||||
df$cyl <- factor(df$cyl)
|
||||
x <- df[, -1]
|
||||
y <- df$mpg
|
||||
dm <- xgb.DMatrix(x, label = y, nthread = 1L)
|
||||
model <- xgb.train(
|
||||
data = dm,
|
||||
params = list(
|
||||
max_depth = 2,
|
||||
nthread = 1
|
||||
),
|
||||
nrounds = 2
|
||||
)
|
||||
data_list <- xgb.shap.data(data = df[, -1], model = model, top_n = 2, subsample = 0.8)
|
||||
expect_equal(NROW(data_list$data), as.integer(0.8 * nrow(df)))
|
||||
expect_equal(NROW(data_list$data), NROW(data_list$shap_contrib))
|
||||
})
|
||||
|
||||
test_that("prepare.ggplot.shap.data works", {
|
||||
.skip_if_vcd_not_available()
|
||||
data_list <- xgb.shap.data(data = sparse_matrix, model = bst.Tree, top_n = 2)
|
||||
@@ -472,6 +492,44 @@ test_that("xgb.plot.shap.summary works", {
|
||||
expect_silent(xgb.ggplot.shap.summary(data = sparse_matrix, model = bst.Tree, top_n = 2))
|
||||
})
|
||||
|
||||
test_that("xgb.plot.shap.summary ignores categorical features", {
|
||||
.skip_if_vcd_not_available()
|
||||
data(mtcars)
|
||||
df <- mtcars
|
||||
df$cyl <- factor(df$cyl)
|
||||
levels(df$cyl) <- c("a", "b", "c")
|
||||
x <- df[, -1]
|
||||
y <- df$mpg
|
||||
dm <- xgb.DMatrix(x, label = y, nthread = 1L)
|
||||
model <- xgb.train(
|
||||
data = dm,
|
||||
params = list(
|
||||
max_depth = 2,
|
||||
nthread = 1
|
||||
),
|
||||
nrounds = 2
|
||||
)
|
||||
expect_warning({
|
||||
xgb.ggplot.shap.summary(data = x, model = model, top_n = 2)
|
||||
})
|
||||
|
||||
x_num <- mtcars[, -1]
|
||||
x_num$gear <- as.numeric(x_num$gear) - 1
|
||||
x_num <- as.matrix(x_num)
|
||||
dm <- xgb.DMatrix(x_num, label = y, feature_types = c(rep("q", 8), "c", "q"), nthread = 1L)
|
||||
model <- xgb.train(
|
||||
data = dm,
|
||||
params = list(
|
||||
max_depth = 2,
|
||||
nthread = 1
|
||||
),
|
||||
nrounds = 2
|
||||
)
|
||||
expect_warning({
|
||||
xgb.ggplot.shap.summary(data = x_num, model = model, top_n = 2)
|
||||
})
|
||||
})
|
||||
|
||||
test_that("check.deprecation works", {
|
||||
ttt <- function(a = NNULL, DUMMY = NULL, ...) {
|
||||
check.deprecation(...)
|
||||
|
||||
Reference in New Issue
Block a user