[R] keep row names in predictions (#10727)

This commit is contained in:
david-cortes 2024-08-20 23:49:02 +02:00 committed by GitHub
parent adf87b27c5
commit e9f1abc1f0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 84 additions and 5 deletions

View File

@ -354,6 +354,11 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
" Should be passed as argument to 'xgb.DMatrix' constructor." " Should be passed as argument to 'xgb.DMatrix' constructor."
) )
} }
if (is_dmatrix) {
rnames <- NULL
} else {
rnames <- row.names(newdata)
}
use_as_df <- FALSE use_as_df <- FALSE
use_as_dense_matrix <- FALSE use_as_dense_matrix <- FALSE
@ -501,6 +506,19 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
.Call(XGSetArrayDimNamesInplace_R, arr, dim_names) .Call(XGSetArrayDimNamesInplace_R, arr, dim_names)
} }
if (NROW(rnames)) {
if (is.null(dim(arr))) {
.Call(XGSetVectorNamesInplace_R, arr, rnames)
} else {
dim_names <- dimnames(arr)
if (is.null(dim_names)) {
dim_names <- vector(mode = "list", length = length(dim(arr)))
}
dim_names[[length(dim_names)]] <- rnames
.Call(XGSetArrayDimNamesInplace_R, arr, dim_names)
}
}
if (!avoid_transpose && is.array(arr)) { if (!avoid_transpose && is.array(arr)) {
arr <- aperm(arr) arr <- aperm(arr)
} }

View File

