[R] Avoid modifying importance dt in-place, fix aggregation (#9740)

This commit is contained in:
david-cortes 2023-10-31 22:10:59 +01:00 committed by GitHub
parent bc995a4865
commit d3f0646779
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 1 deletions

View File

@ -87,7 +87,13 @@ xgb.plot.importance <- function(importance_matrix = NULL, top_n = NULL, measure
} }
# also aggregate, just in case when the values were not yet summed up by feature # also aggregate, just in case when the values were not yet summed up by feature
importance_matrix <- importance_matrix[, Importance := sum(get(measure)), by = Feature] importance_matrix <- importance_matrix[
, lapply(.SD, sum)
, .SDcols = setdiff(names(importance_matrix), "Feature")
, by = Feature
][
, Importance := get(measure)
]
# make sure it's ordered # make sure it's ordered
importance_matrix <- importance_matrix[order(-abs(Importance))] importance_matrix <- importance_matrix[order(-abs(Importance))]

View File

@ -382,6 +382,9 @@ test_that("xgb.importance works with GLM model", {
expect_equal(colnames(imp2plot), c("Feature", "Weight", "Importance")) expect_equal(colnames(imp2plot), c("Feature", "Weight", "Importance"))
xgb.ggplot.importance(importance.GLM) xgb.ggplot.importance(importance.GLM)
# check that the input is not modified in-place
expect_false("Importance" %in% names(importance.GLM))
# for multiclass # for multiclass
imp.GLM <- xgb.importance(model = mbst.GLM) imp.GLM <- xgb.importance(model = mbst.GLM)
expect_equal(dim(imp.GLM), c(12, 3)) expect_equal(dim(imp.GLM), c(12, 3))
@ -400,6 +403,16 @@ test_that("xgb.model.dt.tree and xgb.importance work with a single split model",
expect_equal(imp$Gain, 1) expect_equal(imp$Gain, 1)
}) })
test_that("xgb.plot.importance de-duplicates features", {
importances <- data.table(
Feature = c("col1", "col2", "col2"),
Gain = c(0.4, 0.3, 0.3)
)
imp2plot <- xgb.plot.importance(importances)
expect_equal(nrow(imp2plot), 2L)
expect_equal(imp2plot$Feature, c("col2", "col1"))
})
test_that("xgb.plot.tree works with and without feature names", { test_that("xgb.plot.tree works with and without feature names", {
.skip_if_vcd_not_available() .skip_if_vcd_not_available()
expect_silent(xgb.plot.tree(feature_names = feature.names, model = bst.Tree)) expect_silent(xgb.plot.tree(feature_names = feature.names, model = bst.Tree))