[R] Add class names to coefficients (#10745)
This commit is contained in:
parent
fd0138c91c
commit
479ae8081b
@ -1109,17 +1109,25 @@ coef.xgb.Booster <- function(object, ...) {
|
|||||||
if (n_cols == 1L) {
|
if (n_cols == 1L) {
|
||||||
out <- c(intercepts, coefs)
|
out <- c(intercepts, coefs)
|
||||||
if (add_names) {
|
if (add_names) {
|
||||||
names(out) <- feature_names
|
.Call(XGSetVectorNamesInplace_R, out, feature_names)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
coefs <- matrix(coefs, nrow = num_feature, byrow = TRUE)
|
coefs <- matrix(coefs, nrow = num_feature, byrow = TRUE)
|
||||||
dim(intercepts) <- c(1L, n_cols)
|
dim(intercepts) <- c(1L, n_cols)
|
||||||
out <- rbind(intercepts, coefs)
|
out <- rbind(intercepts, coefs)
|
||||||
|
out_names <- vector(mode = "list", length = 2)
|
||||||
if (add_names) {
|
if (add_names) {
|
||||||
row.names(out) <- feature_names
|
out_names[[1L]] <- feature_names
|
||||||
}
|
}
|
||||||
# TODO: if a class names attributes is added,
|
if (inherits(object, "xgboost")) {
|
||||||
# should use those names here.
|
metadata <- attributes(object)$metadata
|
||||||
|
if (NROW(metadata$y_levels)) {
|
||||||
|
out_names[[2L]] <- metadata$y_levels
|
||||||
|
} else if (NROW(metadata$y_names)) {
|
||||||
|
out_names[[2L]] <- metadata$y_names
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.Call(XGSetArrayDimNamesInplace_R, out, out_names)
|
||||||
}
|
}
|
||||||
return(out)
|
return(out)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -750,6 +750,19 @@ test_that("Coefficients from gblinear have the expected shape and names", {
|
|||||||
pred_auto <- predict(model, x, outputmargin = TRUE)
|
pred_auto <- predict(model, x, outputmargin = TRUE)
|
||||||
pred_manual <- unname(mm %*% coefs)
|
pred_manual <- unname(mm %*% coefs)
|
||||||
expect_equal(pred_manual, pred_auto, tolerance = 1e-7)
|
expect_equal(pred_manual, pred_auto, tolerance = 1e-7)
|
||||||
|
|
||||||
|
# xgboost() with additional metadata
|
||||||
|
model <- xgboost(
|
||||||
|
iris[, -5],
|
||||||
|
iris$Species,
|
||||||
|
booster = "gblinear",
|
||||||
|
objective = "multi:softprob",
|
||||||
|
nrounds = 3,
|
||||||
|
nthread = 1
|
||||||
|
)
|
||||||
|
coefs <- coef(model)
|
||||||
|
expect_equal(row.names(coefs), c("(Intercept)", colnames(x)))
|
||||||
|
expect_equal(colnames(coefs), levels(iris$Species))
|
||||||
})
|
})
|
||||||
|
|
||||||
test_that("Deep copies work as expected", {
|
test_that("Deep copies work as expected", {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user