@ -46,6 +46,7 @@ extern SEXP XGBoosterSetParam_R(SEXP, SEXP, SEXP);
extern SEXP XGBoosterUpdateOneIter_R(SEXP, SEXP, SEXP); extern SEXP XGBoosterUpdateOneIter_R(SEXP, SEXP, SEXP);
extern SEXP XGCheckNullPtr_R(SEXP); extern SEXP XGCheckNullPtr_R(SEXP);
extern SEXP XGSetArrayDimNamesInplace_R(SEXP, SEXP); extern SEXP XGSetArrayDimNamesInplace_R(SEXP, SEXP);
extern SEXP XGSetVectorNamesInplace_R(SEXP, SEXP);
extern SEXP XGDMatrixCreateFromCSC_R(SEXP, SEXP, SEXP, SEXP, SEXP, SEXP); extern SEXP XGDMatrixCreateFromCSC_R(SEXP, SEXP, SEXP, SEXP, SEXP, SEXP);
extern SEXP XGDMatrixCreateFromCSR_R(SEXP, SEXP, SEXP, SEXP, SEXP, SEXP); extern SEXP XGDMatrixCreateFromCSR_R(SEXP, SEXP, SEXP, SEXP, SEXP, SEXP);
extern SEXP XGDMatrixCreateFromURI_R(SEXP, SEXP, SEXP); extern SEXP XGDMatrixCreateFromURI_R(SEXP, SEXP, SEXP);
@ -108,6 +109,7 @@ static const R_CallMethodDef CallEntries[] = {
{"XGBoosterUpdateOneIter_R", (DL_FUNC) &XGBoosterUpdateOneIter_R, 3}, {"XGBoosterUpdateOneIter_R", (DL_FUNC) &XGBoosterUpdateOneIter_R, 3},
{"XGCheckNullPtr_R", (DL_FUNC) &XGCheckNullPtr_R, 1}, {"XGCheckNullPtr_R", (DL_FUNC) &XGCheckNullPtr_R, 1},
{"XGSetArrayDimNamesInplace_R", (DL_FUNC) &XGSetArrayDimNamesInplace_R, 2}, {"XGSetArrayDimNamesInplace_R", (DL_FUNC) &XGSetArrayDimNamesInplace_R, 2},
{"XGSetVectorNamesInplace_R", (DL_FUNC) &XGSetVectorNamesInplace_R, 2},
{"XGDMatrixCreateFromCSC_R", (DL_FUNC) &XGDMatrixCreateFromCSC_R, 6}, {"XGDMatrixCreateFromCSC_R", (DL_FUNC) &XGDMatrixCreateFromCSC_R, 6},
{"XGDMatrixCreateFromCSR_R", (DL_FUNC) &XGDMatrixCreateFromCSR_R, 6}, {"XGDMatrixCreateFromCSR_R", (DL_FUNC) &XGDMatrixCreateFromCSR_R, 6},
{"XGDMatrixCreateFromURI_R", (DL_FUNC) &XGDMatrixCreateFromURI_R, 3}, {"XGDMatrixCreateFromURI_R", (DL_FUNC) &XGDMatrixCreateFromURI_R, 3},

View File

@ -335,6 +335,11 @@ XGB_DLL SEXP XGSetArrayDimNamesInplace_R(SEXP arr, SEXP dim_names) {
return R_NilValue; return R_NilValue;
} }
XGB_DLL SEXP XGSetVectorNamesInplace_R(SEXP arr, SEXP names) {
Rf_setAttrib(arr, R_NamesSymbol, names);
return R_NilValue;
}
namespace { namespace {
void _DMatrixFinalizer(SEXP ext) { void _DMatrixFinalizer(SEXP ext) {
R_API_BEGIN(); R_API_BEGIN();

View File

@ -34,6 +34,14 @@ XGB_DLL SEXP XGCheckNullPtr_R(SEXP handle);
*/ */
XGB_DLL SEXP XGSetArrayDimNamesInplace_R(SEXP arr, SEXP dim_names); XGB_DLL SEXP XGSetArrayDimNamesInplace_R(SEXP arr, SEXP dim_names);
/*!
* \brief set the names of a vector in-place
* \param arr
* \param names names for the dimensions to set
* \return NULL value
*/
XGB_DLL SEXP XGSetVectorNamesInplace_R(SEXP arr, SEXP names);
/*! /*!
* \brief Set global configuration * \brief Set global configuration
* \param json_str a JSON string representing the list of key-value pairs * \param json_str a JSON string representing the list of key-value pairs

View File

@ -678,7 +678,7 @@ test_that("Can predict on data.frame objects", {
pred_mat <- predict(model, xgb.DMatrix(x_mat)) pred_mat <- predict(model, xgb.DMatrix(x_mat))
pred_df <- predict(model, x_df) pred_df <- predict(model, x_df)
expect_equal(pred_mat, pred_df) expect_equal(pred_mat, unname(pred_df))
}) })
test_that("'base_margin' gives the same result in DMatrix as in inplace_predict", { test_that("'base_margin' gives the same result in DMatrix as in inplace_predict", {
@ -702,7 +702,7 @@ test_that("'base_margin' gives the same result in DMatrix as in inplace_predict"
pred_from_dm <- predict(model, dm_w_base) pred_from_dm <- predict(model, dm_w_base)
pred_from_mat <- predict(model, x, base_margin = base_margin) pred_from_mat <- predict(model, x, base_margin = base_margin)
expect_equal(pred_from_dm, pred_from_mat) expect_equal(pred_from_dm, unname(pred_from_mat))
}) })
test_that("Coefficients from gblinear have the expected shape and names", { test_that("Coefficients from gblinear have the expected shape and names", {
@ -725,7 +725,7 @@ test_that("Coefficients from gblinear have the expected shape and names", {
expect_equal(names(coefs), c("(Intercept)", colnames(x))) expect_equal(names(coefs), c("(Intercept)", colnames(x)))
pred_auto <- predict(model, x) pred_auto <- predict(model, x)
pred_manual <- as.numeric(mm %*% coefs) pred_manual <- as.numeric(mm %*% coefs)
expect_equal(pred_manual, pred_auto, tolerance = 1e-5) expect_equal(pred_manual, unname(pred_auto), tolerance = 1e-5)
# Multi-column coefficients # Multi-column coefficients
data(iris) data(iris)
@ -949,3 +949,47 @@ test_that("xgb.cv works for ranking", {
) )
expect_equal(length(res$folds), 2L) expect_equal(length(res$folds), 2L)
}) })
test_that("Row names are preserved in outputs", {
data(iris)
x <- iris[, -5]
y <- as.numeric(iris$Species) - 1
dm <- xgb.DMatrix(x, label = y, nthread = 1)
model <- xgb.train(
data = dm,
params = list(
objective = "multi:softprob",
num_class = 3,
max_depth = 2,
nthread = 1
),
nrounds = 3
)
row.names(x) <- paste0("r", seq(1, nrow(x)))
pred <- predict(model, x)
expect_equal(row.names(pred), row.names(x))
pred <- predict(model, x, avoid_transpose = TRUE)
expect_equal(colnames(pred), row.names(x))
data(mtcars)
y <- mtcars[, 1]
x <- as.matrix(mtcars[, -1])
dm <- xgb.DMatrix(data = x, label = y)
model <- xgb.train(
data = dm,
params = list(
max_depth = 2,
nthread = 1
),
nrounds = 3
)
row.names(x) <- paste0("r", seq(1, nrow(x)))
pred <- predict(model, x)
expect_equal(names(pred), row.names(x))
pred <- predict(model, x, avoid_transpose = TRUE)
expect_equal(names(pred), row.names(x))
pred <- predict(model, x, predleaf = TRUE)
expect_equal(row.names(pred), row.names(x))
pred <- predict(model, x, predleaf = TRUE, avoid_transpose = TRUE)
expect_equal(colnames(pred), row.names(x))
})

View File

@ -493,6 +493,7 @@ test_that("xgb.DMatrix: ExternalDMatrix produces the same results as regular DMa
nrounds = 5 nrounds = 5
) )
pred <- predict(model, x) pred <- predict(model, x)
pred <- unname(pred)
iterator_env <- as.environment( iterator_env <- as.environment(
list( list(
@ -538,7 +539,7 @@ test_that("xgb.DMatrix: ExternalDMatrix produces the same results as regular DMa
) )
pred_model1_edm <- predict(model, edm) pred_model1_edm <- predict(model, edm)
pred_model2_mat <- predict(model_ext, x) pred_model2_mat <- predict(model_ext, x) |> unname()
pred_model2_edm <- predict(model_ext, edm) pred_model2_edm <- predict(model_ext, edm)
expect_equal(pred_model1_edm, pred) expect_equal(pred_model1_edm, pred)
@ -567,6 +568,7 @@ test_that("xgb.DMatrix: External QDM produces same results as regular QDM", {
nrounds = 5 nrounds = 5
) )
pred <- predict(model, x) pred <- predict(model, x)
pred <- unname(pred)
iterator_env <- as.environment( iterator_env <- as.environment(
list( list(
@ -616,7 +618,7 @@ test_that("xgb.DMatrix: External QDM produces same results as regular QDM", {
) )
pred_model1_qdm <- predict(model, qdm) pred_model1_qdm <- predict(model, qdm)
pred_model2_mat <- predict(model_ext, x) pred_model2_mat <- predict(model_ext, x) |> unname()
pred_model2_qdm <- predict(model_ext, qdm) pred_model2_qdm <- predict(model_ext, qdm)
expect_equal(pred_model1_qdm, pred) expect_equal(pred_model1_qdm, pred)