[R] keep row names in predictions (#10727)
This commit is contained in:
parent
adf87b27c5
commit
e9f1abc1f0
@ -354,6 +354,11 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
|
||||
" Should be passed as argument to 'xgb.DMatrix' constructor."
|
||||
)
|
||||
}
|
||||
if (is_dmatrix) {
|
||||
rnames <- NULL
|
||||
} else {
|
||||
rnames <- row.names(newdata)
|
||||
}
|
||||
|
||||
use_as_df <- FALSE
|
||||
use_as_dense_matrix <- FALSE
|
||||
@ -501,6 +506,19 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
|
||||
.Call(XGSetArrayDimNamesInplace_R, arr, dim_names)
|
||||
}
|
||||
|
||||
if (NROW(rnames)) {
|
||||
if (is.null(dim(arr))) {
|
||||
.Call(XGSetVectorNamesInplace_R, arr, rnames)
|
||||
} else {
|
||||
dim_names <- dimnames(arr)
|
||||
if (is.null(dim_names)) {
|
||||
dim_names <- vector(mode = "list", length = length(dim(arr)))
|
||||
}
|
||||
dim_names[[length(dim_names)]] <- rnames
|
||||
.Call(XGSetArrayDimNamesInplace_R, arr, dim_names)
|
||||
}
|
||||
}
|
||||
|
||||
if (!avoid_transpose && is.array(arr)) {
|
||||
arr <- aperm(arr)
|
||||
}
|
||||
|
||||
@ -46,6 +46,7 @@ extern SEXP XGBoosterSetParam_R(SEXP, SEXP, SEXP);
|
||||
extern SEXP XGBoosterUpdateOneIter_R(SEXP, SEXP, SEXP);
|
||||
extern SEXP XGCheckNullPtr_R(SEXP);
|
||||
extern SEXP XGSetArrayDimNamesInplace_R(SEXP, SEXP);
|
||||
extern SEXP XGSetVectorNamesInplace_R(SEXP, SEXP);
|
||||
extern SEXP XGDMatrixCreateFromCSC_R(SEXP, SEXP, SEXP, SEXP, SEXP, SEXP);
|
||||
extern SEXP XGDMatrixCreateFromCSR_R(SEXP, SEXP, SEXP, SEXP, SEXP, SEXP);
|
||||
extern SEXP XGDMatrixCreateFromURI_R(SEXP, SEXP, SEXP);
|
||||
@ -108,6 +109,7 @@ static const R_CallMethodDef CallEntries[] = {
|
||||
{"XGBoosterUpdateOneIter_R", (DL_FUNC) &XGBoosterUpdateOneIter_R, 3},
|
||||
{"XGCheckNullPtr_R", (DL_FUNC) &XGCheckNullPtr_R, 1},
|
||||
{"XGSetArrayDimNamesInplace_R", (DL_FUNC) &XGSetArrayDimNamesInplace_R, 2},
|
||||
{"XGSetVectorNamesInplace_R", (DL_FUNC) &XGSetVectorNamesInplace_R, 2},
|
||||
{"XGDMatrixCreateFromCSC_R", (DL_FUNC) &XGDMatrixCreateFromCSC_R, 6},
|
||||
{"XGDMatrixCreateFromCSR_R", (DL_FUNC) &XGDMatrixCreateFromCSR_R, 6},
|
||||
{"XGDMatrixCreateFromURI_R", (DL_FUNC) &XGDMatrixCreateFromURI_R, 3},
|
||||
|
||||
@ -335,6 +335,11 @@ XGB_DLL SEXP XGSetArrayDimNamesInplace_R(SEXP arr, SEXP dim_names) {
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
XGB_DLL SEXP XGSetVectorNamesInplace_R(SEXP arr, SEXP names) {
|
||||
Rf_setAttrib(arr, R_NamesSymbol, names);
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
namespace {
|
||||
void _DMatrixFinalizer(SEXP ext) {
|
||||
R_API_BEGIN();
|
||||
|
||||
@ -34,6 +34,14 @@ XGB_DLL SEXP XGCheckNullPtr_R(SEXP handle);
|
||||
*/
|
||||
XGB_DLL SEXP XGSetArrayDimNamesInplace_R(SEXP arr, SEXP dim_names);
|
||||
|
||||
/*!
|
||||
* \brief set the names of a vector in-place
|
||||
* \param arr
|
||||
* \param names names for the dimensions to set
|
||||
* \return NULL value
|
||||
*/
|
||||
XGB_DLL SEXP XGSetVectorNamesInplace_R(SEXP arr, SEXP names);
|
||||
|
||||
/*!
|
||||
* \brief Set global configuration
|
||||
* \param json_str a JSON string representing the list of key-value pairs
|
||||
|
||||
@ -678,7 +678,7 @@ test_that("Can predict on data.frame objects", {
|
||||
|
||||
pred_mat <- predict(model, xgb.DMatrix(x_mat))
|
||||
pred_df <- predict(model, x_df)
|
||||
expect_equal(pred_mat, pred_df)
|
||||
expect_equal(pred_mat, unname(pred_df))
|
||||
})
|
||||
|
||||
test_that("'base_margin' gives the same result in DMatrix as in inplace_predict", {
|
||||
@ -702,7 +702,7 @@ test_that("'base_margin' gives the same result in DMatrix as in inplace_predict"
|
||||
pred_from_dm <- predict(model, dm_w_base)
|
||||
pred_from_mat <- predict(model, x, base_margin = base_margin)
|
||||
|
||||
expect_equal(pred_from_dm, pred_from_mat)
|
||||
expect_equal(pred_from_dm, unname(pred_from_mat))
|
||||
})
|
||||
|
||||
test_that("Coefficients from gblinear have the expected shape and names", {
|
||||
@ -725,7 +725,7 @@ test_that("Coefficients from gblinear have the expected shape and names", {
|
||||
expect_equal(names(coefs), c("(Intercept)", colnames(x)))
|
||||
pred_auto <- predict(model, x)
|
||||
pred_manual <- as.numeric(mm %*% coefs)
|
||||
expect_equal(pred_manual, pred_auto, tolerance = 1e-5)
|
||||
expect_equal(pred_manual, unname(pred_auto), tolerance = 1e-5)
|
||||
|
||||
# Multi-column coefficients
|
||||
data(iris)
|
||||
@ -949,3 +949,47 @@ test_that("xgb.cv works for ranking", {
|
||||
)
|
||||
expect_equal(length(res$folds), 2L)
|
||||
})
|
||||
|
||||
test_that("Row names are preserved in outputs", {
|
||||
data(iris)
|
||||
x <- iris[, -5]
|
||||
y <- as.numeric(iris$Species) - 1
|
||||
dm <- xgb.DMatrix(x, label = y, nthread = 1)
|
||||
model <- xgb.train(
|
||||
data = dm,
|
||||
params = list(
|
||||
objective = "multi:softprob",
|
||||
num_class = 3,
|
||||
max_depth = 2,
|
||||
nthread = 1
|
||||
),
|
||||
nrounds = 3
|
||||
)
|
||||
row.names(x) <- paste0("r", seq(1, nrow(x)))
|
||||
pred <- predict(model, x)
|
||||
expect_equal(row.names(pred), row.names(x))
|
||||
pred <- predict(model, x, avoid_transpose = TRUE)
|
||||
expect_equal(colnames(pred), row.names(x))
|
||||
|
||||
data(mtcars)
|
||||
y <- mtcars[, 1]
|
||||
x <- as.matrix(mtcars[, -1])
|
||||
dm <- xgb.DMatrix(data = x, label = y)
|
||||
model <- xgb.train(
|
||||
data = dm,
|
||||
params = list(
|
||||
max_depth = 2,
|
||||
nthread = 1
|
||||
),
|
||||
nrounds = 3
|
||||
)
|
||||
row.names(x) <- paste0("r", seq(1, nrow(x)))
|
||||
pred <- predict(model, x)
|
||||
expect_equal(names(pred), row.names(x))
|
||||
pred <- predict(model, x, avoid_transpose = TRUE)
|
||||
expect_equal(names(pred), row.names(x))
|
||||
pred <- predict(model, x, predleaf = TRUE)
|
||||
expect_equal(row.names(pred), row.names(x))
|
||||
pred <- predict(model, x, predleaf = TRUE, avoid_transpose = TRUE)
|
||||
expect_equal(colnames(pred), row.names(x))
|
||||
})
|
||||
|
||||
@ -493,6 +493,7 @@ test_that("xgb.DMatrix: ExternalDMatrix produces the same results as regular DMa
|
||||
nrounds = 5
|
||||
)
|
||||
pred <- predict(model, x)
|
||||
pred <- unname(pred)
|
||||
|
||||
iterator_env <- as.environment(
|
||||
list(
|
||||
@ -538,7 +539,7 @@ test_that("xgb.DMatrix: ExternalDMatrix produces the same results as regular DMa
|
||||
)
|
||||
|
||||
pred_model1_edm <- predict(model, edm)
|
||||
pred_model2_mat <- predict(model_ext, x)
|
||||
pred_model2_mat <- predict(model_ext, x) |> unname()
|
||||
pred_model2_edm <- predict(model_ext, edm)
|
||||
|
||||
expect_equal(pred_model1_edm, pred)
|
||||
@ -567,6 +568,7 @@ test_that("xgb.DMatrix: External QDM produces same results as regular QDM", {
|
||||
nrounds = 5
|
||||
)
|
||||
pred <- predict(model, x)
|
||||
pred <- unname(pred)
|
||||
|
||||
iterator_env <- as.environment(
|
||||
list(
|
||||
@ -616,7 +618,7 @@ test_that("xgb.DMatrix: External QDM produces same results as regular QDM", {
|
||||
)
|
||||
|
||||
pred_model1_qdm <- predict(model, qdm)
|
||||
pred_model2_mat <- predict(model_ext, x)
|
||||
pred_model2_mat <- predict(model_ext, x) |> unname()
|
||||
pred_model2_qdm <- predict(model_ext, qdm)
|
||||
|
||||
expect_equal(pred_model1_qdm, pred)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user