[R] Enable vector-valued parameters (#9849)
This commit is contained in:
parent
0716c64ef7
commit
1de3f4135c
@ -93,6 +93,14 @@ check.booster.params <- function(params, ...) {
|
||||
interaction_constraints <- sapply(params[['interaction_constraints']], function(x) paste0('[', paste(x, collapse = ','), ']'))
|
||||
params[['interaction_constraints']] <- paste0('[', paste(interaction_constraints, collapse = ','), ']')
|
||||
}
|
||||
|
||||
# for evaluation metrics, should generate multiple entries per metric
|
||||
if (NROW(params[['eval_metric']]) > 1) {
|
||||
eval_metrics <- as.list(params[["eval_metric"]])
|
||||
names(eval_metrics) <- rep("eval_metric", length(eval_metrics))
|
||||
params_without_ev_metrics <- within(params, rm("eval_metric"))
|
||||
params <- c(params_without_ev_metrics, eval_metrics)
|
||||
}
|
||||
return(params)
|
||||
}
|
||||
|
||||
|
||||
@ -697,7 +697,13 @@ xgb.config <- function(object) {
|
||||
stop("parameter names cannot be empty strings")
|
||||
}
|
||||
names(p) <- gsub(".", "_", names(p), fixed = TRUE)
|
||||
p <- lapply(p, function(x) as.character(x)[1])
|
||||
p <- lapply(p, function(x) {
|
||||
if (is.vector(x) && length(x) == 1) {
|
||||
return(as.character(x)[1])
|
||||
} else {
|
||||
return(jsonlite::toJSON(x, auto_unbox = TRUE))
|
||||
}
|
||||
})
|
||||
handle <- xgb.get.handle(object)
|
||||
for (i in seq_along(p)) {
|
||||
.Call(XGBoosterSetParam_R, handle, names(p[i]), p[[i]])
|
||||
|
||||
@ -566,6 +566,33 @@ test_that("'predict' accepts CSR data", {
|
||||
expect_equal(p_csc, p_spv)
|
||||
})
|
||||
|
||||
test_that("Quantile regression accepts multiple quantiles", {
|
||||
data(mtcars)
|
||||
y <- mtcars[, 1]
|
||||
x <- as.matrix(mtcars[, -1])
|
||||
dm <- xgb.DMatrix(data = x, label = y)
|
||||
model <- xgb.train(
|
||||
data = dm,
|
||||
params = list(
|
||||
objective = "reg:quantileerror",
|
||||
tree_method = "exact",
|
||||
quantile_alpha = c(0.05, 0.5, 0.95),
|
||||
nthread = n_threads
|
||||
),
|
||||
nrounds = 15
|
||||
)
|
||||
pred <- predict(model, x, reshape = TRUE)
|
||||
|
||||
expect_equal(dim(pred)[1], nrow(x))
|
||||
expect_equal(dim(pred)[2], 3)
|
||||
expect_true(all(pred[, 1] <= pred[, 3]))
|
||||
|
||||
cors <- cor(y, pred)
|
||||
expect_true(cors[2] > cors[1])
|
||||
expect_true(cors[2] > cors[3])
|
||||
expect_true(cors[2] > 0.85)
|
||||
})
|
||||
|
||||
test_that("Can use multi-output labels with built-in objectives", {
|
||||
data("mtcars")
|
||||
y <- mtcars$mpg
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user