[R] Implement feature weights. (#7660)
This commit is contained in:
parent
0149f81a5a
commit
12949c6b31
@ -287,6 +287,13 @@ setinfo.xgb.DMatrix <- function(object, name, info, ...) {
|
|||||||
.Call(XGDMatrixSetInfo_R, object, name, as.integer(info))
|
.Call(XGDMatrixSetInfo_R, object, name, as.integer(info))
|
||||||
return(TRUE)
|
return(TRUE)
|
||||||
}
|
}
|
||||||
|
if (name == "feature_weights") {
|
||||||
|
if (length(info) != ncol(object)) {
|
||||||
|
stop("The number of feature weights must equal to the number of columns in the input data")
|
||||||
|
}
|
||||||
|
.Call(XGDMatrixSetInfo_R, object, name, as.numeric(info))
|
||||||
|
return(TRUE)
|
||||||
|
}
|
||||||
stop("setinfo: unknown info name ", name)
|
stop("setinfo: unknown info name ", name)
|
||||||
return(FALSE)
|
return(FALSE)
|
||||||
}
|
}
|
||||||
|
|||||||
27
R-package/tests/testthat/test_feature_weights.R
Normal file
27
R-package/tests/testthat/test_feature_weights.R
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
library(xgboost)
|
||||||
|
|
||||||
|
context("feature weights")
|
||||||
|
|
||||||
|
test_that("training with feature weights works", {
|
||||||
|
nrows <- 1000
|
||||||
|
ncols <- 9
|
||||||
|
set.seed(2022)
|
||||||
|
x <- matrix(rnorm(nrows * ncols), nrow = nrows)
|
||||||
|
y <- rowSums(x)
|
||||||
|
weights <- seq(from = 1, to = ncols)
|
||||||
|
|
||||||
|
test <- function(tm) {
|
||||||
|
names <- paste0("f", 1:ncols)
|
||||||
|
xy <- xgb.DMatrix(data = x, label = y, feature_weights = weights)
|
||||||
|
params <- list(colsample_bynode = 0.4, tree_method = tm, nthread = 1)
|
||||||
|
model <- xgb.train(params = params, data = xy, nrounds = 32)
|
||||||
|
importance <- xgb.importance(model = model, feature_names = names)
|
||||||
|
expect_equal(dim(importance), c(ncols, 4))
|
||||||
|
importance <- importance[order(importance$Feature)]
|
||||||
|
expect_lt(importance[1, Frequency], importance[9, Frequency])
|
||||||
|
}
|
||||||
|
|
||||||
|
for (tm in c("hist", "approx", "exact")) {
|
||||||
|
test(tm)
|
||||||
|
}
|
||||||
|
})
|
||||||
@ -115,10 +115,9 @@ Parameters for Tree Booster
|
|||||||
'colsample_bynode':0.5}`` with 64 features will leave 8 features to choose from at
|
'colsample_bynode':0.5}`` with 64 features will leave 8 features to choose from at
|
||||||
each split.
|
each split.
|
||||||
|
|
||||||
On Python interface, when using ``hist``, ``gpu_hist`` or ``exact`` tree method, one
|
Using the Python or the R package, one can set the ``feature_weights`` for DMatrix to
|
||||||
can set the ``feature_weights`` for DMatrix to define the probability of each feature
|
define the probability of each feature being selected when using column sampling.
|
||||||
being selected when using column sampling. There's a similar parameter for ``fit``
|
There's a similar parameter for ``fit`` method in sklearn interface.
|
||||||
method in sklearn interface.
|
|
||||||
|
|
||||||
* ``lambda`` [default=1, alias: ``reg_lambda``]
|
* ``lambda`` [default=1, alias: ``reg_lambda``]
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user