[R] Fix global feature importance and predict with 1 sample. (#7394)
* [R] Fix global feature importance. * Add implementation for tree index. The parameter is not documented in C API since we should work on porting the model slicing to R instead of supporting more use of tree index. * Fix the difference between "gain" and "total_gain". * debug. * Fix prediction.
This commit is contained in:
@@ -397,6 +397,7 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
|
||||
shape <- predts$shape
|
||||
ret <- predts$results
|
||||
|
||||
n_ret <- length(ret)
|
||||
n_row <- nrow(newdata)
|
||||
if (n_row != shape[1]) {
|
||||
stop("Incorrect predict shape.")
|
||||
@@ -405,36 +406,55 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
|
||||
arr <- array(data = ret, dim = rev(shape))
|
||||
|
||||
cnames <- if (!is.null(colnames(newdata))) c(colnames(newdata), "BIAS") else NULL
|
||||
n_groups <- shape[2]
|
||||
|
||||
## Needed regardless of whether strict shape is being used.
|
||||
if (predcontrib) {
|
||||
dimnames(arr) <- list(cnames, NULL, NULL)
|
||||
if (!strict_shape) {
|
||||
arr <- aperm(a = arr, perm = c(2, 3, 1)) # [group, row, col]
|
||||
}
|
||||
} else if (predinteraction) {
|
||||
dimnames(arr) <- list(cnames, cnames, NULL, NULL)
|
||||
if (!strict_shape) {
|
||||
arr <- aperm(a = arr, perm = c(3, 4, 1, 2)) # [group, row, col, col]
|
||||
}
|
||||
}
|
||||
if (strict_shape) {
|
||||
return(arr) # strict shape is calculated by libxgboost uniformly.
|
||||
}
|
||||
|
||||
if (!strict_shape) {
|
||||
n_groups <- shape[2]
|
||||
if (predleaf) {
|
||||
arr <- matrix(arr, nrow = n_row, byrow = TRUE)
|
||||
} else if (predcontrib && n_groups != 1) {
|
||||
arr <- lapply(seq_len(n_groups), function(g) arr[g, , ])
|
||||
} else if (predinteraction && n_groups != 1) {
|
||||
arr <- lapply(seq_len(n_groups), function(g) arr[g, , , ])
|
||||
} else if (!reshape && n_groups != 1) {
|
||||
arr <- ret
|
||||
} else if (reshape && n_groups != 1) {
|
||||
arr <- matrix(arr, ncol = n_groups, byrow = TRUE)
|
||||
if (predleaf) {
|
||||
## Predict leaf
|
||||
arr <- if (n_ret == n_row) {
|
||||
matrix(arr, ncol = 1)
|
||||
} else {
|
||||
matrix(arr, nrow = n_row, byrow = TRUE)
|
||||
}
|
||||
arr <- drop(arr)
|
||||
if (length(dim(arr)) == 1) {
|
||||
arr <- as.vector(arr)
|
||||
} else if (length(dim(arr)) == 2) {
|
||||
arr <- as.matrix(arr)
|
||||
} else if (predcontrib) {
|
||||
## Predict contribution
|
||||
arr <- aperm(a = arr, perm = c(2, 3, 1)) # [group, row, col]
|
||||
arr <- if (n_ret == n_row) {
|
||||
matrix(arr, ncol = 1, dimnames = list(NULL, cnames))
|
||||
} else if (n_groups != 1) {
|
||||
## turns array into list of matrices
|
||||
lapply(seq_len(n_groups), function(g) arr[g, , ])
|
||||
} else {
|
||||
## remove the first axis (group)
|
||||
as.matrix(arr[1, , ])
|
||||
}
|
||||
} else if (predinteraction) {
|
||||
## Predict interaction
|
||||
arr <- aperm(a = arr, perm = c(3, 4, 1, 2)) # [group, row, col, col]
|
||||
arr <- if (n_ret == n_row) {
|
||||
matrix(arr, ncol = 1, dimnames = list(NULL, cnames))
|
||||
} else if (n_groups != 1) {
|
||||
## turns array into list of matrices
|
||||
lapply(seq_len(n_groups), function(g) arr[g, , , ])
|
||||
} else {
|
||||
## remove the first axis (group)
|
||||
arr[1, , , ]
|
||||
}
|
||||
} else {
|
||||
## Normal prediction
|
||||
arr <- if (reshape && n_groups != 1) {
|
||||
matrix(arr, ncol = n_groups, byrow = TRUE)
|
||||
} else {
|
||||
as.vector(ret)
|
||||
}
|
||||
}
|
||||
return(arr)
|
||||
|
||||
Reference in New Issue
Block a user