[R] Avoid memory copies in predict (#9902)

This commit is contained in:
david-cortes 2024-01-20 17:53:18 +01:00 committed by GitHub
parent 2c8fa8b8b9
commit 60b9d2eeb9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 58 additions and 20 deletions

View File

@ -343,24 +343,24 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
) )
names(predts) <- c("shape", "results") names(predts) <- c("shape", "results")
shape <- predts$shape shape <- predts$shape
ret <- predts$results arr <- predts$results
n_ret <- length(ret) n_ret <- length(arr)
n_row <- nrow(newdata) n_row <- nrow(newdata)
if (n_row != shape[1]) { if (n_row != shape[1]) {
stop("Incorrect predict shape.") stop("Incorrect predict shape.")
} }
arr <- array(data = ret, dim = rev(shape)) .Call(XGSetArrayDimInplace_R, arr, rev(shape))
cnames <- if (!is.null(colnames(newdata))) c(colnames(newdata), "BIAS") else NULL cnames <- if (!is.null(colnames(newdata))) c(colnames(newdata), "BIAS") else NULL
n_groups <- shape[2] n_groups <- shape[2]
## Needed regardless of whether strict shape is being used. ## Needed regardless of whether strict shape is being used.
if (predcontrib) { if (predcontrib) {
dimnames(arr) <- list(cnames, NULL, NULL) .Call(XGSetArrayDimNamesInplace_R, arr, list(cnames, NULL, NULL))
} else if (predinteraction) { } else if (predinteraction) {
dimnames(arr) <- list(cnames, cnames, NULL, NULL) .Call(XGSetArrayDimNamesInplace_R, arr, list(cnames, cnames, NULL, NULL))
} }
if (strict_shape) { if (strict_shape) {
return(arr) # strict shape is calculated by libxgboost uniformly. return(arr) # strict shape is calculated by libxgboost uniformly.
@ -368,43 +368,51 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
if (predleaf) { if (predleaf) {
## Predict leaf ## Predict leaf
arr <- if (n_ret == n_row) { if (n_ret == n_row) {
matrix(arr, ncol = 1) .Call(XGSetArrayDimInplace_R, arr, c(n_row, 1L))
} else { } else {
matrix(arr, nrow = n_row, byrow = TRUE) arr <- matrix(arr, nrow = n_row, byrow = TRUE)
} }
} else if (predcontrib) { } else if (predcontrib) {
## Predict contribution ## Predict contribution
arr <- aperm(a = arr, perm = c(2, 3, 1)) # [group, row, col] arr <- aperm(a = arr, perm = c(2, 3, 1)) # [group, row, col]
arr <- if (n_ret == n_row) { if (n_ret == n_row) {
matrix(arr, ncol = 1, dimnames = list(NULL, cnames)) .Call(XGSetArrayDimInplace_R, arr, c(n_row, 1L))
.Call(XGSetArrayDimNamesInplace_R, arr, list(NULL, cnames))
} else if (n_groups != 1) { } else if (n_groups != 1) {
## turns array into list of matrices ## turns array into list of matrices
lapply(seq_len(n_groups), function(g) arr[g, , ]) arr <- lapply(seq_len(n_groups), function(g) arr[g, , ])
} else { } else {
## remove the first axis (group) ## remove the first axis (group)
dn <- dimnames(arr) newdim <- dim(arr)[2:3]
matrix(arr[1, , ], nrow = dim(arr)[2], ncol = dim(arr)[3], dimnames = c(dn[2], dn[3])) newdn <- dimnames(arr)[2:3]
arr <- arr[1, , ]
.Call(XGSetArrayDimInplace_R, arr, newdim)
.Call(XGSetArrayDimNamesInplace_R, arr, newdn)
} }
} else if (predinteraction) { } else if (predinteraction) {
## Predict interaction ## Predict interaction
arr <- aperm(a = arr, perm = c(3, 4, 1, 2)) # [group, row, col, col] arr <- aperm(a = arr, perm = c(3, 4, 1, 2)) # [group, row, col, col]
arr <- if (n_ret == n_row) { if (n_ret == n_row) {
matrix(arr, ncol = 1, dimnames = list(NULL, cnames)) .Call(XGSetArrayDimInplace_R, arr, c(n_row, 1L))
.Call(XGSetArrayDimNamesInplace_R, arr, list(NULL, cnames))
} else if (n_groups != 1) { } else if (n_groups != 1) {
## turns array into list of matrices ## turns array into list of matrices
lapply(seq_len(n_groups), function(g) arr[g, , , ]) arr <- lapply(seq_len(n_groups), function(g) arr[g, , , ])
} else { } else {
## remove the first axis (group) ## remove the first axis (group)
arr <- arr[1, , , , drop = FALSE] arr <- arr[1, , , , drop = FALSE]
array(arr, dim = dim(arr)[2:4], dimnames(arr)[2:4]) newdim <- dim(arr)[2:4]
newdn <- dimnames(arr)[2:4]
.Call(XGSetArrayDimInplace_R, arr, newdim)
.Call(XGSetArrayDimNamesInplace_R, arr, newdn)
} }
} else { } else {
## Normal prediction ## Normal prediction
arr <- if (reshape && n_groups != 1) { if (reshape && n_groups != 1) {
matrix(arr, ncol = n_groups, byrow = TRUE) arr <- matrix(arr, ncol = n_groups, byrow = TRUE)
} else { } else {
as.vector(ret) .Call(XGSetArrayDimInplace_R, arr, NULL)
} }
} }
return(arr) return(arr)

View File

@ -42,6 +42,8 @@ extern SEXP XGBoosterSetAttr_R(SEXP, SEXP, SEXP);
extern SEXP XGBoosterSetParam_R(SEXP, SEXP, SEXP); extern SEXP XGBoosterSetParam_R(SEXP, SEXP, SEXP);
extern SEXP XGBoosterUpdateOneIter_R(SEXP, SEXP, SEXP); extern SEXP XGBoosterUpdateOneIter_R(SEXP, SEXP, SEXP);
extern SEXP XGCheckNullPtr_R(SEXP); extern SEXP XGCheckNullPtr_R(SEXP);
extern SEXP XGSetArrayDimInplace_R(SEXP, SEXP);
extern SEXP XGSetArrayDimNamesInplace_R(SEXP, SEXP);
extern SEXP XGDMatrixCreateFromCSC_R(SEXP, SEXP, SEXP, SEXP, SEXP, SEXP); extern SEXP XGDMatrixCreateFromCSC_R(SEXP, SEXP, SEXP, SEXP, SEXP, SEXP);
extern SEXP XGDMatrixCreateFromCSR_R(SEXP, SEXP, SEXP, SEXP, SEXP, SEXP); extern SEXP XGDMatrixCreateFromCSR_R(SEXP, SEXP, SEXP, SEXP, SEXP, SEXP);
extern SEXP XGDMatrixCreateFromFile_R(SEXP, SEXP); extern SEXP XGDMatrixCreateFromFile_R(SEXP, SEXP);
@ -90,6 +92,8 @@ static const R_CallMethodDef CallEntries[] = {
{"XGBoosterSetParam_R", (DL_FUNC) &XGBoosterSetParam_R, 3}, {"XGBoosterSetParam_R", (DL_FUNC) &XGBoosterSetParam_R, 3},
{"XGBoosterUpdateOneIter_R", (DL_FUNC) &XGBoosterUpdateOneIter_R, 3}, {"XGBoosterUpdateOneIter_R", (DL_FUNC) &XGBoosterUpdateOneIter_R, 3},
{"XGCheckNullPtr_R", (DL_FUNC) &XGCheckNullPtr_R, 1}, {"XGCheckNullPtr_R", (DL_FUNC) &XGCheckNullPtr_R, 1},
{"XGSetArrayDimInplace_R", (DL_FUNC) &XGSetArrayDimInplace_R, 2},
{"XGSetArrayDimNamesInplace_R", (DL_FUNC) &XGSetArrayDimNamesInplace_R, 2},
{"XGDMatrixCreateFromCSC_R", (DL_FUNC) &XGDMatrixCreateFromCSC_R, 6}, {"XGDMatrixCreateFromCSC_R", (DL_FUNC) &XGDMatrixCreateFromCSC_R, 6},
{"XGDMatrixCreateFromCSR_R", (DL_FUNC) &XGDMatrixCreateFromCSR_R, 6}, {"XGDMatrixCreateFromCSR_R", (DL_FUNC) &XGDMatrixCreateFromCSR_R, 6},
{"XGDMatrixCreateFromFile_R", (DL_FUNC) &XGDMatrixCreateFromFile_R, 2}, {"XGDMatrixCreateFromFile_R", (DL_FUNC) &XGDMatrixCreateFromFile_R, 2},

View File

@ -263,6 +263,16 @@ XGB_DLL SEXP XGCheckNullPtr_R(SEXP handle) {
return Rf_ScalarLogical(R_ExternalPtrAddr(handle) == nullptr); return Rf_ScalarLogical(R_ExternalPtrAddr(handle) == nullptr);
} }
XGB_DLL SEXP XGSetArrayDimInplace_R(SEXP arr, SEXP dims) {
Rf_setAttrib(arr, R_DimSymbol, dims);
return R_NilValue;
}
XGB_DLL SEXP XGSetArrayDimNamesInplace_R(SEXP arr, SEXP dim_names) {
Rf_setAttrib(arr, R_DimNamesSymbol, dim_names);
return R_NilValue;
}
namespace { namespace {
void _DMatrixFinalizer(SEXP ext) { void _DMatrixFinalizer(SEXP ext) {
R_API_BEGIN(); R_API_BEGIN();

View File

@ -23,6 +23,22 @@
*/ */
XGB_DLL SEXP XGCheckNullPtr_R(SEXP handle); XGB_DLL SEXP XGCheckNullPtr_R(SEXP handle);
/*!
* \brief set the dimensions of an array in-place
* \param arr
* \param dims dimensions to set to the array
* \return NULL value
*/
XGB_DLL SEXP XGSetArrayDimInplace_R(SEXP arr, SEXP dims);
/*!
* \brief set the names of the dimensions of an array in-place
* \param arr
* \param dim_names names for the dimensions to set
* \return NULL value
*/
XGB_DLL SEXP XGSetArrayDimNamesInplace_R(SEXP arr, SEXP dim_names);
/*! /*!
* \brief Set global configuration * \brief Set global configuration
* \param json_str a JSON string representing the list of key-value pairs * \param json_str a JSON string representing the list of key-value pairs