diff --git a/R-package/R/utils.R b/R-package/R/utils.R index 6f1a1b4ec..bf08c481d 100644 --- a/R-package/R/utils.R +++ b/R-package/R/utils.R @@ -93,6 +93,14 @@ check.booster.params <- function(params, ...) { interaction_constraints <- sapply(params[['interaction_constraints']], function(x) paste0('[', paste(x, collapse = ','), ']')) params[['interaction_constraints']] <- paste0('[', paste(interaction_constraints, collapse = ','), ']') } + + # for evaluation metrics, should generate multiple entries per metric + if (NROW(params[['eval_metric']]) > 1) { + eval_metrics <- as.list(params[["eval_metric"]]) + names(eval_metrics) <- rep("eval_metric", length(eval_metrics)) + params_without_ev_metrics <- within(params, rm("eval_metric")) + params <- c(params_without_ev_metrics, eval_metrics) + } return(params) } diff --git a/R-package/R/xgb.Booster.R b/R-package/R/xgb.Booster.R index 37cfc199e..a5c9e5088 100644 --- a/R-package/R/xgb.Booster.R +++ b/R-package/R/xgb.Booster.R @@ -697,7 +697,13 @@ xgb.config <- function(object) { stop("parameter names cannot be empty strings") } names(p) <- gsub(".", "_", names(p), fixed = TRUE) - p <- lapply(p, function(x) as.character(x)[1]) + p <- lapply(p, function(x) { + if (is.vector(x) && length(x) == 1) { + return(as.character(x)[1]) + } else { + return(jsonlite::toJSON(x, auto_unbox = TRUE)) + } + }) handle <- xgb.get.handle(object) for (i in seq_along(p)) { .Call(XGBoosterSetParam_R, handle, names(p[i]), p[[i]]) diff --git a/R-package/tests/testthat/test_basic.R b/R-package/tests/testthat/test_basic.R index 97c1353dc..c96871e11 100644 --- a/R-package/tests/testthat/test_basic.R +++ b/R-package/tests/testthat/test_basic.R @@ -566,6 +566,33 @@ test_that("'predict' accepts CSR data", { expect_equal(p_csc, p_spv) }) +test_that("Quantile regression accepts multiple quantiles", { + data(mtcars) + y <- mtcars[, 1] + x <- as.matrix(mtcars[, -1]) + dm <- xgb.DMatrix(data = x, label = y) + model <- xgb.train( + data = dm, + params = list( + objective = "reg:quantileerror", + tree_method = "exact", + quantile_alpha = c(0.05, 0.5, 0.95), + nthread = n_threads + ), + nrounds = 15 + ) + pred <- predict(model, x, reshape = TRUE) + + expect_equal(dim(pred)[1], nrow(x)) + expect_equal(dim(pred)[2], 3) + expect_true(all(pred[, 1] <= pred[, 3])) + + cors <- cor(y, pred) + expect_true(cors[2] > cors[1]) + expect_true(cors[2] > cors[3]) + expect_true(cors[2] > 0.85) +}) + test_that("Can use multi-output labels with built-in objectives", { data("mtcars") y <- mtcars$mpg