From 479ae8081b9c3009138e1afbfe1300c0d22a3dab Mon Sep 17 00:00:00 2001 From: david-cortes Date: Sat, 24 Aug 2024 22:41:58 +0200 Subject: [PATCH] [R] Add class names to coefficients (#10745) --- R-package/R/xgb.Booster.R | 16 ++++++++++++---- R-package/tests/testthat/test_basic.R | 13 +++++++++++++ 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/R-package/R/xgb.Booster.R b/R-package/R/xgb.Booster.R index a15285091..7b5484fd2 100644 --- a/R-package/R/xgb.Booster.R +++ b/R-package/R/xgb.Booster.R @@ -1109,17 +1109,25 @@ coef.xgb.Booster <- function(object, ...) { if (n_cols == 1L) { out <- c(intercepts, coefs) if (add_names) { - names(out) <- feature_names + .Call(XGSetVectorNamesInplace_R, out, feature_names) } } else { coefs <- matrix(coefs, nrow = num_feature, byrow = TRUE) dim(intercepts) <- c(1L, n_cols) out <- rbind(intercepts, coefs) + out_names <- vector(mode = "list", length = 2) if (add_names) { - row.names(out) <- feature_names + out_names[[1L]] <- feature_names } - # TODO: if a class names attributes is added, - # should use those names here. + if (inherits(object, "xgboost")) { + metadata <- attributes(object)$metadata + if (NROW(metadata$y_levels)) { + out_names[[2L]] <- metadata$y_levels + } else if (NROW(metadata$y_names)) { + out_names[[2L]] <- metadata$y_names + } + } + .Call(XGSetArrayDimNamesInplace_R, out, out_names) } return(out) } diff --git a/R-package/tests/testthat/test_basic.R b/R-package/tests/testthat/test_basic.R index 840ff2635..03a346d02 100644 --- a/R-package/tests/testthat/test_basic.R +++ b/R-package/tests/testthat/test_basic.R @@ -750,6 +750,19 @@ test_that("Coefficients from gblinear have the expected shape and names", { pred_auto <- predict(model, x, outputmargin = TRUE) pred_manual <- unname(mm %*% coefs) expect_equal(pred_manual, pred_auto, tolerance = 1e-7) + + # xgboost() with additional metadata + model <- xgboost( + iris[, -5], + iris$Species, + booster = "gblinear", + objective = "multi:softprob", + nrounds = 3, + nthread = 1 + ) + coefs <- coef(model) + expect_equal(row.names(coefs), c("(Intercept)", colnames(x))) + expect_equal(colnames(coefs), levels(iris$Species)) }) test_that("Deep copies work as expected", {