[R] Enable vector-valued parameters (#9849)
This commit is contained in:
parent
0716c64ef7
commit
1de3f4135c
@ -93,6 +93,14 @@ check.booster.params <- function(params, ...) {
|
|||||||
interaction_constraints <- sapply(params[['interaction_constraints']], function(x) paste0('[', paste(x, collapse = ','), ']'))
|
interaction_constraints <- sapply(params[['interaction_constraints']], function(x) paste0('[', paste(x, collapse = ','), ']'))
|
||||||
params[['interaction_constraints']] <- paste0('[', paste(interaction_constraints, collapse = ','), ']')
|
params[['interaction_constraints']] <- paste0('[', paste(interaction_constraints, collapse = ','), ']')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# for evaluation metrics, should generate multiple entries per metric
|
||||||
|
if (NROW(params[['eval_metric']]) > 1) {
|
||||||
|
eval_metrics <- as.list(params[["eval_metric"]])
|
||||||
|
names(eval_metrics) <- rep("eval_metric", length(eval_metrics))
|
||||||
|
params_without_ev_metrics <- within(params, rm("eval_metric"))
|
||||||
|
params <- c(params_without_ev_metrics, eval_metrics)
|
||||||
|
}
|
||||||
return(params)
|
return(params)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -697,7 +697,13 @@ xgb.config <- function(object) {
|
|||||||
stop("parameter names cannot be empty strings")
|
stop("parameter names cannot be empty strings")
|
||||||
}
|
}
|
||||||
names(p) <- gsub(".", "_", names(p), fixed = TRUE)
|
names(p) <- gsub(".", "_", names(p), fixed = TRUE)
|
||||||
p <- lapply(p, function(x) as.character(x)[1])
|
p <- lapply(p, function(x) {
|
||||||
|
if (is.vector(x) && length(x) == 1) {
|
||||||
|
return(as.character(x)[1])
|
||||||
|
} else {
|
||||||
|
return(jsonlite::toJSON(x, auto_unbox = TRUE))
|
||||||
|
}
|
||||||
|
})
|
||||||
handle <- xgb.get.handle(object)
|
handle <- xgb.get.handle(object)
|
||||||
for (i in seq_along(p)) {
|
for (i in seq_along(p)) {
|
||||||
.Call(XGBoosterSetParam_R, handle, names(p[i]), p[[i]])
|
.Call(XGBoosterSetParam_R, handle, names(p[i]), p[[i]])
|
||||||
|
|||||||
@ -566,6 +566,33 @@ test_that("'predict' accepts CSR data", {
|
|||||||
expect_equal(p_csc, p_spv)
|
expect_equal(p_csc, p_spv)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
test_that("Quantile regression accepts multiple quantiles", {
|
||||||
|
data(mtcars)
|
||||||
|
y <- mtcars[, 1]
|
||||||
|
x <- as.matrix(mtcars[, -1])
|
||||||
|
dm <- xgb.DMatrix(data = x, label = y)
|
||||||
|
model <- xgb.train(
|
||||||
|
data = dm,
|
||||||
|
params = list(
|
||||||
|
objective = "reg:quantileerror",
|
||||||
|
tree_method = "exact",
|
||||||
|
quantile_alpha = c(0.05, 0.5, 0.95),
|
||||||
|
nthread = n_threads
|
||||||
|
),
|
||||||
|
nrounds = 15
|
||||||
|
)
|
||||||
|
pred <- predict(model, x, reshape = TRUE)
|
||||||
|
|
||||||
|
expect_equal(dim(pred)[1], nrow(x))
|
||||||
|
expect_equal(dim(pred)[2], 3)
|
||||||
|
expect_true(all(pred[, 1] <= pred[, 3]))
|
||||||
|
|
||||||
|
cors <- cor(y, pred)
|
||||||
|
expect_true(cors[2] > cors[1])
|
||||||
|
expect_true(cors[2] > cors[3])
|
||||||
|
expect_true(cors[2] > 0.85)
|
||||||
|
})
|
||||||
|
|
||||||
test_that("Can use multi-output labels with built-in objectives", {
|
test_that("Can use multi-output labels with built-in objectives", {
|
||||||
data("mtcars")
|
data("mtcars")
|
||||||
y <- mtcars$mpg
|
y <- mtcars$mpg
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user