[R] Enable vector-valued parameters (#9849)

This commit is contained in:
david-cortes 2023-12-06 13:32:20 +01:00 committed by GitHub
parent 0716c64ef7
commit 1de3f4135c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 42 additions and 1 deletions

View File

@ -93,6 +93,14 @@ check.booster.params <- function(params, ...) {
interaction_constraints <- sapply(params[['interaction_constraints']], function(x) paste0('[', paste(x, collapse = ','), ']'))
params[['interaction_constraints']] <- paste0('[', paste(interaction_constraints, collapse = ','), ']')
}
# for evaluation metrics, should generate multiple entries per metric
if (NROW(params[['eval_metric']]) > 1) {
eval_metrics <- as.list(params[["eval_metric"]])
names(eval_metrics) <- rep("eval_metric", length(eval_metrics))
params_without_ev_metrics <- within(params, rm("eval_metric"))
params <- c(params_without_ev_metrics, eval_metrics)
}
return(params)
}

View File

@ -697,7 +697,13 @@ xgb.config <- function(object) {
stop("parameter names cannot be empty strings")
}
names(p) <- gsub(".", "_", names(p), fixed = TRUE)
p <- lapply(p, function(x) as.character(x)[1])
p <- lapply(p, function(x) {
if (is.vector(x) && length(x) == 1) {
return(as.character(x)[1])
} else {
return(jsonlite::toJSON(x, auto_unbox = TRUE))
}
})
handle <- xgb.get.handle(object)
for (i in seq_along(p)) {
.Call(XGBoosterSetParam_R, handle, names(p[i]), p[[i]])

View File

@ -566,6 +566,33 @@ test_that("'predict' accepts CSR data", {
expect_equal(p_csc, p_spv)
})
test_that("Quantile regression accepts multiple quantiles", {
data(mtcars)
y <- mtcars[, 1]
x <- as.matrix(mtcars[, -1])
dm <- xgb.DMatrix(data = x, label = y)
model <- xgb.train(
data = dm,
params = list(
objective = "reg:quantileerror",
tree_method = "exact",
quantile_alpha = c(0.05, 0.5, 0.95),
nthread = n_threads
),
nrounds = 15
)
pred <- predict(model, x, reshape = TRUE)
expect_equal(dim(pred)[1], nrow(x))
expect_equal(dim(pred)[2], 3)
expect_true(all(pred[, 1] <= pred[, 3]))
cors <- cor(y, pred)
expect_true(cors[2] > cors[1])
expect_true(cors[2] > cors[3])
expect_true(cors[2] > 0.85)
})
test_that("Can use multi-output labels with built-in objectives", {
data("mtcars")
y <- mtcars$mpg