[backport] [R] Fix single sample prediction. (#7524) (#7558)

This commit is contained in:
Jiaming Yuan 2022-01-14 00:20:17 +08:00 committed by GitHub
parent 3e2d7519a6
commit 328d1e18db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 2 deletions

View File

@ -435,7 +435,8 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
lapply(seq_len(n_groups), function(g) arr[g, , ])
} else {
## remove the first axis (group)
as.matrix(arr[1, , ])
dn <- dimnames(arr)
matrix(arr[1, , ], nrow = dim(arr)[2], ncol = dim(arr)[3], dimnames = c(dn[2], dn[3]))
}
} else if (predinteraction) {
## Predict interaction
@ -447,7 +448,8 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
lapply(seq_len(n_groups), function(g) arr[g, , , ])
} else {
## remove the first axis (group)
arr[1, , , ]
arr <- arr[1, , , , drop = FALSE]
array(arr, dim = dim(arr)[2:4], dimnames(arr)[2:4])
}
} else {
## Normal prediction

View File

@ -157,3 +157,28 @@ test_that("multiclass feature interactions work", {
# sums WRT columns must be close to feature contributions
expect_lt(max(abs(apply(intr, c(1, 2, 3), sum) - aperm(cont, c(3, 1, 2)))), 0.00001)
})
test_that("SHAP single sample works", {
train <- agaricus.train
test <- agaricus.test
booster <- xgboost(
data = train$data,
label = train$label,
max_depth = 2,
nrounds = 4,
objective = "binary:logistic",
)
predt <- predict(
booster,
newdata = train$data[1, , drop = FALSE], predcontrib = TRUE
)
expect_equal(dim(predt), c(1, dim(train$data)[2] + 1))
predt <- predict(
booster,
newdata = train$data[1, , drop = FALSE], predinteraction = TRUE
)
expect_equal(dim(predt), c(1, dim(train$data)[2] + 1, dim(train$data)[2] + 1))
})