From 12949c6b3134817f2f38ca765f42d9fdbda6e600 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 16 Feb 2022 22:20:52 +0800 Subject: [PATCH] [R] Implement feature weights. (#7660) --- R-package/R/xgb.DMatrix.R | 7 +++++ .../tests/testthat/test_feature_weights.R | 27 +++++++++++++++++++ doc/parameter.rst | 7 +++-- 3 files changed, 37 insertions(+), 4 deletions(-) create mode 100644 R-package/tests/testthat/test_feature_weights.R diff --git a/R-package/R/xgb.DMatrix.R b/R-package/R/xgb.DMatrix.R index 970e317cd..d9335405c 100644 --- a/R-package/R/xgb.DMatrix.R +++ b/R-package/R/xgb.DMatrix.R @@ -287,6 +287,13 @@ setinfo.xgb.DMatrix <- function(object, name, info, ...) { .Call(XGDMatrixSetInfo_R, object, name, as.integer(info)) return(TRUE) } + if (name == "feature_weights") { + if (length(info) != ncol(object)) { + stop("The number of feature weights must equal to the number of columns in the input data") + } + .Call(XGDMatrixSetInfo_R, object, name, as.numeric(info)) + return(TRUE) + } stop("setinfo: unknown info name ", name) return(FALSE) } diff --git a/R-package/tests/testthat/test_feature_weights.R b/R-package/tests/testthat/test_feature_weights.R new file mode 100644 index 000000000..580f58456 --- /dev/null +++ b/R-package/tests/testthat/test_feature_weights.R @@ -0,0 +1,27 @@ +library(xgboost) + +context("feature weights") + +test_that("training with feature weights works", { + nrows <- 1000 + ncols <- 9 + set.seed(2022) + x <- matrix(rnorm(nrows * ncols), nrow = nrows) + y <- rowSums(x) + weights <- seq(from = 1, to = ncols) + + test <- function(tm) { + names <- paste0("f", 1:ncols) + xy <- xgb.DMatrix(data = x, label = y, feature_weights = weights) + params <- list(colsample_bynode = 0.4, tree_method = tm, nthread = 1) + model <- xgb.train(params = params, data = xy, nrounds = 32) + importance <- xgb.importance(model = model, feature_names = names) + expect_equal(dim(importance), c(ncols, 4)) + importance <- importance[order(importance$Feature)] + expect_lt(importance[1, Frequency], importance[9, Frequency]) + } + + for (tm in c("hist", "approx", "exact")) { + test(tm) + } +}) diff --git a/doc/parameter.rst b/doc/parameter.rst index d341baa8b..227263e7d 100644 --- a/doc/parameter.rst +++ b/doc/parameter.rst @@ -115,10 +115,9 @@ Parameters for Tree Booster 'colsample_bynode':0.5}`` with 64 features will leave 8 features to choose from at each split. - On Python interface, when using ``hist``, ``gpu_hist`` or ``exact`` tree method, one - can set the ``feature_weights`` for DMatrix to define the probability of each feature - being selected when using column sampling. There's a similar parameter for ``fit`` - method in sklearn interface. + Using the Python or the R package, one can set the ``feature_weights`` for DMatrix to + define the probability of each feature being selected when using column sampling. + There's a similar parameter for ``fit`` method in sklearn interface. * ``lambda`` [default=1, alias: ``reg_lambda``]