From e9f1abc1f0637c13923c8282de54c4298a43c533 Mon Sep 17 00:00:00 2001 From: david-cortes Date: Tue, 20 Aug 2024 23:49:02 +0200 Subject: [PATCH] [R] keep row names in predictions (#10727) --- R-package/R/xgb.Booster.R | 18 +++++++++ R-package/src/init.c | 2 + R-package/src/xgboost_R.cc | 5 +++ R-package/src/xgboost_R.h | 8 ++++ R-package/tests/testthat/test_basic.R | 50 +++++++++++++++++++++++-- R-package/tests/testthat/test_dmatrix.R | 6 ++- 6 files changed, 84 insertions(+), 5 deletions(-) diff --git a/R-package/R/xgb.Booster.R b/R-package/R/xgb.Booster.R index 0e6313d88..a15285091 100644 --- a/R-package/R/xgb.Booster.R +++ b/R-package/R/xgb.Booster.R @@ -354,6 +354,11 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA " Should be passed as argument to 'xgb.DMatrix' constructor." ) } + if (is_dmatrix) { + rnames <- NULL + } else { + rnames <- row.names(newdata) + } use_as_df <- FALSE use_as_dense_matrix <- FALSE @@ -501,6 +506,19 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA .Call(XGSetArrayDimNamesInplace_R, arr, dim_names) } + if (NROW(rnames)) { + if (is.null(dim(arr))) { + .Call(XGSetVectorNamesInplace_R, arr, rnames) + } else { + dim_names <- dimnames(arr) + if (is.null(dim_names)) { + dim_names <- vector(mode = "list", length = length(dim(arr))) + } + dim_names[[length(dim_names)]] <- rnames + .Call(XGSetArrayDimNamesInplace_R, arr, dim_names) + } + } + if (!avoid_transpose && is.array(arr)) { arr <- aperm(arr) } diff --git a/R-package/src/init.c b/R-package/src/init.c index 16c1d3b14..523e5118a 100644 --- a/R-package/src/init.c +++ b/R-package/src/init.c @@ -46,6 +46,7 @@ extern SEXP XGBoosterSetParam_R(SEXP, SEXP, SEXP); extern SEXP XGBoosterUpdateOneIter_R(SEXP, SEXP, SEXP); extern SEXP XGCheckNullPtr_R(SEXP); extern SEXP XGSetArrayDimNamesInplace_R(SEXP, SEXP); +extern SEXP XGSetVectorNamesInplace_R(SEXP, SEXP); extern SEXP XGDMatrixCreateFromCSC_R(SEXP, SEXP, SEXP, SEXP, SEXP, SEXP); extern SEXP XGDMatrixCreateFromCSR_R(SEXP, SEXP, SEXP, SEXP, SEXP, SEXP); extern SEXP XGDMatrixCreateFromURI_R(SEXP, SEXP, SEXP); @@ -108,6 +109,7 @@ static const R_CallMethodDef CallEntries[] = { {"XGBoosterUpdateOneIter_R", (DL_FUNC) &XGBoosterUpdateOneIter_R, 3}, {"XGCheckNullPtr_R", (DL_FUNC) &XGCheckNullPtr_R, 1}, {"XGSetArrayDimNamesInplace_R", (DL_FUNC) &XGSetArrayDimNamesInplace_R, 2}, + {"XGSetVectorNamesInplace_R", (DL_FUNC) &XGSetVectorNamesInplace_R, 2}, {"XGDMatrixCreateFromCSC_R", (DL_FUNC) &XGDMatrixCreateFromCSC_R, 6}, {"XGDMatrixCreateFromCSR_R", (DL_FUNC) &XGDMatrixCreateFromCSR_R, 6}, {"XGDMatrixCreateFromURI_R", (DL_FUNC) &XGDMatrixCreateFromURI_R, 3}, diff --git a/R-package/src/xgboost_R.cc b/R-package/src/xgboost_R.cc index 5faae8a9f..0e7234a18 100644 --- a/R-package/src/xgboost_R.cc +++ b/R-package/src/xgboost_R.cc @@ -335,6 +335,11 @@ XGB_DLL SEXP XGSetArrayDimNamesInplace_R(SEXP arr, SEXP dim_names) { return R_NilValue; } +XGB_DLL SEXP XGSetVectorNamesInplace_R(SEXP arr, SEXP names) { + Rf_setAttrib(arr, R_NamesSymbol, names); + return R_NilValue; +} + namespace { void _DMatrixFinalizer(SEXP ext) { R_API_BEGIN(); diff --git a/R-package/src/xgboost_R.h b/R-package/src/xgboost_R.h index 08f16bac1..bfccd9f15 100644 --- a/R-package/src/xgboost_R.h +++ b/R-package/src/xgboost_R.h @@ -34,6 +34,14 @@ XGB_DLL SEXP XGCheckNullPtr_R(SEXP handle); */ XGB_DLL SEXP XGSetArrayDimNamesInplace_R(SEXP arr, SEXP dim_names); +/*! + * \brief set the names of a vector in-place + * \param arr + * \param names names for the dimensions to set + * \return NULL value + */ +XGB_DLL SEXP XGSetVectorNamesInplace_R(SEXP arr, SEXP names); + /*! * \brief Set global configuration * \param json_str a JSON string representing the list of key-value pairs diff --git a/R-package/tests/testthat/test_basic.R b/R-package/tests/testthat/test_basic.R index f0ebd7a1c..840ff2635 100644 --- a/R-package/tests/testthat/test_basic.R +++ b/R-package/tests/testthat/test_basic.R @@ -678,7 +678,7 @@ test_that("Can predict on data.frame objects", { pred_mat <- predict(model, xgb.DMatrix(x_mat)) pred_df <- predict(model, x_df) - expect_equal(pred_mat, pred_df) + expect_equal(pred_mat, unname(pred_df)) }) test_that("'base_margin' gives the same result in DMatrix as in inplace_predict", { @@ -702,7 +702,7 @@ test_that("'base_margin' gives the same result in DMatrix as in inplace_predict" pred_from_dm <- predict(model, dm_w_base) pred_from_mat <- predict(model, x, base_margin = base_margin) - expect_equal(pred_from_dm, pred_from_mat) + expect_equal(pred_from_dm, unname(pred_from_mat)) }) test_that("Coefficients from gblinear have the expected shape and names", { @@ -725,7 +725,7 @@ test_that("Coefficients from gblinear have the expected shape and names", { expect_equal(names(coefs), c("(Intercept)", colnames(x))) pred_auto <- predict(model, x) pred_manual <- as.numeric(mm %*% coefs) - expect_equal(pred_manual, pred_auto, tolerance = 1e-5) + expect_equal(pred_manual, unname(pred_auto), tolerance = 1e-5) # Multi-column coefficients data(iris) @@ -949,3 +949,47 @@ test_that("xgb.cv works for ranking", { ) expect_equal(length(res$folds), 2L) }) + +test_that("Row names are preserved in outputs", { + data(iris) + x <- iris[, -5] + y <- as.numeric(iris$Species) - 1 + dm <- xgb.DMatrix(x, label = y, nthread = 1) + model <- xgb.train( + data = dm, + params = list( + objective = "multi:softprob", + num_class = 3, + max_depth = 2, + nthread = 1 + ), + nrounds = 3 + ) + row.names(x) <- paste0("r", seq(1, nrow(x))) + pred <- predict(model, x) + expect_equal(row.names(pred), row.names(x)) + pred <- predict(model, x, avoid_transpose = TRUE) + expect_equal(colnames(pred), row.names(x)) + + data(mtcars) + y <- mtcars[, 1] + x <- as.matrix(mtcars[, -1]) + dm <- xgb.DMatrix(data = x, label = y) + model <- xgb.train( + data = dm, + params = list( + max_depth = 2, + nthread = 1 + ), + nrounds = 3 + ) + row.names(x) <- paste0("r", seq(1, nrow(x))) + pred <- predict(model, x) + expect_equal(names(pred), row.names(x)) + pred <- predict(model, x, avoid_transpose = TRUE) + expect_equal(names(pred), row.names(x)) + pred <- predict(model, x, predleaf = TRUE) + expect_equal(row.names(pred), row.names(x)) + pred <- predict(model, x, predleaf = TRUE, avoid_transpose = TRUE) + expect_equal(colnames(pred), row.names(x)) +}) diff --git a/R-package/tests/testthat/test_dmatrix.R b/R-package/tests/testthat/test_dmatrix.R index cca7b88da..887f602be 100644 --- a/R-package/tests/testthat/test_dmatrix.R +++ b/R-package/tests/testthat/test_dmatrix.R @@ -493,6 +493,7 @@ test_that("xgb.DMatrix: ExternalDMatrix produces the same results as regular DMa nrounds = 5 ) pred <- predict(model, x) + pred <- unname(pred) iterator_env <- as.environment( list( @@ -538,7 +539,7 @@ test_that("xgb.DMatrix: ExternalDMatrix produces the same results as regular DMa ) pred_model1_edm <- predict(model, edm) - pred_model2_mat <- predict(model_ext, x) + pred_model2_mat <- predict(model_ext, x) |> unname() pred_model2_edm <- predict(model_ext, edm) expect_equal(pred_model1_edm, pred) @@ -567,6 +568,7 @@ test_that("xgb.DMatrix: External QDM produces same results as regular QDM", { nrounds = 5 ) pred <- predict(model, x) + pred <- unname(pred) iterator_env <- as.environment( list( @@ -616,7 +618,7 @@ test_that("xgb.DMatrix: External QDM produces same results as regular QDM", { ) pred_model1_qdm <- predict(model, qdm) - pred_model2_mat <- predict(model_ext, x) + pred_model2_mat <- predict(model_ext, x) |> unname() pred_model2_qdm <- predict(model_ext, qdm) expect_equal(pred_model1_qdm, pred)