[R] Avoid memory copies in predict (#9902)
This commit is contained in:
@@ -343,24 +343,24 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
|
||||
)
|
||||
names(predts) <- c("shape", "results")
|
||||
shape <- predts$shape
|
||||
ret <- predts$results
|
||||
arr <- predts$results
|
||||
|
||||
n_ret <- length(ret)
|
||||
n_ret <- length(arr)
|
||||
n_row <- nrow(newdata)
|
||||
if (n_row != shape[1]) {
|
||||
stop("Incorrect predict shape.")
|
||||
}
|
||||
|
||||
arr <- array(data = ret, dim = rev(shape))
|
||||
.Call(XGSetArrayDimInplace_R, arr, rev(shape))
|
||||
|
||||
cnames <- if (!is.null(colnames(newdata))) c(colnames(newdata), "BIAS") else NULL
|
||||
n_groups <- shape[2]
|
||||
|
||||
## Needed regardless of whether strict shape is being used.
|
||||
if (predcontrib) {
|
||||
dimnames(arr) <- list(cnames, NULL, NULL)
|
||||
.Call(XGSetArrayDimNamesInplace_R, arr, list(cnames, NULL, NULL))
|
||||
} else if (predinteraction) {
|
||||
dimnames(arr) <- list(cnames, cnames, NULL, NULL)
|
||||
.Call(XGSetArrayDimNamesInplace_R, arr, list(cnames, cnames, NULL, NULL))
|
||||
}
|
||||
if (strict_shape) {
|
||||
return(arr) # strict shape is calculated by libxgboost uniformly.
|
||||
@@ -368,43 +368,51 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
|
||||
|
||||
if (predleaf) {
|
||||
## Predict leaf
|
||||
arr <- if (n_ret == n_row) {
|
||||
matrix(arr, ncol = 1)
|
||||
if (n_ret == n_row) {
|
||||
.Call(XGSetArrayDimInplace_R, arr, c(n_row, 1L))
|
||||
} else {
|
||||
matrix(arr, nrow = n_row, byrow = TRUE)
|
||||
arr <- matrix(arr, nrow = n_row, byrow = TRUE)
|
||||
}
|
||||
} else if (predcontrib) {
|
||||
## Predict contribution
|
||||
arr <- aperm(a = arr, perm = c(2, 3, 1)) # [group, row, col]
|
||||
arr <- if (n_ret == n_row) {
|
||||
matrix(arr, ncol = 1, dimnames = list(NULL, cnames))
|
||||
if (n_ret == n_row) {
|
||||
.Call(XGSetArrayDimInplace_R, arr, c(n_row, 1L))
|
||||
.Call(XGSetArrayDimNamesInplace_R, arr, list(NULL, cnames))
|
||||
} else if (n_groups != 1) {
|
||||
## turns array into list of matrices
|
||||
lapply(seq_len(n_groups), function(g) arr[g, , ])
|
||||
arr <- lapply(seq_len(n_groups), function(g) arr[g, , ])
|
||||
} else {
|
||||
## remove the first axis (group)
|
||||
dn <- dimnames(arr)
|
||||
matrix(arr[1, , ], nrow = dim(arr)[2], ncol = dim(arr)[3], dimnames = c(dn[2], dn[3]))
|
||||
newdim <- dim(arr)[2:3]
|
||||
newdn <- dimnames(arr)[2:3]
|
||||
arr <- arr[1, , ]
|
||||
.Call(XGSetArrayDimInplace_R, arr, newdim)
|
||||
.Call(XGSetArrayDimNamesInplace_R, arr, newdn)
|
||||
}
|
||||
} else if (predinteraction) {
|
||||
## Predict interaction
|
||||
arr <- aperm(a = arr, perm = c(3, 4, 1, 2)) # [group, row, col, col]
|
||||
arr <- if (n_ret == n_row) {
|
||||
matrix(arr, ncol = 1, dimnames = list(NULL, cnames))
|
||||
if (n_ret == n_row) {
|
||||
.Call(XGSetArrayDimInplace_R, arr, c(n_row, 1L))
|
||||
.Call(XGSetArrayDimNamesInplace_R, arr, list(NULL, cnames))
|
||||
} else if (n_groups != 1) {
|
||||
## turns array into list of matrices
|
||||
lapply(seq_len(n_groups), function(g) arr[g, , , ])
|
||||
arr <- lapply(seq_len(n_groups), function(g) arr[g, , , ])
|
||||
} else {
|
||||
## remove the first axis (group)
|
||||
arr <- arr[1, , , , drop = FALSE]
|
||||
array(arr, dim = dim(arr)[2:4], dimnames(arr)[2:4])
|
||||
newdim <- dim(arr)[2:4]
|
||||
newdn <- dimnames(arr)[2:4]
|
||||
.Call(XGSetArrayDimInplace_R, arr, newdim)
|
||||
.Call(XGSetArrayDimNamesInplace_R, arr, newdn)
|
||||
}
|
||||
} else {
|
||||
## Normal prediction
|
||||
arr <- if (reshape && n_groups != 1) {
|
||||
matrix(arr, ncol = n_groups, byrow = TRUE)
|
||||
if (reshape && n_groups != 1) {
|
||||
arr <- matrix(arr, ncol = n_groups, byrow = TRUE)
|
||||
} else {
|
||||
as.vector(ret)
|
||||
.Call(XGSetArrayDimInplace_R, arr, NULL)
|
||||
}
|
||||
}
|
||||
return(arr)
|
||||
|
||||
Reference in New Issue
Block a user