From 60b9d2eeb9f88f8d028a1f2fba37abbb3d0f47cb Mon Sep 17 00:00:00 2001 From: david-cortes Date: Sat, 20 Jan 2024 17:53:18 +0100 Subject: [PATCH] [R] Avoid memory copies in `predict` (#9902) --- R-package/R/xgb.Booster.R | 48 ++++++++++++++++++++++---------------- R-package/src/init.c | 4 ++++ R-package/src/xgboost_R.cc | 10 ++++++++ R-package/src/xgboost_R.h | 16 +++++++++++++ 4 files changed, 58 insertions(+), 20 deletions(-) diff --git a/R-package/R/xgb.Booster.R b/R-package/R/xgb.Booster.R index cee7e9fc5..5562c22f3 100644 --- a/R-package/R/xgb.Booster.R +++ b/R-package/R/xgb.Booster.R @@ -343,24 +343,24 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA ) names(predts) <- c("shape", "results") shape <- predts$shape - ret <- predts$results + arr <- predts$results - n_ret <- length(ret) + n_ret <- length(arr) n_row <- nrow(newdata) if (n_row != shape[1]) { stop("Incorrect predict shape.") } - arr <- array(data = ret, dim = rev(shape)) + .Call(XGSetArrayDimInplace_R, arr, rev(shape)) cnames <- if (!is.null(colnames(newdata))) c(colnames(newdata), "BIAS") else NULL n_groups <- shape[2] ## Needed regardless of whether strict shape is being used. if (predcontrib) { - dimnames(arr) <- list(cnames, NULL, NULL) + .Call(XGSetArrayDimNamesInplace_R, arr, list(cnames, NULL, NULL)) } else if (predinteraction) { - dimnames(arr) <- list(cnames, cnames, NULL, NULL) + .Call(XGSetArrayDimNamesInplace_R, arr, list(cnames, cnames, NULL, NULL)) } if (strict_shape) { return(arr) # strict shape is calculated by libxgboost uniformly. @@ -368,43 +368,51 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA if (predleaf) { ## Predict leaf - arr <- if (n_ret == n_row) { - matrix(arr, ncol = 1) + if (n_ret == n_row) { + .Call(XGSetArrayDimInplace_R, arr, c(n_row, 1L)) } else { - matrix(arr, nrow = n_row, byrow = TRUE) + arr <- matrix(arr, nrow = n_row, byrow = TRUE) } } else if (predcontrib) { ## Predict contribution arr <- aperm(a = arr, perm = c(2, 3, 1)) # [group, row, col] - arr <- if (n_ret == n_row) { - matrix(arr, ncol = 1, dimnames = list(NULL, cnames)) + if (n_ret == n_row) { + .Call(XGSetArrayDimInplace_R, arr, c(n_row, 1L)) + .Call(XGSetArrayDimNamesInplace_R, arr, list(NULL, cnames)) } else if (n_groups != 1) { ## turns array into list of matrices - lapply(seq_len(n_groups), function(g) arr[g, , ]) + arr <- lapply(seq_len(n_groups), function(g) arr[g, , ]) } else { ## remove the first axis (group) - dn <- dimnames(arr) - matrix(arr[1, , ], nrow = dim(arr)[2], ncol = dim(arr)[3], dimnames = c(dn[2], dn[3])) + newdim <- dim(arr)[2:3] + newdn <- dimnames(arr)[2:3] + arr <- arr[1, , ] + .Call(XGSetArrayDimInplace_R, arr, newdim) + .Call(XGSetArrayDimNamesInplace_R, arr, newdn) } } else if (predinteraction) { ## Predict interaction arr <- aperm(a = arr, perm = c(3, 4, 1, 2)) # [group, row, col, col] - arr <- if (n_ret == n_row) { - matrix(arr, ncol = 1, dimnames = list(NULL, cnames)) + if (n_ret == n_row) { + .Call(XGSetArrayDimInplace_R, arr, c(n_row, 1L)) + .Call(XGSetArrayDimNamesInplace_R, arr, list(NULL, cnames)) } else if (n_groups != 1) { ## turns array into list of matrices - lapply(seq_len(n_groups), function(g) arr[g, , , ]) + arr <- lapply(seq_len(n_groups), function(g) arr[g, , , ]) } else { ## remove the first axis (group) arr <- arr[1, , , , drop = FALSE] - array(arr, dim = dim(arr)[2:4], dimnames(arr)[2:4]) + newdim <- dim(arr)[2:4] + newdn <- dimnames(arr)[2:4] + .Call(XGSetArrayDimInplace_R, arr, newdim) + .Call(XGSetArrayDimNamesInplace_R, arr, newdn) } } else { ## Normal prediction - arr <- if (reshape && n_groups != 1) { - matrix(arr, ncol = n_groups, byrow = TRUE) + if (reshape && n_groups != 1) { + arr <- matrix(arr, ncol = n_groups, byrow = TRUE) } else { - as.vector(ret) + .Call(XGSetArrayDimInplace_R, arr, NULL) } } return(arr) diff --git a/R-package/src/init.c b/R-package/src/init.c index 81c28c401..dd3a1aa2f 100644 --- a/R-package/src/init.c +++ b/R-package/src/init.c @@ -42,6 +42,8 @@ extern SEXP XGBoosterSetAttr_R(SEXP, SEXP, SEXP); extern SEXP XGBoosterSetParam_R(SEXP, SEXP, SEXP); extern SEXP XGBoosterUpdateOneIter_R(SEXP, SEXP, SEXP); extern SEXP XGCheckNullPtr_R(SEXP); +extern SEXP XGSetArrayDimInplace_R(SEXP, SEXP); +extern SEXP XGSetArrayDimNamesInplace_R(SEXP, SEXP); extern SEXP XGDMatrixCreateFromCSC_R(SEXP, SEXP, SEXP, SEXP, SEXP, SEXP); extern SEXP XGDMatrixCreateFromCSR_R(SEXP, SEXP, SEXP, SEXP, SEXP, SEXP); extern SEXP XGDMatrixCreateFromFile_R(SEXP, SEXP); @@ -90,6 +92,8 @@ static const R_CallMethodDef CallEntries[] = { {"XGBoosterSetParam_R", (DL_FUNC) &XGBoosterSetParam_R, 3}, {"XGBoosterUpdateOneIter_R", (DL_FUNC) &XGBoosterUpdateOneIter_R, 3}, {"XGCheckNullPtr_R", (DL_FUNC) &XGCheckNullPtr_R, 1}, + {"XGSetArrayDimInplace_R", (DL_FUNC) &XGSetArrayDimInplace_R, 2}, + {"XGSetArrayDimNamesInplace_R", (DL_FUNC) &XGSetArrayDimNamesInplace_R, 2}, {"XGDMatrixCreateFromCSC_R", (DL_FUNC) &XGDMatrixCreateFromCSC_R, 6}, {"XGDMatrixCreateFromCSR_R", (DL_FUNC) &XGDMatrixCreateFromCSR_R, 6}, {"XGDMatrixCreateFromFile_R", (DL_FUNC) &XGDMatrixCreateFromFile_R, 2}, diff --git a/R-package/src/xgboost_R.cc b/R-package/src/xgboost_R.cc index 63f36ad6a..4a8710124 100644 --- a/R-package/src/xgboost_R.cc +++ b/R-package/src/xgboost_R.cc @@ -263,6 +263,16 @@ XGB_DLL SEXP XGCheckNullPtr_R(SEXP handle) { return Rf_ScalarLogical(R_ExternalPtrAddr(handle) == nullptr); } +XGB_DLL SEXP XGSetArrayDimInplace_R(SEXP arr, SEXP dims) { + Rf_setAttrib(arr, R_DimSymbol, dims); + return R_NilValue; +} + +XGB_DLL SEXP XGSetArrayDimNamesInplace_R(SEXP arr, SEXP dim_names) { + Rf_setAttrib(arr, R_DimNamesSymbol, dim_names); + return R_NilValue; +} + namespace { void _DMatrixFinalizer(SEXP ext) { R_API_BEGIN(); diff --git a/R-package/src/xgboost_R.h b/R-package/src/xgboost_R.h index 79d441792..e2688bf34 100644 --- a/R-package/src/xgboost_R.h +++ b/R-package/src/xgboost_R.h @@ -23,6 +23,22 @@ */ XGB_DLL SEXP XGCheckNullPtr_R(SEXP handle); +/*! + * \brief set the dimensions of an array in-place + * \param arr + * \param dims dimensions to set to the array + * \return NULL value + */ +XGB_DLL SEXP XGSetArrayDimInplace_R(SEXP arr, SEXP dims); + +/*! + * \brief set the names of the dimensions of an array in-place + * \param arr + * \param dim_names names for the dimensions to set + * \return NULL value + */ +XGB_DLL SEXP XGSetArrayDimNamesInplace_R(SEXP arr, SEXP dim_names); + /*! * \brief Set global configuration * \param json_str a JSON string representing the list of key-value pairs