[R] keep row names in predictions (#10727)
This commit is contained in:
@@ -678,7 +678,7 @@ test_that("Can predict on data.frame objects", {
|
||||
|
||||
pred_mat <- predict(model, xgb.DMatrix(x_mat))
|
||||
pred_df <- predict(model, x_df)
|
||||
expect_equal(pred_mat, pred_df)
|
||||
expect_equal(pred_mat, unname(pred_df))
|
||||
})
|
||||
|
||||
test_that("'base_margin' gives the same result in DMatrix as in inplace_predict", {
|
||||
@@ -702,7 +702,7 @@ test_that("'base_margin' gives the same result in DMatrix as in inplace_predict"
|
||||
pred_from_dm <- predict(model, dm_w_base)
|
||||
pred_from_mat <- predict(model, x, base_margin = base_margin)
|
||||
|
||||
expect_equal(pred_from_dm, pred_from_mat)
|
||||
expect_equal(pred_from_dm, unname(pred_from_mat))
|
||||
})
|
||||
|
||||
test_that("Coefficients from gblinear have the expected shape and names", {
|
||||
@@ -725,7 +725,7 @@ test_that("Coefficients from gblinear have the expected shape and names", {
|
||||
expect_equal(names(coefs), c("(Intercept)", colnames(x)))
|
||||
pred_auto <- predict(model, x)
|
||||
pred_manual <- as.numeric(mm %*% coefs)
|
||||
expect_equal(pred_manual, pred_auto, tolerance = 1e-5)
|
||||
expect_equal(pred_manual, unname(pred_auto), tolerance = 1e-5)
|
||||
|
||||
# Multi-column coefficients
|
||||
data(iris)
|
||||
@@ -949,3 +949,47 @@ test_that("xgb.cv works for ranking", {
|
||||
)
|
||||
expect_equal(length(res$folds), 2L)
|
||||
})
|
||||
|
||||
test_that("Row names are preserved in outputs", {
|
||||
data(iris)
|
||||
x <- iris[, -5]
|
||||
y <- as.numeric(iris$Species) - 1
|
||||
dm <- xgb.DMatrix(x, label = y, nthread = 1)
|
||||
model <- xgb.train(
|
||||
data = dm,
|
||||
params = list(
|
||||
objective = "multi:softprob",
|
||||
num_class = 3,
|
||||
max_depth = 2,
|
||||
nthread = 1
|
||||
),
|
||||
nrounds = 3
|
||||
)
|
||||
row.names(x) <- paste0("r", seq(1, nrow(x)))
|
||||
pred <- predict(model, x)
|
||||
expect_equal(row.names(pred), row.names(x))
|
||||
pred <- predict(model, x, avoid_transpose = TRUE)
|
||||
expect_equal(colnames(pred), row.names(x))
|
||||
|
||||
data(mtcars)
|
||||
y <- mtcars[, 1]
|
||||
x <- as.matrix(mtcars[, -1])
|
||||
dm <- xgb.DMatrix(data = x, label = y)
|
||||
model <- xgb.train(
|
||||
data = dm,
|
||||
params = list(
|
||||
max_depth = 2,
|
||||
nthread = 1
|
||||
),
|
||||
nrounds = 3
|
||||
)
|
||||
row.names(x) <- paste0("r", seq(1, nrow(x)))
|
||||
pred <- predict(model, x)
|
||||
expect_equal(names(pred), row.names(x))
|
||||
pred <- predict(model, x, avoid_transpose = TRUE)
|
||||
expect_equal(names(pred), row.names(x))
|
||||
pred <- predict(model, x, predleaf = TRUE)
|
||||
expect_equal(row.names(pred), row.names(x))
|
||||
pred <- predict(model, x, predleaf = TRUE, avoid_transpose = TRUE)
|
||||
expect_equal(colnames(pred), row.names(x))
|
||||
})
|
||||
|
||||
@@ -493,6 +493,7 @@ test_that("xgb.DMatrix: ExternalDMatrix produces the same results as regular DMa
|
||||
nrounds = 5
|
||||
)
|
||||
pred <- predict(model, x)
|
||||
pred <- unname(pred)
|
||||
|
||||
iterator_env <- as.environment(
|
||||
list(
|
||||
@@ -538,7 +539,7 @@ test_that("xgb.DMatrix: ExternalDMatrix produces the same results as regular DMa
|
||||
)
|
||||
|
||||
pred_model1_edm <- predict(model, edm)
|
||||
pred_model2_mat <- predict(model_ext, x)
|
||||
pred_model2_mat <- predict(model_ext, x) |> unname()
|
||||
pred_model2_edm <- predict(model_ext, edm)
|
||||
|
||||
expect_equal(pred_model1_edm, pred)
|
||||
@@ -567,6 +568,7 @@ test_that("xgb.DMatrix: External QDM produces same results as regular QDM", {
|
||||
nrounds = 5
|
||||
)
|
||||
pred <- predict(model, x)
|
||||
pred <- unname(pred)
|
||||
|
||||
iterator_env <- as.environment(
|
||||
list(
|
||||
@@ -616,7 +618,7 @@ test_that("xgb.DMatrix: External QDM produces same results as regular QDM", {
|
||||
)
|
||||
|
||||
pred_model1_qdm <- predict(model, qdm)
|
||||
pred_model2_mat <- predict(model_ext, x)
|
||||
pred_model2_mat <- predict(model_ext, x) |> unname()
|
||||
pred_model2_qdm <- predict(model_ext, qdm)
|
||||
|
||||
expect_equal(pred_model1_qdm, pred)
|
||||
|
||||
Reference in New Issue
Block a user