[R-package] Alter xgb.train() to accept multiple eval metrics as a list (#8657)

This commit is contained in:
Philip Hyunsu Cho 2023-01-24 17:14:14 -08:00 committed by GitHub
parent 0f4d52a864
commit d29e45371f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 2 deletions

View File

@ -321,6 +321,10 @@ xgb.train <- function(params = list(), data, nrounds, watchlist = list(),
if (is.null(evnames) || any(evnames == "")) if (is.null(evnames) || any(evnames == ""))
stop("each element of the watchlist must have a name tag") stop("each element of the watchlist must have a name tag")
} }
# Handle multiple evaluation metrics given as a list
for (m in params$eval_metric) {
params <- c(params, list(eval_metric = m))
}
# evaluation printing callback # evaluation printing callback
params <- c(params) params <- c(params)

View File

@ -232,12 +232,20 @@ test_that("train and predict RF with softprob", {
test_that("use of multiple eval metrics works", { test_that("use of multiple eval metrics works", {
expect_output( expect_output(
bst <- xgboost(data = train$data, label = train$label, max_depth = 2, bst <- xgboost(data = train$data, label = train$label, max_depth = 2,
eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic", eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic",
eval_metric = 'error', eval_metric = 'auc', eval_metric = "logloss") eval_metric = 'error', eval_metric = 'auc', eval_metric = "logloss")
, "train-error.*train-auc.*train-logloss") , "train-error.*train-auc.*train-logloss")
expect_false(is.null(bst$evaluation_log)) expect_false(is.null(bst$evaluation_log))
expect_equal(dim(bst$evaluation_log), c(2, 4)) expect_equal(dim(bst$evaluation_log), c(2, 4))
expect_equal(colnames(bst$evaluation_log), c("iter", "train_error", "train_auc", "train_logloss")) expect_equal(colnames(bst$evaluation_log), c("iter", "train_error", "train_auc", "train_logloss"))
expect_output(
bst2 <- xgboost(data = train$data, label = train$label, max_depth = 2,
eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic",
eval_metric = list("error", "auc", "logloss"))
, "train-error.*train-auc.*train-logloss")
expect_false(is.null(bst2$evaluation_log))
expect_equal(dim(bst2$evaluation_log), c(2, 4))
expect_equal(colnames(bst2$evaluation_log), c("iter", "train_error", "train_auc", "train_logloss"))
}) })