[R-package] Alter xgb.train() to accept multiple eval metrics as a list (#8657)
This commit is contained in:
parent
0f4d52a864
commit
d29e45371f
@ -321,6 +321,10 @@ xgb.train <- function(params = list(), data, nrounds, watchlist = list(),
|
|||||||
if (is.null(evnames) || any(evnames == ""))
|
if (is.null(evnames) || any(evnames == ""))
|
||||||
stop("each element of the watchlist must have a name tag")
|
stop("each element of the watchlist must have a name tag")
|
||||||
}
|
}
|
||||||
|
# Handle multiple evaluation metrics given as a list
|
||||||
|
for (m in params$eval_metric) {
|
||||||
|
params <- c(params, list(eval_metric = m))
|
||||||
|
}
|
||||||
|
|
||||||
# evaluation printing callback
|
# evaluation printing callback
|
||||||
params <- c(params)
|
params <- c(params)
|
||||||
|
|||||||
@ -232,12 +232,20 @@ test_that("train and predict RF with softprob", {
|
|||||||
test_that("use of multiple eval metrics works", {
|
test_that("use of multiple eval metrics works", {
|
||||||
expect_output(
|
expect_output(
|
||||||
bst <- xgboost(data = train$data, label = train$label, max_depth = 2,
|
bst <- xgboost(data = train$data, label = train$label, max_depth = 2,
|
||||||
eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic",
|
eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic",
|
||||||
eval_metric = 'error', eval_metric = 'auc', eval_metric = "logloss")
|
eval_metric = 'error', eval_metric = 'auc', eval_metric = "logloss")
|
||||||
, "train-error.*train-auc.*train-logloss")
|
, "train-error.*train-auc.*train-logloss")
|
||||||
expect_false(is.null(bst$evaluation_log))
|
expect_false(is.null(bst$evaluation_log))
|
||||||
expect_equal(dim(bst$evaluation_log), c(2, 4))
|
expect_equal(dim(bst$evaluation_log), c(2, 4))
|
||||||
expect_equal(colnames(bst$evaluation_log), c("iter", "train_error", "train_auc", "train_logloss"))
|
expect_equal(colnames(bst$evaluation_log), c("iter", "train_error", "train_auc", "train_logloss"))
|
||||||
|
expect_output(
|
||||||
|
bst2 <- xgboost(data = train$data, label = train$label, max_depth = 2,
|
||||||
|
eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic",
|
||||||
|
eval_metric = list("error", "auc", "logloss"))
|
||||||
|
, "train-error.*train-auc.*train-logloss")
|
||||||
|
expect_false(is.null(bst2$evaluation_log))
|
||||||
|
expect_equal(dim(bst2$evaluation_log), c(2, 4))
|
||||||
|
expect_equal(colnames(bst2$evaluation_log), c("iter", "train_error", "train_auc", "train_logloss"))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user