[R] Add class names to coefficients (#10745)

This commit is contained in:
david-cortes 2024-08-24 22:41:58 +02:00 committed by GitHub
parent fd0138c91c
commit 479ae8081b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 25 additions and 4 deletions

View File

@ -1109,17 +1109,25 @@ coef.xgb.Booster <- function(object, ...) {
if (n_cols == 1L) { if (n_cols == 1L) {
out <- c(intercepts, coefs) out <- c(intercepts, coefs)
if (add_names) { if (add_names) {
names(out) <- feature_names .Call(XGSetVectorNamesInplace_R, out, feature_names)
} }
} else { } else {
coefs <- matrix(coefs, nrow = num_feature, byrow = TRUE) coefs <- matrix(coefs, nrow = num_feature, byrow = TRUE)
dim(intercepts) <- c(1L, n_cols) dim(intercepts) <- c(1L, n_cols)
out <- rbind(intercepts, coefs) out <- rbind(intercepts, coefs)
out_names <- vector(mode = "list", length = 2)
if (add_names) { if (add_names) {
row.names(out) <- feature_names out_names[[1L]] <- feature_names
} }
# TODO: if a class names attributes is added, if (inherits(object, "xgboost")) {
# should use those names here. metadata <- attributes(object)$metadata
if (NROW(metadata$y_levels)) {
out_names[[2L]] <- metadata$y_levels
} else if (NROW(metadata$y_names)) {
out_names[[2L]] <- metadata$y_names
}
}
.Call(XGSetArrayDimNamesInplace_R, out, out_names)
} }
return(out) return(out)
} }

View File

@ -750,6 +750,19 @@ test_that("Coefficients from gblinear have the expected shape and names", {
pred_auto <- predict(model, x, outputmargin = TRUE) pred_auto <- predict(model, x, outputmargin = TRUE)
pred_manual <- unname(mm %*% coefs) pred_manual <- unname(mm %*% coefs)
expect_equal(pred_manual, pred_auto, tolerance = 1e-7) expect_equal(pred_manual, pred_auto, tolerance = 1e-7)
# xgboost() with additional metadata
model <- xgboost(
iris[, -5],
iris$Species,
booster = "gblinear",
objective = "multi:softprob",
nrounds = 3,
nthread = 1
)
coefs <- coef(model)
expect_equal(row.names(coefs), c("(Intercept)", colnames(x)))
expect_equal(colnames(coefs), levels(iris$Species))
}) })
test_that("Deep copies work as expected", { test_that("Deep copies work as expected", {