[R] Avoid modifying importance dt in-place, fix aggregation (#9740)
This commit is contained in:
parent
bc995a4865
commit
d3f0646779
@ -87,7 +87,13 @@ xgb.plot.importance <- function(importance_matrix = NULL, top_n = NULL, measure
|
|||||||
}
|
}
|
||||||
|
|
||||||
# also aggregate, just in case when the values were not yet summed up by feature
|
# also aggregate, just in case when the values were not yet summed up by feature
|
||||||
importance_matrix <- importance_matrix[, Importance := sum(get(measure)), by = Feature]
|
importance_matrix <- importance_matrix[
|
||||||
|
, lapply(.SD, sum)
|
||||||
|
, .SDcols = setdiff(names(importance_matrix), "Feature")
|
||||||
|
, by = Feature
|
||||||
|
][
|
||||||
|
, Importance := get(measure)
|
||||||
|
]
|
||||||
|
|
||||||
# make sure it's ordered
|
# make sure it's ordered
|
||||||
importance_matrix <- importance_matrix[order(-abs(Importance))]
|
importance_matrix <- importance_matrix[order(-abs(Importance))]
|
||||||
|
|||||||
@ -382,6 +382,9 @@ test_that("xgb.importance works with GLM model", {
|
|||||||
expect_equal(colnames(imp2plot), c("Feature", "Weight", "Importance"))
|
expect_equal(colnames(imp2plot), c("Feature", "Weight", "Importance"))
|
||||||
xgb.ggplot.importance(importance.GLM)
|
xgb.ggplot.importance(importance.GLM)
|
||||||
|
|
||||||
|
# check that the input is not modified in-place
|
||||||
|
expect_false("Importance" %in% names(importance.GLM))
|
||||||
|
|
||||||
# for multiclass
|
# for multiclass
|
||||||
imp.GLM <- xgb.importance(model = mbst.GLM)
|
imp.GLM <- xgb.importance(model = mbst.GLM)
|
||||||
expect_equal(dim(imp.GLM), c(12, 3))
|
expect_equal(dim(imp.GLM), c(12, 3))
|
||||||
@ -400,6 +403,16 @@ test_that("xgb.model.dt.tree and xgb.importance work with a single split model",
|
|||||||
expect_equal(imp$Gain, 1)
|
expect_equal(imp$Gain, 1)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
test_that("xgb.plot.importance de-duplicates features", {
|
||||||
|
importances <- data.table(
|
||||||
|
Feature = c("col1", "col2", "col2"),
|
||||||
|
Gain = c(0.4, 0.3, 0.3)
|
||||||
|
)
|
||||||
|
imp2plot <- xgb.plot.importance(importances)
|
||||||
|
expect_equal(nrow(imp2plot), 2L)
|
||||||
|
expect_equal(imp2plot$Feature, c("col2", "col1"))
|
||||||
|
})
|
||||||
|
|
||||||
test_that("xgb.plot.tree works with and without feature names", {
|
test_that("xgb.plot.tree works with and without feature names", {
|
||||||
.skip_if_vcd_not_available()
|
.skip_if_vcd_not_available()
|
||||||
expect_silent(xgb.plot.tree(feature_names = feature.names, model = bst.Tree))
|
expect_silent(xgb.plot.tree(feature_names = feature.names, model = bst.Tree))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user