Change obj name to reg:squarederror in learner. (#4427)

* Change memory dump size in R test.
This commit is contained in:
Jiaming Yuan 2019-05-06 21:35:35 +08:00 committed by GitHub
parent 8d1098a983
commit 5de7e12704
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 5 additions and 5 deletions

View File

@ -31,7 +31,7 @@ num_round <- 2
test_that("custom objective works", { test_that("custom objective works", {
bst <- xgb.train(param, dtrain, num_round, watchlist) bst <- xgb.train(param, dtrain, num_round, watchlist)
expect_equal(class(bst), "xgb.Booster") expect_equal(class(bst), "xgb.Booster")
expect_equal(length(bst$raw), 1094) expect_equal(length(bst$raw), 1100)
expect_false(is.null(bst$evaluation_log)) expect_false(is.null(bst$evaluation_log))
expect_false(is.null(bst$evaluation_log$eval_error)) expect_false(is.null(bst$evaluation_log$eval_error))
expect_lt(bst$evaluation_log[num_round, eval_error], 0.03) expect_lt(bst$evaluation_log[num_round, eval_error], 0.03)
@ -45,7 +45,7 @@ test_that("custom objective in CV works", {
}) })
test_that("custom objective using DMatrix attr works", { test_that("custom objective using DMatrix attr works", {
attr(dtrain, 'label') <- getinfo(dtrain, 'label') attr(dtrain, 'label') <- getinfo(dtrain, 'label')
logregobjattr <- function(preds, dtrain) { logregobjattr <- function(preds, dtrain) {
@ -58,5 +58,5 @@ test_that("custom objective using DMatrix attr works", {
param$objective = logregobjattr param$objective = logregobjattr
bst <- xgb.train(param, dtrain, num_round, watchlist) bst <- xgb.train(param, dtrain, num_round, watchlist)
expect_equal(class(bst), "xgb.Booster") expect_equal(class(bst), "xgb.Booster")
expect_equal(length(bst$raw), 1094) expect_equal(length(bst$raw), 1100)
}) })

View File

@ -171,7 +171,7 @@ class LearnerImpl : public Learner {
explicit LearnerImpl(std::vector<std::shared_ptr<DMatrix> > cache) explicit LearnerImpl(std::vector<std::shared_ptr<DMatrix> > cache)
: cache_(std::move(cache)) { : cache_(std::move(cache)) {
// boosted tree // boosted tree
name_obj_ = "reg:linear"; name_obj_ = "reg:squarederror";
name_gbm_ = "gbtree"; name_gbm_ = "gbtree";
} }
@ -281,7 +281,7 @@ class LearnerImpl : public Learner {
} }
if (cfg_.count("objective") == 0) { if (cfg_.count("objective") == 0) {
cfg_["objective"] = "reg:linear"; cfg_["objective"] = "reg:squarederror";
} }
if (cfg_.count("booster") == 0) { if (cfg_.count("booster") == 0) {
cfg_["booster"] = "gbtree"; cfg_["booster"] = "gbtree";