Deprecate reg:linear' in favor of reg:squarederror'. (#4267)
* Deprecate `reg:linear' in favor of `reg:squarederror'. * Replace the use of `reg:linear'. * Replace the use of `silent`.
This commit is contained in:
parent
cf8d5b9b76
commit
29a1356669
@ -209,13 +209,14 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, params) {
|
||||
if (exists('objective', where = params) &&
|
||||
is.character(params$objective)) {
|
||||
# If 'objective' provided in params, assume that y is a classification label
|
||||
# unless objective is reg:linear
|
||||
if (params$objective != 'reg:linear')
|
||||
# unless objective is reg:squarederror
|
||||
if (params$objective != 'reg:squarederror')
|
||||
y <- factor(y)
|
||||
} else {
|
||||
# If no 'objective' given in params, it means that user either wants to use
|
||||
# the default 'reg:linear' objective or has provided a custom obj function.
|
||||
# Here, assume classification setting when y has 5 or less unique values:
|
||||
# If no 'objective' given in params, it means that user either wants to
|
||||
# use the default 'reg:squarederror' objective or has provided a custom
|
||||
# obj function. Here, assume classification setting when y has 5 or less
|
||||
# unique values:
|
||||
if (length(unique(y)) <= 5)
|
||||
y <- factor(y)
|
||||
}
|
||||
|
||||
@ -6,7 +6,7 @@
|
||||
#' \itemize{
|
||||
#' \item \code{objective} objective function, common ones are
|
||||
#' \itemize{
|
||||
#' \item \code{reg:linear} linear regression
|
||||
#' \item \code{reg:squarederror} Regression with squared loss
|
||||
#' \item \code{binary:logistic} logistic regression for classification
|
||||
#' }
|
||||
#' \item \code{eta} step size of each boosting step
|
||||
|
||||
@ -42,7 +42,7 @@
|
||||
#' \itemize{
|
||||
#' \item \code{objective} specify the learning task and the corresponding learning objective, users can pass a self-defined function to it. The default objective options are below:
|
||||
#' \itemize{
|
||||
#' \item \code{reg:linear} linear regression (Default).
|
||||
#' \item \code{reg:squarederror} Regression with squared loss (Default).
|
||||
#' \item \code{reg:logistic} logistic regression.
|
||||
#' \item \code{binary:logistic} logistic regression for binary classification. Output probability.
|
||||
#' \item \code{binary:logitraw} logistic regression for binary classification, output score before logistic transformation.
|
||||
|
||||
@ -16,7 +16,7 @@ xgb.cv(params = list(), data, nrounds, nfold, label = NULL,
|
||||
\itemize{
|
||||
\item \code{objective} objective function, common ones are
|
||||
\itemize{
|
||||
\item \code{reg:linear} linear regression
|
||||
\item \code{reg:squarederror} Regression with squared loss.
|
||||
\item \code{binary:logistic} logistic regression for classification
|
||||
}
|
||||
\item \code{eta} step size of each boosting step
|
||||
|
||||
@ -56,7 +56,7 @@ xgboost(data = NULL, label = NULL, missing = NA, weight = NULL,
|
||||
\itemize{
|
||||
\item \code{objective} specify the learning task and the corresponding learning objective, users can pass a self-defined function to it. The default objective options are below:
|
||||
\itemize{
|
||||
\item \code{reg:linear} linear regression (Default).
|
||||
\item \code{reg:squarederror} Regression with squared loss (Default).
|
||||
\item \code{reg:logistic} logistic regression.
|
||||
\item \code{binary:logistic} logistic regression for binary classification. Output probability.
|
||||
\item \code{binary:logitraw} logistic regression for binary classification, output score before logistic transformation.
|
||||
@ -210,7 +210,7 @@ dtest <- xgb.DMatrix(agaricus.test$data, label = agaricus.test$label)
|
||||
watchlist <- list(train = dtrain, eval = dtest)
|
||||
|
||||
## A simple xgb.train example:
|
||||
param <- list(max_depth = 2, eta = 1, silent = 1, nthread = 2,
|
||||
param <- list(max_depth = 2, eta = 1, verbosity = 0, nthread = 2,
|
||||
objective = "binary:logistic", eval_metric = "auc")
|
||||
bst <- xgb.train(param, dtrain, nrounds = 2, watchlist)
|
||||
|
||||
@ -231,12 +231,12 @@ evalerror <- function(preds, dtrain) {
|
||||
|
||||
# These functions could be used by passing them either:
|
||||
# as 'objective' and 'eval_metric' parameters in the params list:
|
||||
param <- list(max_depth = 2, eta = 1, silent = 1, nthread = 2,
|
||||
param <- list(max_depth = 2, eta = 1, verbosity = 0, nthread = 2,
|
||||
objective = logregobj, eval_metric = evalerror)
|
||||
bst <- xgb.train(param, dtrain, nrounds = 2, watchlist)
|
||||
|
||||
# or through the ... arguments:
|
||||
param <- list(max_depth = 2, eta = 1, silent = 1, nthread = 2)
|
||||
param <- list(max_depth = 2, eta = 1, verbosity = 0, nthread = 2)
|
||||
bst <- xgb.train(param, dtrain, nrounds = 2, watchlist,
|
||||
objective = logregobj, eval_metric = evalerror)
|
||||
|
||||
@ -246,7 +246,7 @@ bst <- xgb.train(param, dtrain, nrounds = 2, watchlist,
|
||||
|
||||
|
||||
## An xgb.train example of using variable learning rates at each iteration:
|
||||
param <- list(max_depth = 2, eta = 1, silent = 1, nthread = 2,
|
||||
param <- list(max_depth = 2, eta = 1, verbosity = 0, nthread = 2,
|
||||
objective = "binary:logistic", eval_metric = "auc")
|
||||
my_etas <- list(eta = c(0.5, 0.1))
|
||||
bst <- xgb.train(param, dtrain, nrounds = 2, watchlist,
|
||||
|
||||
@ -98,7 +98,7 @@ test_that("SHAP contribution values are not NAN", {
|
||||
fit <- xgboost(
|
||||
verbose = 0,
|
||||
params = list(
|
||||
objective = "reg:linear",
|
||||
objective = "reg:squarederror",
|
||||
eval_metric = "rmse"),
|
||||
data = as.matrix(subset(d, fold == 2)[, ivs]),
|
||||
label = subset(d, fold == 2)$y,
|
||||
|
||||
@ -6,9 +6,9 @@ Using XGBoost for regression is very similar to using it for binary classificati
|
||||
The dataset we used is the [computer hardware dataset from UCI repository](https://archive.ics.uci.edu/ml/datasets/Computer+Hardware). The demo for regression is almost the same as the [binary classification demo](../binary_classification), except a little difference in general parameter:
|
||||
```
|
||||
# General parameter
|
||||
# this is the only difference with classification, use reg:linear to do linear regression
|
||||
# this is the only difference with classification, use reg:squarederror to do linear regression
|
||||
# when labels are in [0,1] we can also use reg:logistic
|
||||
objective = reg:linear
|
||||
objective = reg:squarederror
|
||||
...
|
||||
|
||||
```
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
# General Parameters, see comment for each definition
|
||||
# choose the tree booster, can also change to gblinear
|
||||
booster = gbtree
|
||||
# this is the only difference with classification, use reg:linear to do linear classification
|
||||
# this is the only difference with classification, use reg:squarederror to do linear classification
|
||||
# when labels are in [0,1] we can also use reg:logistic
|
||||
objective = reg:linear
|
||||
objective = reg:squarederror
|
||||
|
||||
# Tree Booster Parameters
|
||||
# step size shrinkage
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
# General Parameters, see comment for each definition
|
||||
# choose the tree booster, can also change to gblinear
|
||||
booster = gbtree
|
||||
# this is the only difference with classification, use reg:linear to do linear classification
|
||||
# this is the only difference with classification, use reg:squarederror to do linear classification
|
||||
# when labels are in [0,1] we can also use reg:logistic
|
||||
objective = reg:linear
|
||||
objective = reg:squarederror
|
||||
|
||||
# Tree Booster Parameters
|
||||
# step size shrinkage
|
||||
@ -27,4 +27,3 @@ data = "yearpredMSD.libsvm.train"
|
||||
eval[test] = "yearpredMSD.libsvm.test"
|
||||
# The path of test data
|
||||
#test:data = "yearpredMSD.libsvm.test"
|
||||
|
||||
|
||||
@ -92,7 +92,7 @@ Most of the objective functions implemented in XGBoost can be run on GPU. Follo
|
||||
+-----------------+-------------+
|
||||
| Objectives | GPU support |
|
||||
+-----------------+-------------+
|
||||
| reg:linear | |tick| |
|
||||
| reg:squarederror| |tick| |
|
||||
+-----------------+-------------+
|
||||
| reg:logistic | |tick| |
|
||||
+-----------------+-------------+
|
||||
|
||||
@ -293,9 +293,9 @@ Learning Task Parameters
|
||||
************************
|
||||
Specify the learning task and the corresponding learning objective. The objective options are below:
|
||||
|
||||
* ``objective`` [default=reg:linear]
|
||||
* ``objective`` [default=reg:squarederror]
|
||||
|
||||
- ``reg:linear``: linear regression
|
||||
- ``reg:squarederror``: regression with squared loss
|
||||
- ``reg:logistic``: logistic regression
|
||||
- ``binary:logistic``: logistic regression for binary classification, output probability
|
||||
- ``binary:logitraw``: logistic regression for binary classification, output score before logistic transformation
|
||||
|
||||
@ -36,7 +36,7 @@ The following parameters must be set to enable random forest training.
|
||||
|
||||
|
||||
Other parameters should be set in a similar way they are set for gradient boosting. For
|
||||
instance, ``objective`` will typically be ``reg:linear`` for regression and
|
||||
instance, ``objective`` will typically be ``reg:squarederror`` for regression and
|
||||
``binary:logistic`` for classification, ``lambda`` should be set according to a desired
|
||||
regularization weight, etc.
|
||||
|
||||
|
||||
@ -24,8 +24,8 @@ private[spark] trait LearningTaskParams extends Params {
|
||||
|
||||
/**
|
||||
* Specify the learning task and the corresponding learning objective.
|
||||
* options: reg:linear, reg:logistic, binary:logistic, binary:logitraw, count:poisson,
|
||||
* multi:softmax, multi:softprob, rank:pairwise, reg:gamma. default: reg:linear
|
||||
* options: reg:squarederror, reg:logistic, binary:logistic, binary:logitraw, count:poisson,
|
||||
* multi:softmax, multi:softprob, rank:pairwise, reg:gamma. default: reg:squarederror
|
||||
*/
|
||||
final val objective = new Param[String](this, "objective", "objective function used for " +
|
||||
s"training, options: {${LearningTaskParams.supportedObjective.mkString(",")}",
|
||||
@ -94,12 +94,12 @@ private[spark] trait LearningTaskParams extends Params {
|
||||
|
||||
final def getMaximizeEvaluationMetrics: Boolean = $(maximizeEvaluationMetrics)
|
||||
|
||||
setDefault(objective -> "reg:linear", baseScore -> 0.5,
|
||||
setDefault(objective -> "reg:squarederror", baseScore -> 0.5,
|
||||
trainTestRatio -> 1.0, numEarlyStoppingRounds -> 0)
|
||||
}
|
||||
|
||||
private[spark] object LearningTaskParams {
|
||||
val supportedObjective = HashSet("reg:linear", "reg:logistic", "binary:logistic",
|
||||
val supportedObjective = HashSet("reg:squarederror", "reg:logistic", "binary:logistic",
|
||||
"binary:logitraw", "count:poisson", "multi:softmax", "multi:softprob", "rank:pairwise",
|
||||
"rank:ndcg", "rank:map", "reg:gamma", "reg:tweedie")
|
||||
|
||||
|
||||
@ -96,7 +96,7 @@ class PersistenceSuite extends FunSuite with PerTest with BeforeAndAfterAll {
|
||||
val testDM = new DMatrix(Regression.test.iterator)
|
||||
|
||||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "reg:linear", "num_round" -> "10", "num_workers" -> numWorkers)
|
||||
"objective" -> "reg:squarederror", "num_round" -> "10", "num_workers" -> numWorkers)
|
||||
val xgbr = new XGBoostRegressor(paramMap)
|
||||
val xgbrPath = new File(tempDir, "xgbr").getPath
|
||||
xgbr.write.overwrite().save(xgbrPath)
|
||||
|
||||
@ -36,7 +36,7 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
|
||||
"eta" -> "1",
|
||||
"max_depth" -> "6",
|
||||
"silent" -> "1",
|
||||
"objective" -> "reg:linear")
|
||||
"objective" -> "reg:squarederror")
|
||||
|
||||
val model1 = ScalaXGBoost.train(trainingDM, paramMap, round)
|
||||
val prediction1 = model1.predict(testDM)
|
||||
@ -69,7 +69,7 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
|
||||
"eta" -> "1",
|
||||
"max_depth" -> "6",
|
||||
"silent" -> "1",
|
||||
"objective" -> "reg:linear",
|
||||
"objective" -> "reg:squarederror",
|
||||
"num_round" -> round,
|
||||
"num_workers" -> numWorkers)
|
||||
|
||||
@ -80,7 +80,7 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
|
||||
.setEta(1)
|
||||
.setMaxDepth(6)
|
||||
.setSilent(1)
|
||||
.setObjective("reg:linear")
|
||||
.setObjective("reg:squarederror")
|
||||
.setNumRound(round)
|
||||
.setNumWorkers(numWorkers)
|
||||
.fit(trainingDF)
|
||||
@ -108,7 +108,7 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
|
||||
|
||||
test("use weight") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "reg:linear", "num_round" -> 5, "num_workers" -> numWorkers)
|
||||
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers)
|
||||
|
||||
val getWeightFromId = udf({id: Int => if (id == 0) 1.0f else 0.001f}, DataTypes.FloatType)
|
||||
val trainingDF = buildDataFrame(Regression.train)
|
||||
@ -123,7 +123,7 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
|
||||
|
||||
test("test predictionLeaf") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "reg:linear", "num_round" -> 5, "num_workers" -> numWorkers)
|
||||
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers)
|
||||
val training = buildDataFrame(Regression.train)
|
||||
val testDF = buildDataFrame(Regression.test)
|
||||
val groundTruth = testDF.count()
|
||||
@ -137,7 +137,7 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
|
||||
|
||||
test("test predictionLeaf with empty column name") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "reg:linear", "num_round" -> 5, "num_workers" -> numWorkers)
|
||||
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers)
|
||||
val training = buildDataFrame(Regression.train)
|
||||
val testDF = buildDataFrame(Regression.test)
|
||||
val xgb = new XGBoostRegressor(paramMap)
|
||||
@ -149,7 +149,7 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
|
||||
|
||||
test("test predictionContrib") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "reg:linear", "num_round" -> 5, "num_workers" -> numWorkers)
|
||||
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers)
|
||||
val training = buildDataFrame(Regression.train)
|
||||
val testDF = buildDataFrame(Regression.test)
|
||||
val groundTruth = testDF.count()
|
||||
@ -163,7 +163,7 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
|
||||
|
||||
test("test predictionContrib with empty column name") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "reg:linear", "num_round" -> 5, "num_workers" -> numWorkers)
|
||||
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers)
|
||||
val training = buildDataFrame(Regression.train)
|
||||
val testDF = buildDataFrame(Regression.test)
|
||||
val xgb = new XGBoostRegressor(paramMap)
|
||||
@ -175,7 +175,7 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
|
||||
|
||||
test("test predictionLeaf and predictionContrib") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "reg:linear", "num_round" -> 5, "num_workers" -> numWorkers)
|
||||
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers)
|
||||
val training = buildDataFrame(Regression.train)
|
||||
val testDF = buildDataFrame(Regression.test)
|
||||
val groundTruth = testDF.count()
|
||||
|
||||
@ -128,8 +128,8 @@ class RegLossObj : public ObjFunction {
|
||||
// register the objective functions
|
||||
DMLC_REGISTER_PARAMETER(RegLossParam);
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(LinearRegression, "reg:linear")
|
||||
.describe("Linear regression.")
|
||||
XGBOOST_REGISTER_OBJECTIVE(SquaredLossRegression, "reg:squarederror")
|
||||
.describe("Regression with squared error.")
|
||||
.set_body([]() { return new RegLossObj<LinearSquareLoss>(); });
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(LogisticRegression, "reg:logistic")
|
||||
@ -145,7 +145,13 @@ XGBOOST_REGISTER_OBJECTIVE(LogisticRaw, "binary:logitraw")
|
||||
"before logistic transformation.")
|
||||
.set_body([]() { return new RegLossObj<LogisticRaw>(); });
|
||||
|
||||
// Deprecated GPU functions
|
||||
// Deprecated functions
|
||||
XGBOOST_REGISTER_OBJECTIVE(LinearRegression, "reg:linear")
|
||||
.describe("Regression with squared error.")
|
||||
.set_body([]() {
|
||||
LOG(WARNING) << "reg:linear is now deprecated in favor of reg:squarederror.";
|
||||
return new RegLossObj<LinearSquareLoss>(); });
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(GPULinearRegression, "gpu:reg:linear")
|
||||
.describe("Deprecated. Linear regression (computed on GPU).")
|
||||
.set_body([]() {
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
// Copyright by Contributors
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/objective.h>
|
||||
|
||||
#include "../helpers.h"
|
||||
@ -6,7 +7,7 @@
|
||||
TEST(Objective, UnknownFunction) {
|
||||
xgboost::ObjFunction* obj = nullptr;
|
||||
EXPECT_ANY_THROW(obj = xgboost::ObjFunction::Create("unknown_name"));
|
||||
EXPECT_NO_THROW(obj = xgboost::ObjFunction::Create("reg:linear"));
|
||||
EXPECT_NO_THROW(obj = xgboost::ObjFunction::Create("reg:squarederror"));
|
||||
if (obj) {
|
||||
delete obj;
|
||||
}
|
||||
|
||||
@ -1,12 +1,13 @@
|
||||
/*!
|
||||
* Copyright 2017-2018 XGBoost contributors
|
||||
* Copyright 2017-2019 XGBoost contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/objective.h>
|
||||
|
||||
#include "../helpers.h"
|
||||
|
||||
TEST(Objective, DeclareUnifiedTest(LinearRegressionGPair)) {
|
||||
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("reg:linear");
|
||||
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("reg:squarederror");
|
||||
std::vector<std::pair<std::string, std::string> > args;
|
||||
obj->Configure(args);
|
||||
CheckObjFunction(obj,
|
||||
|
||||
@ -132,15 +132,16 @@ def run_suite(param, num_rounds=10, select_datasets=None, scale_features=False):
|
||||
Run the given parameters on a range of datasets. Objective and eval metric will be automatically set
|
||||
"""
|
||||
datasets = [
|
||||
Dataset("Boston", get_boston, "reg:linear", "rmse"),
|
||||
Dataset("Boston", get_boston, "reg:squarederror", "rmse"),
|
||||
Dataset("Digits", get_digits, "multi:softmax", "merror"),
|
||||
Dataset("Cancer", get_cancer, "binary:logistic", "error"),
|
||||
Dataset("Sparse regression", get_sparse, "reg:linear", "rmse"),
|
||||
Dataset("Sparse regression", get_sparse, "reg:squarederror", "rmse"),
|
||||
Dataset("Sparse regression with weights", get_sparse_weights,
|
||||
"reg:linear", "rmse", has_weights=True),
|
||||
"reg:squarederror", "rmse", has_weights=True),
|
||||
Dataset("Small weights regression", get_small_weights,
|
||||
"reg:linear", "rmse", has_weights=True),
|
||||
Dataset("Boston External Memory", get_boston, "reg:linear", "rmse",
|
||||
"reg:squarederror", "rmse", has_weights=True),
|
||||
Dataset("Boston External Memory", get_boston,
|
||||
"reg:squarederror", "rmse",
|
||||
use_external_memory=True)
|
||||
]
|
||||
|
||||
|
||||
@ -38,7 +38,7 @@ class TestBasic(unittest.TestCase):
|
||||
def test_basic(self):
|
||||
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
|
||||
param = {'max_depth': 2, 'eta': 1, 'silent': 1,
|
||||
param = {'max_depth': 2, 'eta': 1, 'verbosity': 0,
|
||||
'objective': 'binary:logistic'}
|
||||
# specify validations set to watch performance
|
||||
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
||||
@ -85,7 +85,7 @@ class TestBasic(unittest.TestCase):
|
||||
def test_record_results(self):
|
||||
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
|
||||
param = {'max_depth': 2, 'eta': 1, 'silent': 1,
|
||||
param = {'max_depth': 2, 'eta': 1, 'verbosity': 0,
|
||||
'objective': 'binary:logistic'}
|
||||
# specify validations set to watch performance
|
||||
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
||||
@ -102,7 +102,7 @@ class TestBasic(unittest.TestCase):
|
||||
def test_multiclass(self):
|
||||
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
|
||||
param = {'max_depth': 2, 'eta': 1, 'silent': 1, 'num_class': 2}
|
||||
param = {'max_depth': 2, 'eta': 1, 'verbosity': 0, 'num_class': 2}
|
||||
# specify validations set to watch performance
|
||||
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
||||
num_round = 2
|
||||
@ -273,7 +273,7 @@ class TestBasic(unittest.TestCase):
|
||||
|
||||
def test_cv(self):
|
||||
dm = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||
params = {'max_depth': 2, 'eta': 1, 'silent': 1,
|
||||
params = {'max_depth': 2, 'eta': 1, 'verbosity': 0,
|
||||
'objective': 'binary:logistic'}
|
||||
|
||||
# return np.ndarray
|
||||
@ -283,7 +283,7 @@ class TestBasic(unittest.TestCase):
|
||||
|
||||
def test_cv_no_shuffle(self):
|
||||
dm = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||
params = {'max_depth': 2, 'eta': 1, 'silent': 1,
|
||||
params = {'max_depth': 2, 'eta': 1, 'verbosity': 0,
|
||||
'objective': 'binary:logistic'}
|
||||
|
||||
# return np.ndarray
|
||||
@ -294,7 +294,7 @@ class TestBasic(unittest.TestCase):
|
||||
|
||||
def test_cv_explicit_fold_indices(self):
|
||||
dm = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||
params = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective':
|
||||
params = {'max_depth': 2, 'eta': 1, 'verbosity': 0, 'objective':
|
||||
'binary:logistic'}
|
||||
folds = [
|
||||
# Train Test
|
||||
@ -310,7 +310,7 @@ class TestBasic(unittest.TestCase):
|
||||
|
||||
def test_cv_explicit_fold_indices_labels(self):
|
||||
params = {'max_depth': 2, 'eta': 1, 'verbosity': 0, 'objective':
|
||||
'reg:linear'}
|
||||
'reg:squarederror'}
|
||||
N = 100
|
||||
F = 3
|
||||
dm = xgb.DMatrix(data=np.random.randn(N, F), label=np.arange(N))
|
||||
|
||||
@ -11,7 +11,7 @@ rng = np.random.RandomState(1994)
|
||||
|
||||
class TestModels(unittest.TestCase):
|
||||
def test_glm(self):
|
||||
param = {'silent': 1, 'objective': 'binary:logistic',
|
||||
param = {'verbosity': 0, 'objective': 'binary:logistic',
|
||||
'booster': 'gblinear', 'alpha': 0.0001, 'lambda': 1, 'nthread': 1}
|
||||
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
||||
num_round = 4
|
||||
@ -26,7 +26,7 @@ class TestModels(unittest.TestCase):
|
||||
def test_dart(self):
|
||||
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
|
||||
param = {'max_depth': 5, 'objective': 'binary:logistic', 'booster': 'dart', 'silent': False}
|
||||
param = {'max_depth': 5, 'objective': 'binary:logistic', 'booster': 'dart', 'verbosity': 1}
|
||||
# specify validations set to watch performance
|
||||
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
||||
num_round = 2
|
||||
@ -51,7 +51,7 @@ class TestModels(unittest.TestCase):
|
||||
|
||||
# check whether sample_type and normalize_type work
|
||||
num_round = 50
|
||||
param['silent'] = True
|
||||
param['verbosity'] = 0
|
||||
param['learning_rate'] = 0.1
|
||||
param['rate_drop'] = 0.1
|
||||
preds_list = []
|
||||
@ -74,7 +74,8 @@ class TestModels(unittest.TestCase):
|
||||
|
||||
# learning_rates as a list
|
||||
# init eta with 0 to check whether learning_rates work
|
||||
param = {'max_depth': 2, 'eta': 0, 'silent': 1, 'objective': 'binary:logistic'}
|
||||
param = {'max_depth': 2, 'eta': 0, 'verbosity': 0,
|
||||
'objective': 'binary:logistic'}
|
||||
evals_result = {}
|
||||
bst = xgb.train(param, dtrain, num_round, watchlist, learning_rates=[0.8, 0.7, 0.6, 0.5],
|
||||
evals_result=evals_result)
|
||||
@ -84,7 +85,8 @@ class TestModels(unittest.TestCase):
|
||||
assert eval_errors[0] > eval_errors[-1]
|
||||
|
||||
# init learning_rate with 0 to check whether learning_rates work
|
||||
param = {'max_depth': 2, 'learning_rate': 0, 'silent': 1, 'objective': 'binary:logistic'}
|
||||
param = {'max_depth': 2, 'learning_rate': 0, 'verbosity': 0,
|
||||
'objective': 'binary:logistic'}
|
||||
evals_result = {}
|
||||
bst = xgb.train(param, dtrain, num_round, watchlist, learning_rates=[0.8, 0.7, 0.6, 0.5],
|
||||
evals_result=evals_result)
|
||||
@ -94,7 +96,7 @@ class TestModels(unittest.TestCase):
|
||||
assert eval_errors[0] > eval_errors[-1]
|
||||
|
||||
# check if learning_rates override default value of eta/learning_rate
|
||||
param = {'max_depth': 2, 'silent': 1, 'objective': 'binary:logistic'}
|
||||
param = {'max_depth': 2, 'verbosity': 0, 'objective': 'binary:logistic'}
|
||||
evals_result = {}
|
||||
bst = xgb.train(param, dtrain, num_round, watchlist, learning_rates=[0, 0, 0, 0],
|
||||
evals_result=evals_result)
|
||||
@ -111,7 +113,7 @@ class TestModels(unittest.TestCase):
|
||||
assert isinstance(bst, xgb.core.Booster)
|
||||
|
||||
def test_custom_objective(self):
|
||||
param = {'max_depth': 2, 'eta': 1, 'silent': 1}
|
||||
param = {'max_depth': 2, 'eta': 1, 'verbosity': 0}
|
||||
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
||||
num_round = 2
|
||||
|
||||
@ -152,7 +154,8 @@ class TestModels(unittest.TestCase):
|
||||
|
||||
def test_multi_eval_metric(self):
|
||||
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
||||
param = {'max_depth': 2, 'eta': 0.2, 'silent': 1, 'objective': 'binary:logistic'}
|
||||
param = {'max_depth': 2, 'eta': 0.2, 'verbosity': 0,
|
||||
'objective': 'binary:logistic'}
|
||||
param['eval_metric'] = ["auc", "logloss", 'error']
|
||||
evals_result = {}
|
||||
bst = xgb.train(param, dtrain, 4, watchlist, evals_result=evals_result)
|
||||
@ -161,7 +164,7 @@ class TestModels(unittest.TestCase):
|
||||
assert set(evals_result['eval'].keys()) == {'auc', 'error', 'logloss'}
|
||||
|
||||
def test_fpreproc(self):
|
||||
param = {'max_depth': 2, 'eta': 1, 'silent': 1,
|
||||
param = {'max_depth': 2, 'eta': 1, 'verbosity': 0,
|
||||
'objective': 'binary:logistic'}
|
||||
num_round = 2
|
||||
|
||||
@ -175,7 +178,7 @@ class TestModels(unittest.TestCase):
|
||||
metrics={'auc'}, seed=0, fpreproc=fpreproc)
|
||||
|
||||
def test_show_stdv(self):
|
||||
param = {'max_depth': 2, 'eta': 1, 'silent': 1,
|
||||
param = {'max_depth': 2, 'eta': 1, 'verbosity': 0,
|
||||
'objective': 'binary:logistic'}
|
||||
num_round = 2
|
||||
xgb.cv(param, dtrain, num_round, nfold=5,
|
||||
|
||||
@ -52,7 +52,7 @@ class TestEarlyStopping(unittest.TestCase):
|
||||
X = digits['data']
|
||||
y = digits['target']
|
||||
dm = xgb.DMatrix(X, label=y)
|
||||
params = {'max_depth': 2, 'eta': 1, 'silent': 1,
|
||||
params = {'max_depth': 2, 'eta': 1, 'verbosity': 0,
|
||||
'objective': 'binary:logistic'}
|
||||
|
||||
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10,
|
||||
|
||||
@ -9,25 +9,25 @@ rng = np.random.RandomState(1337)
|
||||
|
||||
class TestEvalMetrics(unittest.TestCase):
|
||||
xgb_params_01 = {
|
||||
'silent': 1,
|
||||
'verbosity': 0,
|
||||
'nthread': 1,
|
||||
'eval_metric': 'error'
|
||||
}
|
||||
|
||||
xgb_params_02 = {
|
||||
'silent': 1,
|
||||
'verbosity': 0,
|
||||
'nthread': 1,
|
||||
'eval_metric': ['error']
|
||||
}
|
||||
|
||||
xgb_params_03 = {
|
||||
'silent': 1,
|
||||
'verbosity': 0,
|
||||
'nthread': 1,
|
||||
'eval_metric': ['rmse', 'error']
|
||||
}
|
||||
|
||||
xgb_params_04 = {
|
||||
'silent': 1,
|
||||
'verbosity': 0,
|
||||
'nthread': 1,
|
||||
'eval_metric': ['error', 'rmse']
|
||||
}
|
||||
|
||||
@ -18,7 +18,7 @@ class TestInteractionConstraints(unittest.TestCase):
|
||||
X = np.column_stack((x1, x2, x3))
|
||||
dtrain = xgboost.DMatrix(X, label=y)
|
||||
|
||||
params = {'max_depth': 3, 'eta': 0.1, 'nthread': 2, 'silent': 1,
|
||||
params = {'max_depth': 3, 'eta': 0.1, 'nthread': 2, 'verbosity': 0,
|
||||
'interaction_constraints': '[[0, 1]]'}
|
||||
num_boost_round = 100
|
||||
# Fit a model that only allows interaction between x1 and x2
|
||||
|
||||
@ -30,7 +30,7 @@ def xgb_get_weights(bst):
|
||||
|
||||
def assert_regression_result(results, tol):
|
||||
regression_results = [r for r in results if
|
||||
r["param"]["objective"] == "reg:linear"]
|
||||
r["param"]["objective"] == "reg:squarederror"]
|
||||
for res in regression_results:
|
||||
X = scale(res["dataset"].X,
|
||||
with_mean=isinstance(res["dataset"].X, np.ndarray))
|
||||
@ -52,7 +52,7 @@ def assert_regression_result(results, tol):
|
||||
# TODO: More robust classification tests
|
||||
def assert_classification_result(results):
|
||||
classification_results = [r for r in results if
|
||||
r["param"]["objective"] != "reg:linear"]
|
||||
r["param"]["objective"] != "reg:squarederror"]
|
||||
for res in classification_results:
|
||||
# Check accuracy is reasonable
|
||||
assert res["eval"][-1] < 0.5, (res["dataset"].name, res["eval"][-1])
|
||||
|
||||
@ -16,7 +16,8 @@ class TestTreesToDataFrame(unittest.TestCase):
|
||||
|
||||
def build_model(self, max_depth, num_round):
|
||||
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||
param = {'max_depth': max_depth, 'objective': 'binary:logistic', 'silent': False}
|
||||
param = {'max_depth': max_depth, 'objective': 'binary:logistic',
|
||||
'verbosity': 1}
|
||||
num_round = num_round
|
||||
bst = xgb.train(param, dtrain, num_round)
|
||||
return bst
|
||||
|
||||
@ -51,7 +51,7 @@ class TestSHAP(unittest.TestCase):
|
||||
|
||||
def fn(max_depth, num_rounds):
|
||||
# train
|
||||
params = {'max_depth': max_depth, 'eta': 1, 'silent': 1}
|
||||
params = {'max_depth': max_depth, 'eta': 1, 'verbosity': 0}
|
||||
bst = xgb.train(params, dtrain, num_boost_round=num_rounds)
|
||||
|
||||
# predict
|
||||
|
||||
@ -4,7 +4,7 @@ from scipy.sparse import rand
|
||||
|
||||
rng = np.random.RandomState(1)
|
||||
|
||||
param = {'max_depth': 3, 'objective': 'binary:logistic', 'silent': 1}
|
||||
param = {'max_depth': 3, 'objective': 'binary:logistic', 'verbosity': 0}
|
||||
|
||||
|
||||
def test_sparse_dmatrix_csr():
|
||||
|
||||
@ -11,18 +11,18 @@ class TestTrainingContinuation(unittest.TestCase):
|
||||
num_parallel_tree = 3
|
||||
|
||||
xgb_params_01 = {
|
||||
'silent': 1,
|
||||
'verbosity': 0,
|
||||
'nthread': 1,
|
||||
}
|
||||
|
||||
xgb_params_02 = {
|
||||
'silent': 1,
|
||||
'verbosity': 0,
|
||||
'nthread': 1,
|
||||
'num_parallel_tree': num_parallel_tree
|
||||
}
|
||||
|
||||
xgb_params_03 = {
|
||||
'silent': 1,
|
||||
'verbosity': 0,
|
||||
'nthread': 1,
|
||||
'num_class': 5,
|
||||
'num_parallel_tree': num_parallel_tree
|
||||
|
||||
@ -10,7 +10,8 @@ train_data = xgb.DMatrix(np.array([[1]]), label=np.array([1]))
|
||||
class TestTreeRegularization(unittest.TestCase):
|
||||
def test_alpha(self):
|
||||
params = {
|
||||
'tree_method': 'exact', 'silent': 1, 'objective': 'reg:linear',
|
||||
'tree_method': 'exact', 'verbosity': 0,
|
||||
'objective': 'reg:squarederror',
|
||||
'eta': 1,
|
||||
'lambda': 0,
|
||||
'alpha': 0.1
|
||||
@ -27,7 +28,8 @@ class TestTreeRegularization(unittest.TestCase):
|
||||
|
||||
def test_lambda(self):
|
||||
params = {
|
||||
'tree_method': 'exact', 'silent': 1, 'objective': 'reg:linear',
|
||||
'tree_method': 'exact', 'verbosity': 0,
|
||||
'objective': 'reg:squarederror',
|
||||
'eta': 1,
|
||||
'lambda': 1,
|
||||
'alpha': 0
|
||||
@ -44,7 +46,8 @@ class TestTreeRegularization(unittest.TestCase):
|
||||
|
||||
def test_alpha_and_lambda(self):
|
||||
params = {
|
||||
'tree_method': 'exact', 'silent': 1, 'objective': 'reg:linear',
|
||||
'tree_method': 'exact', 'verbosity': 1,
|
||||
'objective': 'reg:squarederror',
|
||||
'eta': 1,
|
||||
'lambda': 1,
|
||||
'alpha': 0.1
|
||||
|
||||
@ -33,7 +33,7 @@ class TestUpdaters(unittest.TestCase):
|
||||
'max_bin': [2, 256],
|
||||
'grow_policy': ['depthwise', 'lossguide'],
|
||||
'max_leaves': [64, 0],
|
||||
'silent': [1]}
|
||||
'verbosity': [0]}
|
||||
for param in parameter_combinations(variable_param):
|
||||
result = run_suite(param)
|
||||
assert_results_non_increasing(result, 1e-2)
|
||||
@ -45,7 +45,7 @@ class TestUpdaters(unittest.TestCase):
|
||||
ag_param = {'max_depth': 2,
|
||||
'tree_method': 'hist',
|
||||
'eta': 1,
|
||||
'silent': 1,
|
||||
'verbosity': 0,
|
||||
'objective': 'binary:logistic',
|
||||
'eval_metric': 'auc'}
|
||||
hist_res = {}
|
||||
|
||||
@ -120,7 +120,7 @@ class TestPandas(unittest.TestCase):
|
||||
|
||||
def test_cv_as_pandas(self):
|
||||
dm = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||
params = {'max_depth': 2, 'eta': 1, 'silent': 1,
|
||||
params = {'max_depth': 2, 'eta': 1, 'verbosity': 0,
|
||||
'objective': 'binary:logistic'}
|
||||
|
||||
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10)
|
||||
@ -143,19 +143,19 @@ class TestPandas(unittest.TestCase):
|
||||
u'train-error-mean', u'train-error-std'])
|
||||
assert cv.columns.equals(exp)
|
||||
|
||||
params = {'max_depth': 2, 'eta': 1, 'silent': 1,
|
||||
params = {'max_depth': 2, 'eta': 1, 'verbosity': 0,
|
||||
'objective': 'binary:logistic', 'eval_metric': 'auc'}
|
||||
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10, as_pandas=True)
|
||||
assert 'eval_metric' in params
|
||||
assert 'auc' in cv.columns[0]
|
||||
|
||||
params = {'max_depth': 2, 'eta': 1, 'silent': 1,
|
||||
params = {'max_depth': 2, 'eta': 1, 'verbosity': 0,
|
||||
'objective': 'binary:logistic', 'eval_metric': ['auc']}
|
||||
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10, as_pandas=True)
|
||||
assert 'eval_metric' in params
|
||||
assert 'auc' in cv.columns[0]
|
||||
|
||||
params = {'max_depth': 2, 'eta': 1, 'silent': 1,
|
||||
params = {'max_depth': 2, 'eta': 1, 'verbosity': 0,
|
||||
'objective': 'binary:logistic', 'eval_metric': ['auc']}
|
||||
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10,
|
||||
as_pandas=True, early_stopping_rounds=1)
|
||||
@ -163,19 +163,19 @@ class TestPandas(unittest.TestCase):
|
||||
assert 'auc' in cv.columns[0]
|
||||
assert cv.shape[0] < 10
|
||||
|
||||
params = {'max_depth': 2, 'eta': 1, 'silent': 1,
|
||||
params = {'max_depth': 2, 'eta': 1, 'verbosity': 0,
|
||||
'objective': 'binary:logistic'}
|
||||
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10,
|
||||
as_pandas=True, metrics='auc')
|
||||
assert 'auc' in cv.columns[0]
|
||||
|
||||
params = {'max_depth': 2, 'eta': 1, 'silent': 1,
|
||||
params = {'max_depth': 2, 'eta': 1, 'verbosity': 0,
|
||||
'objective': 'binary:logistic'}
|
||||
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10,
|
||||
as_pandas=True, metrics=['auc'])
|
||||
assert 'auc' in cv.columns[0]
|
||||
|
||||
params = {'max_depth': 2, 'eta': 1, 'silent': 1,
|
||||
params = {'max_depth': 2, 'eta': 1, 'verbosity': 0,
|
||||
'objective': 'binary:logistic', 'eval_metric': ['auc']}
|
||||
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10,
|
||||
as_pandas=True, metrics='error')
|
||||
|
||||
@ -603,7 +603,8 @@ def test_RFECV():
|
||||
# Regression
|
||||
X, y = load_boston(return_X_y=True)
|
||||
bst = xgb.XGBClassifier(booster='gblinear', learning_rate=0.1,
|
||||
n_estimators=10, n_jobs=1, objective='reg:linear',
|
||||
n_estimators=10, n_jobs=1,
|
||||
objective='reg:squarederror',
|
||||
random_state=0, verbosity=0)
|
||||
rfecv = RFECV(
|
||||
estimator=bst, step=1, cv=3, scoring='neg_mean_squared_error')
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user