@@ -125,12 +125,12 @@ xgb.plot.shap <- function(data, shap_contrib = NULL, features = NULL, top_n = 1,
|
||||
|
||||
nsample <- if (is.null(subsample)) min(100000, nrow(data)) else as.integer(subsample * nrow(data))
|
||||
idx <- sample(1:nrow(data), nsample)
|
||||
data <- data[idx,]
|
||||
data <- data[idx, ]
|
||||
|
||||
if (is.null(shap_contrib)) {
|
||||
shap_contrib <- predict(model, data, predcontrib = TRUE, approxcontrib = approxcontrib)
|
||||
} else {
|
||||
shap_contrib <- shap_contrib[idx,]
|
||||
shap_contrib <- shap_contrib[idx, ]
|
||||
}
|
||||
|
||||
which <- match.arg(which)
|
||||
@@ -168,8 +168,8 @@ xgb.plot.shap <- function(data, shap_contrib = NULL, features = NULL, top_n = 1,
|
||||
|
||||
if (plot && which == "1d") {
|
||||
op <- par(mfrow = c(ceiling(length(features) / n_col), n_col),
|
||||
oma = c(0,0,0,0) + 0.2,
|
||||
mar = c(3.5,3.5,0,0) + 0.1,
|
||||
oma = c(0, 0, 0, 0) + 0.2,
|
||||
mar = c(3.5, 3.5, 0, 0) + 0.1,
|
||||
mgp = c(1.7, 0.6, 0))
|
||||
for (f in cols) {
|
||||
ord <- order(data[, f])
|
||||
@@ -192,7 +192,7 @@ xgb.plot.shap <- function(data, shap_contrib = NULL, features = NULL, top_n = 1,
|
||||
grid()
|
||||
if (plot_loess) {
|
||||
# compress x to 3 digits, and mean-aggredate y
|
||||
zz <- data.table(x = signif(x, 3), y)[, .(.N, y=mean(y)), x]
|
||||
zz <- data.table(x = signif(x, 3), y)[, .(.N, y = mean(y)), x]
|
||||
if (nrow(zz) <= 5) {
|
||||
lines(zz$x, zz$y, col = col_loess)
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user