parent
3e2d7519a6
commit
328d1e18db
@ -435,7 +435,8 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
|
|||||||
lapply(seq_len(n_groups), function(g) arr[g, , ])
|
lapply(seq_len(n_groups), function(g) arr[g, , ])
|
||||||
} else {
|
} else {
|
||||||
## remove the first axis (group)
|
## remove the first axis (group)
|
||||||
as.matrix(arr[1, , ])
|
dn <- dimnames(arr)
|
||||||
|
matrix(arr[1, , ], nrow = dim(arr)[2], ncol = dim(arr)[3], dimnames = c(dn[2], dn[3]))
|
||||||
}
|
}
|
||||||
} else if (predinteraction) {
|
} else if (predinteraction) {
|
||||||
## Predict interaction
|
## Predict interaction
|
||||||
@ -447,7 +448,8 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
|
|||||||
lapply(seq_len(n_groups), function(g) arr[g, , , ])
|
lapply(seq_len(n_groups), function(g) arr[g, , , ])
|
||||||
} else {
|
} else {
|
||||||
## remove the first axis (group)
|
## remove the first axis (group)
|
||||||
arr[1, , , ]
|
arr <- arr[1, , , , drop = FALSE]
|
||||||
|
array(arr, dim = dim(arr)[2:4], dimnames(arr)[2:4])
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
## Normal prediction
|
## Normal prediction
|
||||||
|
|||||||
@ -157,3 +157,28 @@ test_that("multiclass feature interactions work", {
|
|||||||
# sums WRT columns must be close to feature contributions
|
# sums WRT columns must be close to feature contributions
|
||||||
expect_lt(max(abs(apply(intr, c(1, 2, 3), sum) - aperm(cont, c(3, 1, 2)))), 0.00001)
|
expect_lt(max(abs(apply(intr, c(1, 2, 3), sum) - aperm(cont, c(3, 1, 2)))), 0.00001)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
test_that("SHAP single sample works", {
|
||||||
|
train <- agaricus.train
|
||||||
|
test <- agaricus.test
|
||||||
|
booster <- xgboost(
|
||||||
|
data = train$data,
|
||||||
|
label = train$label,
|
||||||
|
max_depth = 2,
|
||||||
|
nrounds = 4,
|
||||||
|
objective = "binary:logistic",
|
||||||
|
)
|
||||||
|
|
||||||
|
predt <- predict(
|
||||||
|
booster,
|
||||||
|
newdata = train$data[1, , drop = FALSE], predcontrib = TRUE
|
||||||
|
)
|
||||||
|
expect_equal(dim(predt), c(1, dim(train$data)[2] + 1))
|
||||||
|
|
||||||
|
predt <- predict(
|
||||||
|
booster,
|
||||||
|
newdata = train$data[1, , drop = FALSE], predinteraction = TRUE
|
||||||
|
)
|
||||||
|
expect_equal(dim(predt), c(1, dim(train$data)[2] + 1, dim(train$data)[2] + 1))
|
||||||
|
})
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user