[R] Avoid memory copies in predict (#9902)
This commit is contained in:
parent
2c8fa8b8b9
commit
60b9d2eeb9
@ -343,24 +343,24 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
|
|||||||
)
|
)
|
||||||
names(predts) <- c("shape", "results")
|
names(predts) <- c("shape", "results")
|
||||||
shape <- predts$shape
|
shape <- predts$shape
|
||||||
ret <- predts$results
|
arr <- predts$results
|
||||||
|
|
||||||
n_ret <- length(ret)
|
n_ret <- length(arr)
|
||||||
n_row <- nrow(newdata)
|
n_row <- nrow(newdata)
|
||||||
if (n_row != shape[1]) {
|
if (n_row != shape[1]) {
|
||||||
stop("Incorrect predict shape.")
|
stop("Incorrect predict shape.")
|
||||||
}
|
}
|
||||||
|
|
||||||
arr <- array(data = ret, dim = rev(shape))
|
.Call(XGSetArrayDimInplace_R, arr, rev(shape))
|
||||||
|
|
||||||
cnames <- if (!is.null(colnames(newdata))) c(colnames(newdata), "BIAS") else NULL
|
cnames <- if (!is.null(colnames(newdata))) c(colnames(newdata), "BIAS") else NULL
|
||||||
n_groups <- shape[2]
|
n_groups <- shape[2]
|
||||||
|
|
||||||
## Needed regardless of whether strict shape is being used.
|
## Needed regardless of whether strict shape is being used.
|
||||||
if (predcontrib) {
|
if (predcontrib) {
|
||||||
dimnames(arr) <- list(cnames, NULL, NULL)
|
.Call(XGSetArrayDimNamesInplace_R, arr, list(cnames, NULL, NULL))
|
||||||
} else if (predinteraction) {
|
} else if (predinteraction) {
|
||||||
dimnames(arr) <- list(cnames, cnames, NULL, NULL)
|
.Call(XGSetArrayDimNamesInplace_R, arr, list(cnames, cnames, NULL, NULL))
|
||||||
}
|
}
|
||||||
if (strict_shape) {
|
if (strict_shape) {
|
||||||
return(arr) # strict shape is calculated by libxgboost uniformly.
|
return(arr) # strict shape is calculated by libxgboost uniformly.
|
||||||
@ -368,43 +368,51 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
|
|||||||
|
|
||||||
if (predleaf) {
|
if (predleaf) {
|
||||||
## Predict leaf
|
## Predict leaf
|
||||||
arr <- if (n_ret == n_row) {
|
if (n_ret == n_row) {
|
||||||
matrix(arr, ncol = 1)
|
.Call(XGSetArrayDimInplace_R, arr, c(n_row, 1L))
|
||||||
} else {
|
} else {
|
||||||
matrix(arr, nrow = n_row, byrow = TRUE)
|
arr <- matrix(arr, nrow = n_row, byrow = TRUE)
|
||||||
}
|
}
|
||||||
} else if (predcontrib) {
|
} else if (predcontrib) {
|
||||||
## Predict contribution
|
## Predict contribution
|
||||||
arr <- aperm(a = arr, perm = c(2, 3, 1)) # [group, row, col]
|
arr <- aperm(a = arr, perm = c(2, 3, 1)) # [group, row, col]
|
||||||
arr <- if (n_ret == n_row) {
|
if (n_ret == n_row) {
|
||||||
matrix(arr, ncol = 1, dimnames = list(NULL, cnames))
|
.Call(XGSetArrayDimInplace_R, arr, c(n_row, 1L))
|
||||||
|
.Call(XGSetArrayDimNamesInplace_R, arr, list(NULL, cnames))
|
||||||
} else if (n_groups != 1) {
|
} else if (n_groups != 1) {
|
||||||
## turns array into list of matrices
|
## turns array into list of matrices
|
||||||
lapply(seq_len(n_groups), function(g) arr[g, , ])
|
arr <- lapply(seq_len(n_groups), function(g) arr[g, , ])
|
||||||
} else {
|
} else {
|
||||||
## remove the first axis (group)
|
## remove the first axis (group)
|
||||||
dn <- dimnames(arr)
|
newdim <- dim(arr)[2:3]
|
||||||
matrix(arr[1, , ], nrow = dim(arr)[2], ncol = dim(arr)[3], dimnames = c(dn[2], dn[3]))
|
newdn <- dimnames(arr)[2:3]
|
||||||
|
arr <- arr[1, , ]
|
||||||
|
.Call(XGSetArrayDimInplace_R, arr, newdim)
|
||||||
|
.Call(XGSetArrayDimNamesInplace_R, arr, newdn)
|
||||||
}
|
}
|
||||||
} else if (predinteraction) {
|
} else if (predinteraction) {
|
||||||
## Predict interaction
|
## Predict interaction
|
||||||
arr <- aperm(a = arr, perm = c(3, 4, 1, 2)) # [group, row, col, col]
|
arr <- aperm(a = arr, perm = c(3, 4, 1, 2)) # [group, row, col, col]
|
||||||
arr <- if (n_ret == n_row) {
|
if (n_ret == n_row) {
|
||||||
matrix(arr, ncol = 1, dimnames = list(NULL, cnames))
|
.Call(XGSetArrayDimInplace_R, arr, c(n_row, 1L))
|
||||||
|
.Call(XGSetArrayDimNamesInplace_R, arr, list(NULL, cnames))
|
||||||
} else if (n_groups != 1) {
|
} else if (n_groups != 1) {
|
||||||
## turns array into list of matrices
|
## turns array into list of matrices
|
||||||
lapply(seq_len(n_groups), function(g) arr[g, , , ])
|
arr <- lapply(seq_len(n_groups), function(g) arr[g, , , ])
|
||||||
} else {
|
} else {
|
||||||
## remove the first axis (group)
|
## remove the first axis (group)
|
||||||
arr <- arr[1, , , , drop = FALSE]
|
arr <- arr[1, , , , drop = FALSE]
|
||||||
array(arr, dim = dim(arr)[2:4], dimnames(arr)[2:4])
|
newdim <- dim(arr)[2:4]
|
||||||
|
newdn <- dimnames(arr)[2:4]
|
||||||
|
.Call(XGSetArrayDimInplace_R, arr, newdim)
|
||||||
|
.Call(XGSetArrayDimNamesInplace_R, arr, newdn)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
## Normal prediction
|
## Normal prediction
|
||||||
arr <- if (reshape && n_groups != 1) {
|
if (reshape && n_groups != 1) {
|
||||||
matrix(arr, ncol = n_groups, byrow = TRUE)
|
arr <- matrix(arr, ncol = n_groups, byrow = TRUE)
|
||||||
} else {
|
} else {
|
||||||
as.vector(ret)
|
.Call(XGSetArrayDimInplace_R, arr, NULL)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return(arr)
|
return(arr)
|
||||||
|
|||||||
@ -42,6 +42,8 @@ extern SEXP XGBoosterSetAttr_R(SEXP, SEXP, SEXP);
|
|||||||
extern SEXP XGBoosterSetParam_R(SEXP, SEXP, SEXP);
|
extern SEXP XGBoosterSetParam_R(SEXP, SEXP, SEXP);
|
||||||
extern SEXP XGBoosterUpdateOneIter_R(SEXP, SEXP, SEXP);
|
extern SEXP XGBoosterUpdateOneIter_R(SEXP, SEXP, SEXP);
|
||||||
extern SEXP XGCheckNullPtr_R(SEXP);
|
extern SEXP XGCheckNullPtr_R(SEXP);
|
||||||
|
extern SEXP XGSetArrayDimInplace_R(SEXP, SEXP);
|
||||||
|
extern SEXP XGSetArrayDimNamesInplace_R(SEXP, SEXP);
|
||||||
extern SEXP XGDMatrixCreateFromCSC_R(SEXP, SEXP, SEXP, SEXP, SEXP, SEXP);
|
extern SEXP XGDMatrixCreateFromCSC_R(SEXP, SEXP, SEXP, SEXP, SEXP, SEXP);
|
||||||
extern SEXP XGDMatrixCreateFromCSR_R(SEXP, SEXP, SEXP, SEXP, SEXP, SEXP);
|
extern SEXP XGDMatrixCreateFromCSR_R(SEXP, SEXP, SEXP, SEXP, SEXP, SEXP);
|
||||||
extern SEXP XGDMatrixCreateFromFile_R(SEXP, SEXP);
|
extern SEXP XGDMatrixCreateFromFile_R(SEXP, SEXP);
|
||||||
@ -90,6 +92,8 @@ static const R_CallMethodDef CallEntries[] = {
|
|||||||
{"XGBoosterSetParam_R", (DL_FUNC) &XGBoosterSetParam_R, 3},
|
{"XGBoosterSetParam_R", (DL_FUNC) &XGBoosterSetParam_R, 3},
|
||||||
{"XGBoosterUpdateOneIter_R", (DL_FUNC) &XGBoosterUpdateOneIter_R, 3},
|
{"XGBoosterUpdateOneIter_R", (DL_FUNC) &XGBoosterUpdateOneIter_R, 3},
|
||||||
{"XGCheckNullPtr_R", (DL_FUNC) &XGCheckNullPtr_R, 1},
|
{"XGCheckNullPtr_R", (DL_FUNC) &XGCheckNullPtr_R, 1},
|
||||||
|
{"XGSetArrayDimInplace_R", (DL_FUNC) &XGSetArrayDimInplace_R, 2},
|
||||||
|
{"XGSetArrayDimNamesInplace_R", (DL_FUNC) &XGSetArrayDimNamesInplace_R, 2},
|
||||||
{"XGDMatrixCreateFromCSC_R", (DL_FUNC) &XGDMatrixCreateFromCSC_R, 6},
|
{"XGDMatrixCreateFromCSC_R", (DL_FUNC) &XGDMatrixCreateFromCSC_R, 6},
|
||||||
{"XGDMatrixCreateFromCSR_R", (DL_FUNC) &XGDMatrixCreateFromCSR_R, 6},
|
{"XGDMatrixCreateFromCSR_R", (DL_FUNC) &XGDMatrixCreateFromCSR_R, 6},
|
||||||
{"XGDMatrixCreateFromFile_R", (DL_FUNC) &XGDMatrixCreateFromFile_R, 2},
|
{"XGDMatrixCreateFromFile_R", (DL_FUNC) &XGDMatrixCreateFromFile_R, 2},
|
||||||
|
|||||||
@ -263,6 +263,16 @@ XGB_DLL SEXP XGCheckNullPtr_R(SEXP handle) {
|
|||||||
return Rf_ScalarLogical(R_ExternalPtrAddr(handle) == nullptr);
|
return Rf_ScalarLogical(R_ExternalPtrAddr(handle) == nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XGB_DLL SEXP XGSetArrayDimInplace_R(SEXP arr, SEXP dims) {
|
||||||
|
Rf_setAttrib(arr, R_DimSymbol, dims);
|
||||||
|
return R_NilValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
XGB_DLL SEXP XGSetArrayDimNamesInplace_R(SEXP arr, SEXP dim_names) {
|
||||||
|
Rf_setAttrib(arr, R_DimNamesSymbol, dim_names);
|
||||||
|
return R_NilValue;
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
void _DMatrixFinalizer(SEXP ext) {
|
void _DMatrixFinalizer(SEXP ext) {
|
||||||
R_API_BEGIN();
|
R_API_BEGIN();
|
||||||
|
|||||||
@ -23,6 +23,22 @@
|
|||||||
*/
|
*/
|
||||||
XGB_DLL SEXP XGCheckNullPtr_R(SEXP handle);
|
XGB_DLL SEXP XGCheckNullPtr_R(SEXP handle);
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief set the dimensions of an array in-place
|
||||||
|
* \param arr
|
||||||
|
* \param dims dimensions to set to the array
|
||||||
|
* \return NULL value
|
||||||
|
*/
|
||||||
|
XGB_DLL SEXP XGSetArrayDimInplace_R(SEXP arr, SEXP dims);
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief set the names of the dimensions of an array in-place
|
||||||
|
* \param arr
|
||||||
|
* \param dim_names names for the dimensions to set
|
||||||
|
* \return NULL value
|
||||||
|
*/
|
||||||
|
XGB_DLL SEXP XGSetArrayDimNamesInplace_R(SEXP arr, SEXP dim_names);
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief Set global configuration
|
* \brief Set global configuration
|
||||||
* \param json_str a JSON string representing the list of key-value pairs
|
* \param json_str a JSON string representing the list of key-value pairs
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user