[Breaking] Change default evaluation metric for classification to logloss / mlogloss (#6183)
* Change DefaultEvalMetric of classification from error to logloss * Change default binary metric in plugin/example/custom_obj.cc * Set old error metric in python tests * Set old error metric in R tests * Fix missed eval metrics and typos in R tests * Fix setting eval_metric twice in R tests * Add warning for empty eval_metric for classification * Fix Dask tests Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
parent
e0e4f15d0e
commit
cf4f019ed6
@ -17,7 +17,8 @@ test_that("train and predict binary classification", {
|
|||||||
nrounds <- 2
|
nrounds <- 2
|
||||||
expect_output(
|
expect_output(
|
||||||
bst <- xgboost(data = train$data, label = train$label, max_depth = 2,
|
bst <- xgboost(data = train$data, label = train$label, max_depth = 2,
|
||||||
eta = 1, nthread = 2, nrounds = nrounds, objective = "binary:logistic")
|
eta = 1, nthread = 2, nrounds = nrounds, objective = "binary:logistic",
|
||||||
|
eval_metric = "error")
|
||||||
, "train-error")
|
, "train-error")
|
||||||
expect_equal(class(bst), "xgb.Booster")
|
expect_equal(class(bst), "xgb.Booster")
|
||||||
expect_equal(bst$niter, nrounds)
|
expect_equal(bst$niter, nrounds)
|
||||||
@ -122,7 +123,7 @@ test_that("train and predict softprob", {
|
|||||||
expect_output(
|
expect_output(
|
||||||
bst <- xgboost(data = as.matrix(iris[, -5]), label = lb,
|
bst <- xgboost(data = as.matrix(iris[, -5]), label = lb,
|
||||||
max_depth = 3, eta = 0.5, nthread = 2, nrounds = 5,
|
max_depth = 3, eta = 0.5, nthread = 2, nrounds = 5,
|
||||||
objective = "multi:softprob", num_class = 3)
|
objective = "multi:softprob", num_class = 3, eval_metric = "merror")
|
||||||
, "train-merror")
|
, "train-merror")
|
||||||
expect_false(is.null(bst$evaluation_log))
|
expect_false(is.null(bst$evaluation_log))
|
||||||
expect_lt(bst$evaluation_log[, min(train_merror)], 0.025)
|
expect_lt(bst$evaluation_log[, min(train_merror)], 0.025)
|
||||||
@ -150,7 +151,7 @@ test_that("train and predict softmax", {
|
|||||||
expect_output(
|
expect_output(
|
||||||
bst <- xgboost(data = as.matrix(iris[, -5]), label = lb,
|
bst <- xgboost(data = as.matrix(iris[, -5]), label = lb,
|
||||||
max_depth = 3, eta = 0.5, nthread = 2, nrounds = 5,
|
max_depth = 3, eta = 0.5, nthread = 2, nrounds = 5,
|
||||||
objective = "multi:softmax", num_class = 3)
|
objective = "multi:softmax", num_class = 3, eval_metric = "merror")
|
||||||
, "train-merror")
|
, "train-merror")
|
||||||
expect_false(is.null(bst$evaluation_log))
|
expect_false(is.null(bst$evaluation_log))
|
||||||
expect_lt(bst$evaluation_log[, min(train_merror)], 0.025)
|
expect_lt(bst$evaluation_log[, min(train_merror)], 0.025)
|
||||||
@ -167,7 +168,7 @@ test_that("train and predict RF", {
|
|||||||
lb <- train$label
|
lb <- train$label
|
||||||
# single iteration
|
# single iteration
|
||||||
bst <- xgboost(data = train$data, label = lb, max_depth = 5,
|
bst <- xgboost(data = train$data, label = lb, max_depth = 5,
|
||||||
nthread = 2, nrounds = 1, objective = "binary:logistic",
|
nthread = 2, nrounds = 1, objective = "binary:logistic", eval_metric = "error",
|
||||||
num_parallel_tree = 20, subsample = 0.6, colsample_bytree = 0.1)
|
num_parallel_tree = 20, subsample = 0.6, colsample_bytree = 0.1)
|
||||||
expect_equal(bst$niter, 1)
|
expect_equal(bst$niter, 1)
|
||||||
expect_equal(xgb.ntree(bst), 20)
|
expect_equal(xgb.ntree(bst), 20)
|
||||||
@ -193,7 +194,8 @@ test_that("train and predict RF with softprob", {
|
|||||||
set.seed(11)
|
set.seed(11)
|
||||||
bst <- xgboost(data = as.matrix(iris[, -5]), label = lb,
|
bst <- xgboost(data = as.matrix(iris[, -5]), label = lb,
|
||||||
max_depth = 3, eta = 0.9, nthread = 2, nrounds = nrounds,
|
max_depth = 3, eta = 0.9, nthread = 2, nrounds = nrounds,
|
||||||
objective = "multi:softprob", num_class = 3, verbose = 0,
|
objective = "multi:softprob", eval_metric = "merror",
|
||||||
|
num_class = 3, verbose = 0,
|
||||||
num_parallel_tree = 4, subsample = 0.5, colsample_bytree = 0.5)
|
num_parallel_tree = 4, subsample = 0.5, colsample_bytree = 0.5)
|
||||||
expect_equal(bst$niter, 15)
|
expect_equal(bst$niter, 15)
|
||||||
expect_equal(xgb.ntree(bst), 15 * 3 * 4)
|
expect_equal(xgb.ntree(bst), 15 * 3 * 4)
|
||||||
@ -274,7 +276,7 @@ test_that("xgb.cv works", {
|
|||||||
expect_output(
|
expect_output(
|
||||||
cv <- xgb.cv(data = train$data, label = train$label, max_depth = 2, nfold = 5,
|
cv <- xgb.cv(data = train$data, label = train$label, max_depth = 2, nfold = 5,
|
||||||
eta = 1., nthread = 2, nrounds = 2, objective = "binary:logistic",
|
eta = 1., nthread = 2, nrounds = 2, objective = "binary:logistic",
|
||||||
verbose = TRUE)
|
eval_metric = "error", verbose = TRUE)
|
||||||
, "train-error:")
|
, "train-error:")
|
||||||
expect_is(cv, 'xgb.cv.synchronous')
|
expect_is(cv, 'xgb.cv.synchronous')
|
||||||
expect_false(is.null(cv$evaluation_log))
|
expect_false(is.null(cv$evaluation_log))
|
||||||
@ -299,7 +301,7 @@ test_that("xgb.cv works with stratified folds", {
|
|||||||
eta = 1., nthread = 2, nrounds = 2, objective = "binary:logistic",
|
eta = 1., nthread = 2, nrounds = 2, objective = "binary:logistic",
|
||||||
verbose = TRUE, stratified = TRUE)
|
verbose = TRUE, stratified = TRUE)
|
||||||
# Stratified folds should result in a different evaluation logs
|
# Stratified folds should result in a different evaluation logs
|
||||||
expect_true(all(cv$evaluation_log[, test_error_mean] != cv2$evaluation_log[, test_error_mean]))
|
expect_true(all(cv$evaluation_log[, test_logloss_mean] != cv2$evaluation_log[, test_logloss_mean]))
|
||||||
})
|
})
|
||||||
|
|
||||||
test_that("train and predict with non-strict classes", {
|
test_that("train and predict with non-strict classes", {
|
||||||
|
|||||||
@ -26,7 +26,8 @@ watchlist <- list(train = dtrain, test = dtest)
|
|||||||
|
|
||||||
err <- function(label, pr) sum((pr > 0.5) != label) / length(label)
|
err <- function(label, pr) sum((pr > 0.5) != label) / length(label)
|
||||||
|
|
||||||
param <- list(objective = "binary:logistic", max_depth = 2, nthread = 2)
|
param <- list(objective = "binary:logistic", eval_metric = "error",
|
||||||
|
max_depth = 2, nthread = 2)
|
||||||
|
|
||||||
|
|
||||||
test_that("cb.print.evaluation works as expected", {
|
test_that("cb.print.evaluation works as expected", {
|
||||||
@ -105,7 +106,8 @@ test_that("cb.evaluation.log works as expected", {
|
|||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
param <- list(objective = "binary:logistic", max_depth = 4, nthread = 2)
|
param <- list(objective = "binary:logistic", eval_metric = "error",
|
||||||
|
max_depth = 4, nthread = 2)
|
||||||
|
|
||||||
test_that("can store evaluation_log without printing", {
|
test_that("can store evaluation_log without printing", {
|
||||||
expect_silent(
|
expect_silent(
|
||||||
@ -236,7 +238,7 @@ test_that("early stopping xgb.train works", {
|
|||||||
test_that("early stopping using a specific metric works", {
|
test_that("early stopping using a specific metric works", {
|
||||||
set.seed(11)
|
set.seed(11)
|
||||||
expect_output(
|
expect_output(
|
||||||
bst <- xgb.train(param, dtrain, nrounds = 20, watchlist, eta = 0.6,
|
bst <- xgb.train(param[-2], dtrain, nrounds = 20, watchlist, eta = 0.6,
|
||||||
eval_metric = "logloss", eval_metric = "auc",
|
eval_metric = "logloss", eval_metric = "auc",
|
||||||
callbacks = list(cb.early.stop(stopping_rounds = 3, maximize = FALSE,
|
callbacks = list(cb.early.stop(stopping_rounds = 3, maximize = FALSE,
|
||||||
metric_name = 'test_logloss')))
|
metric_name = 'test_logloss')))
|
||||||
|
|||||||
@ -8,7 +8,7 @@ test_that("gblinear works", {
|
|||||||
dtrain <- xgb.DMatrix(agaricus.train$data, label = agaricus.train$label)
|
dtrain <- xgb.DMatrix(agaricus.train$data, label = agaricus.train$label)
|
||||||
dtest <- xgb.DMatrix(agaricus.test$data, label = agaricus.test$label)
|
dtest <- xgb.DMatrix(agaricus.test$data, label = agaricus.test$label)
|
||||||
|
|
||||||
param <- list(objective = "binary:logistic", booster = "gblinear",
|
param <- list(objective = "binary:logistic", eval_metric = "error", booster = "gblinear",
|
||||||
nthread = 2, eta = 0.8, alpha = 0.0001, lambda = 0.0001)
|
nthread = 2, eta = 0.8, alpha = 0.0001, lambda = 0.0001)
|
||||||
watchlist <- list(eval = dtest, train = dtrain)
|
watchlist <- list(eval = dtest, train = dtrain)
|
||||||
|
|
||||||
|
|||||||
@ -142,7 +142,8 @@ def main(args):
|
|||||||
|
|
||||||
native_results = {}
|
native_results = {}
|
||||||
# Use the same objective function defined in XGBoost.
|
# Use the same objective function defined in XGBoost.
|
||||||
booster_native = xgb.train({'num_class': kClasses},
|
booster_native = xgb.train({'num_class': kClasses,
|
||||||
|
'eval_metric': 'merror'},
|
||||||
m,
|
m,
|
||||||
num_boost_round=kRounds,
|
num_boost_round=kRounds,
|
||||||
evals_result=native_results,
|
evals_result=native_results,
|
||||||
|
|||||||
@ -376,7 +376,7 @@ Specify the learning task and the corresponding learning objective. The objectiv
|
|||||||
|
|
||||||
* ``eval_metric`` [default according to objective]
|
* ``eval_metric`` [default according to objective]
|
||||||
|
|
||||||
- Evaluation metrics for validation data, a default metric will be assigned according to objective (rmse for regression, and error for classification, mean average precision for ranking)
|
- Evaluation metrics for validation data, a default metric will be assigned according to objective (rmse for regression, and logloss for classification, mean average precision for ranking)
|
||||||
- User can add multiple evaluation metrics. Python users: remember to pass the metrics in as list of parameters pairs instead of map, so that latter ``eval_metric`` won't override previous one
|
- User can add multiple evaluation metrics. Python users: remember to pass the metrics in as list of parameters pairs instead of map, so that latter ``eval_metric`` won't override previous one
|
||||||
- The choices are listed below:
|
- The choices are listed below:
|
||||||
|
|
||||||
|
|||||||
@ -154,10 +154,10 @@ class XGBoostClassifier (
|
|||||||
require(isDefined(objective), "Users must set \'objective\' via xgboostParams.")
|
require(isDefined(objective), "Users must set \'objective\' via xgboostParams.")
|
||||||
if ($(objective).startsWith("multi")) {
|
if ($(objective).startsWith("multi")) {
|
||||||
// multi
|
// multi
|
||||||
"merror"
|
"mlogloss"
|
||||||
} else {
|
} else {
|
||||||
// binary
|
// binary
|
||||||
"error"
|
"logloss"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -56,7 +56,7 @@ class MyLogistic : public ObjFunction {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
const char* DefaultEvalMetric() const override {
|
const char* DefaultEvalMetric() const override {
|
||||||
return "error";
|
return "logloss";
|
||||||
}
|
}
|
||||||
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
|
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
|
||||||
// transform margin value to probability.
|
// transform margin value to probability.
|
||||||
|
|||||||
@ -103,7 +103,7 @@ struct LogisticRegressionOneAPI {
|
|||||||
|
|
||||||
// logistic loss for binary classification task
|
// logistic loss for binary classification task
|
||||||
struct LogisticClassificationOneAPI : public LogisticRegressionOneAPI {
|
struct LogisticClassificationOneAPI : public LogisticRegressionOneAPI {
|
||||||
static const char* DefaultEvalMetric() { return "error"; }
|
static const char* DefaultEvalMetric() { return "logloss"; }
|
||||||
static const char* Name() { return "binary:logistic_oneapi"; }
|
static const char* Name() { return "binary:logistic_oneapi"; }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -1031,6 +1031,18 @@ class LearnerImpl : public LearnerIO {
|
|||||||
std::ostringstream os;
|
std::ostringstream os;
|
||||||
os << '[' << iter << ']' << std::setiosflags(std::ios::fixed);
|
os << '[' << iter << ']' << std::setiosflags(std::ios::fixed);
|
||||||
if (metrics_.size() == 0 && tparam_.disable_default_eval_metric <= 0) {
|
if (metrics_.size() == 0 && tparam_.disable_default_eval_metric <= 0) {
|
||||||
|
auto warn_default_eval_metric = [](const std::string& objective, const std::string& before,
|
||||||
|
const std::string& after) {
|
||||||
|
LOG(WARNING) << "Starting in XGBoost 1.3.0, the default evaluation metric used with the "
|
||||||
|
<< "objective '" << objective << "' was changed from '" << before
|
||||||
|
<< "' to '" << after << "'. Explicitly set eval_metric if you'd like to "
|
||||||
|
<< "restore the old behavior.";
|
||||||
|
};
|
||||||
|
if (tparam_.objective == "binary:logistic") {
|
||||||
|
warn_default_eval_metric(tparam_.objective, "error", "logloss");
|
||||||
|
} else if ((tparam_.objective == "multi:softmax" || tparam_.objective == "multi:softprob")) {
|
||||||
|
warn_default_eval_metric(tparam_.objective, "merror", "mlogloss");
|
||||||
|
}
|
||||||
metrics_.emplace_back(Metric::Create(obj_->DefaultEvalMetric(), &generic_parameters_));
|
metrics_.emplace_back(Metric::Create(obj_->DefaultEvalMetric(), &generic_parameters_));
|
||||||
metrics_.back()->Configure({cfg_.begin(), cfg_.end()});
|
metrics_.back()->Configure({cfg_.begin(), cfg_.end()});
|
||||||
}
|
}
|
||||||
|
|||||||
@ -125,7 +125,7 @@ class SoftmaxMultiClassObj : public ObjFunction {
|
|||||||
this->Transform(io_preds, true);
|
this->Transform(io_preds, true);
|
||||||
}
|
}
|
||||||
const char* DefaultEvalMetric() const override {
|
const char* DefaultEvalMetric() const override {
|
||||||
return "merror";
|
return "mlogloss";
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void Transform(HostDeviceVector<bst_float> *io_preds, bool prob) {
|
inline void Transform(HostDeviceVector<bst_float> *io_preds, bool prob) {
|
||||||
|
|||||||
@ -131,7 +131,7 @@ struct PseudoHuberError {
|
|||||||
|
|
||||||
// logistic loss for binary classification task
|
// logistic loss for binary classification task
|
||||||
struct LogisticClassification : public LogisticRegression {
|
struct LogisticClassification : public LogisticRegression {
|
||||||
static const char* DefaultEvalMetric() { return "error"; }
|
static const char* DefaultEvalMetric() { return "logloss"; }
|
||||||
static const char* Name() { return "binary:logistic"; }
|
static const char* Name() { return "binary:logistic"; }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -8,7 +8,7 @@ namespace xgboost {
|
|||||||
TEST(Plugin, ExampleObjective) {
|
TEST(Plugin, ExampleObjective) {
|
||||||
xgboost::GenericParameter tparam = CreateEmptyGenericParam(GPUIDX);
|
xgboost::GenericParameter tparam = CreateEmptyGenericParam(GPUIDX);
|
||||||
auto * obj = xgboost::ObjFunction::Create("mylogistic", &tparam);
|
auto * obj = xgboost::ObjFunction::Create("mylogistic", &tparam);
|
||||||
ASSERT_EQ(obj->DefaultEvalMetric(), std::string{"error"});
|
ASSERT_EQ(obj->DefaultEvalMetric(), std::string{"logloss"});
|
||||||
delete obj;
|
delete obj;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -81,7 +81,7 @@ class TestBasic(unittest.TestCase):
|
|||||||
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||||
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
|
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
|
||||||
param = {'max_depth': 2, 'eta': 1, 'verbosity': 0,
|
param = {'max_depth': 2, 'eta': 1, 'verbosity': 0,
|
||||||
'objective': 'binary:logistic'}
|
'objective': 'binary:logistic', 'eval_metric': 'error'}
|
||||||
# specify validations set to watch performance
|
# specify validations set to watch performance
|
||||||
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
||||||
num_round = 2
|
num_round = 2
|
||||||
|
|||||||
@ -117,7 +117,8 @@ class TestModels(unittest.TestCase):
|
|||||||
# learning_rates as a list
|
# learning_rates as a list
|
||||||
# init eta with 0 to check whether learning_rates work
|
# init eta with 0 to check whether learning_rates work
|
||||||
param = {'max_depth': 2, 'eta': 0, 'verbosity': 0,
|
param = {'max_depth': 2, 'eta': 0, 'verbosity': 0,
|
||||||
'objective': 'binary:logistic', 'tree_method': tree_method}
|
'objective': 'binary:logistic', 'eval_metric': 'error',
|
||||||
|
'tree_method': tree_method}
|
||||||
evals_result = {}
|
evals_result = {}
|
||||||
bst = xgb.train(param, dtrain, num_round, watchlist,
|
bst = xgb.train(param, dtrain, num_round, watchlist,
|
||||||
callbacks=[xgb.callback.reset_learning_rate([
|
callbacks=[xgb.callback.reset_learning_rate([
|
||||||
@ -131,7 +132,8 @@ class TestModels(unittest.TestCase):
|
|||||||
|
|
||||||
# init learning_rate with 0 to check whether learning_rates work
|
# init learning_rate with 0 to check whether learning_rates work
|
||||||
param = {'max_depth': 2, 'learning_rate': 0, 'verbosity': 0,
|
param = {'max_depth': 2, 'learning_rate': 0, 'verbosity': 0,
|
||||||
'objective': 'binary:logistic', 'tree_method': tree_method}
|
'objective': 'binary:logistic', 'eval_metric': 'error',
|
||||||
|
'tree_method': tree_method}
|
||||||
evals_result = {}
|
evals_result = {}
|
||||||
bst = xgb.train(param, dtrain, num_round, watchlist,
|
bst = xgb.train(param, dtrain, num_round, watchlist,
|
||||||
callbacks=[xgb.callback.reset_learning_rate(
|
callbacks=[xgb.callback.reset_learning_rate(
|
||||||
@ -145,7 +147,7 @@ class TestModels(unittest.TestCase):
|
|||||||
# check if learning_rates override default value of eta/learning_rate
|
# check if learning_rates override default value of eta/learning_rate
|
||||||
param = {
|
param = {
|
||||||
'max_depth': 2, 'verbosity': 0, 'objective': 'binary:logistic',
|
'max_depth': 2, 'verbosity': 0, 'objective': 'binary:logistic',
|
||||||
'tree_method': tree_method
|
'eval_metric': 'error', 'tree_method': tree_method
|
||||||
}
|
}
|
||||||
evals_result = {}
|
evals_result = {}
|
||||||
bst = xgb.train(param, dtrain, num_round, watchlist,
|
bst = xgb.train(param, dtrain, num_round, watchlist,
|
||||||
|
|||||||
@ -115,7 +115,9 @@ class TestDMatrix(unittest.TestCase):
|
|||||||
|
|
||||||
eval_res_0 = {}
|
eval_res_0 = {}
|
||||||
booster = xgb.train(
|
booster = xgb.train(
|
||||||
{'num_class': 3, 'objective': 'multi:softprob'}, d,
|
{'num_class': 3, 'objective': 'multi:softprob',
|
||||||
|
'eval_metric': 'merror'},
|
||||||
|
d,
|
||||||
num_boost_round=2, evals=[(d, 'd')], evals_result=eval_res_0)
|
num_boost_round=2, evals=[(d, 'd')], evals_result=eval_res_0)
|
||||||
|
|
||||||
predt = booster.predict(d)
|
predt = booster.predict(d)
|
||||||
@ -130,9 +132,11 @@ class TestDMatrix(unittest.TestCase):
|
|||||||
assert sliced_margin.shape[0] == len(ridxs) * 3
|
assert sliced_margin.shape[0] == len(ridxs) * 3
|
||||||
|
|
||||||
eval_res_1 = {}
|
eval_res_1 = {}
|
||||||
xgb.train({'num_class': 3, 'objective': 'multi:softprob'}, sliced,
|
xgb.train(
|
||||||
num_boost_round=2, evals=[(sliced, 'd')],
|
{'num_class': 3, 'objective': 'multi:softprob',
|
||||||
evals_result=eval_res_1)
|
'eval_metric': 'merror'},
|
||||||
|
sliced,
|
||||||
|
num_boost_round=2, evals=[(sliced, 'd')], evals_result=eval_res_1)
|
||||||
|
|
||||||
eval_res_0 = eval_res_0['d']['merror']
|
eval_res_0 = eval_res_0['d']['merror']
|
||||||
eval_res_1 = eval_res_1['d']['merror']
|
eval_res_1 = eval_res_1['d']['merror']
|
||||||
|
|||||||
@ -58,7 +58,7 @@ class TestEarlyStopping(unittest.TestCase):
|
|||||||
y = digits['target']
|
y = digits['target']
|
||||||
dm = xgb.DMatrix(X, label=y)
|
dm = xgb.DMatrix(X, label=y)
|
||||||
params = {'max_depth': 2, 'eta': 1, 'verbosity': 0,
|
params = {'max_depth': 2, 'eta': 1, 'verbosity': 0,
|
||||||
'objective': 'binary:logistic'}
|
'objective': 'binary:logistic', 'eval_metric': 'error'}
|
||||||
|
|
||||||
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10,
|
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10,
|
||||||
early_stopping_rounds=10)
|
early_stopping_rounds=10)
|
||||||
|
|||||||
@ -274,7 +274,7 @@ def test_dask_classifier():
|
|||||||
X, y = generate_array()
|
X, y = generate_array()
|
||||||
y = (y * 10).astype(np.int32)
|
y = (y * 10).astype(np.int32)
|
||||||
classifier = xgb.dask.DaskXGBClassifier(
|
classifier = xgb.dask.DaskXGBClassifier(
|
||||||
verbosity=1, n_estimators=2)
|
verbosity=1, n_estimators=2, eval_metric='merror')
|
||||||
classifier.client = client
|
classifier.client = client
|
||||||
classifier.fit(X, y, eval_set=[(X, y)])
|
classifier.fit(X, y, eval_set=[(X, y)])
|
||||||
prediction = classifier.predict(X)
|
prediction = classifier.predict(X)
|
||||||
@ -386,6 +386,7 @@ def run_empty_dmatrix_cls(client, parameters):
|
|||||||
y = dd.from_array(np.random.randint(low=0, high=n_classes, size=kRows))
|
y = dd.from_array(np.random.randint(low=0, high=n_classes, size=kRows))
|
||||||
dtrain = xgb.dask.DaskDMatrix(client, X, y)
|
dtrain = xgb.dask.DaskDMatrix(client, X, y)
|
||||||
parameters['objective'] = 'multi:softprob'
|
parameters['objective'] = 'multi:softprob'
|
||||||
|
parameters['eval_metric'] = 'merror'
|
||||||
parameters['num_class'] = n_classes
|
parameters['num_class'] = n_classes
|
||||||
|
|
||||||
out = xgb.dask.train(client, parameters,
|
out = xgb.dask.train(client, parameters,
|
||||||
@ -482,7 +483,7 @@ async def run_dask_classifier_asyncio(scheduler_address):
|
|||||||
X, y = generate_array()
|
X, y = generate_array()
|
||||||
y = (y * 10).astype(np.int32)
|
y = (y * 10).astype(np.int32)
|
||||||
classifier = await xgb.dask.DaskXGBClassifier(
|
classifier = await xgb.dask.DaskXGBClassifier(
|
||||||
verbosity=1, n_estimators=2)
|
verbosity=1, n_estimators=2, eval_metric='merror')
|
||||||
classifier.client = client
|
classifier.client = client
|
||||||
await classifier.fit(X, y, eval_set=[(X, y)])
|
await classifier.fit(X, y, eval_set=[(X, y)])
|
||||||
prediction = await classifier.predict(X)
|
prediction = await classifier.predict(X)
|
||||||
|
|||||||
@ -174,7 +174,7 @@ class TestPandas(unittest.TestCase):
|
|||||||
def test_cv_as_pandas(self):
|
def test_cv_as_pandas(self):
|
||||||
dm = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
dm = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||||
params = {'max_depth': 2, 'eta': 1, 'verbosity': 0,
|
params = {'max_depth': 2, 'eta': 1, 'verbosity': 0,
|
||||||
'objective': 'binary:logistic'}
|
'objective': 'binary:logistic', 'eval_metric': 'error'}
|
||||||
|
|
||||||
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10)
|
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10)
|
||||||
assert isinstance(cv, pd.DataFrame)
|
assert isinstance(cv, pd.DataFrame)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user