[jvm-packges] set the correct objective if user doesn't explicitly set it (#7781)
This commit is contained in:
parent
806c92c80b
commit
5ef33adf68
@ -169,6 +169,23 @@ class XGBoostClassifier (
|
||||
}
|
||||
|
||||
override protected def train(dataset: Dataset[_]): XGBoostClassificationModel = {
|
||||
val _numClasses = getNumClasses(dataset)
|
||||
if (isDefined(numClass) && $(numClass) != _numClasses) {
|
||||
throw new Exception("The number of classes in dataset doesn't match " +
|
||||
"\'num_class\' in xgboost params.")
|
||||
}
|
||||
|
||||
if (_numClasses == 2) {
|
||||
if (!isDefined(objective)) {
|
||||
// If user doesn't set objective, force it to binary:logistic
|
||||
setObjective("binary:logistic")
|
||||
}
|
||||
} else if (_numClasses > 2) {
|
||||
if (!isDefined(objective)) {
|
||||
// If user doesn't set objective, force it to multi:softprob
|
||||
setObjective("multi:softprob")
|
||||
}
|
||||
}
|
||||
|
||||
if (!isDefined(evalMetric) || $(evalMetric).isEmpty) {
|
||||
set(evalMetric, setupDefaultEvalMetric())
|
||||
@ -178,12 +195,6 @@ class XGBoostClassifier (
|
||||
set(objectiveType, "classification")
|
||||
}
|
||||
|
||||
val _numClasses = getNumClasses(dataset)
|
||||
if (isDefined(numClass) && $(numClass) != _numClasses) {
|
||||
throw new Exception("The number of classes in dataset doesn't match " +
|
||||
"\'num_class\' in xgboost params.")
|
||||
}
|
||||
|
||||
// Packing with all params plus params user defined
|
||||
val derivedXGBParamMap = xgboostParams ++ MLlib2XGBoostParams
|
||||
val buildTrainingData = PreXGBoost.buildDatasetToRDD(this, dataset, derivedXGBParamMap)
|
||||
|
||||
@ -169,6 +169,11 @@ class XGBoostRegressor (
|
||||
|
||||
override protected def train(dataset: Dataset[_]): XGBoostRegressionModel = {
|
||||
|
||||
if (!isDefined(objective)) {
|
||||
// If user doesn't set objective, force it to reg:squarederror
|
||||
setObjective("reg:squarederror")
|
||||
}
|
||||
|
||||
if (!isDefined(evalMetric) || $(evalMetric).isEmpty) {
|
||||
set(evalMetric, setupDefaultEvalMetric())
|
||||
}
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014-2022 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -105,7 +105,7 @@ private[spark] trait LearningTaskParams extends Params {
|
||||
|
||||
final def getMaximizeEvaluationMetrics: Boolean = $(maximizeEvaluationMetrics)
|
||||
|
||||
setDefault(objective -> "reg:squarederror", baseScore -> 0.5, trainTestRatio -> 1.0,
|
||||
setDefault(baseScore -> 0.5, trainTestRatio -> 1.0,
|
||||
numEarlyStoppingRounds -> 0, cacheTrainingSet -> false)
|
||||
}
|
||||
|
||||
|
||||
@ -138,7 +138,7 @@ class PersistenceSuite extends FunSuite with TmpFolderPerSuite with PerTest {
|
||||
val testDM = new DMatrix(Classification.test.iterator)
|
||||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
||||
"custom_eval" -> new EvalError, "custom_obj" -> new CustomObj(1),
|
||||
"num_round" -> "10", "num_workers" -> numWorkers)
|
||||
"num_round" -> "10", "num_workers" -> numWorkers, "objective" -> "binary:logistic")
|
||||
|
||||
val xgbc = new XGBoostClassifier(paramMap)
|
||||
val xgbcPath = new File(tempDir.toFile, "xgbc").getPath
|
||||
|
||||
@ -112,6 +112,34 @@ class XGBoostClassifierSuite extends FunSuite with PerTest with TmpFolderPerSuit
|
||||
assert(!transformedDf.columns.contains("probability"))
|
||||
}
|
||||
|
||||
test("objective will be set if not specifying it") {
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6",
|
||||
"num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod)
|
||||
val xgb = new XGBoostClassifier(paramMap)
|
||||
assert(!xgb.isDefined(xgb.objective))
|
||||
xgb.fit(training)
|
||||
assert(xgb.getObjective == "binary:logistic")
|
||||
|
||||
val trainingDF = buildDataFrame(MultiClassification.train)
|
||||
val paramMap1 = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
||||
"num_class" -> "6", "num_round" -> 5, "num_workers" -> numWorkers,
|
||||
"tree_method" -> treeMethod)
|
||||
val xgb1 = new XGBoostClassifier(paramMap1)
|
||||
assert(!xgb1.isDefined(xgb1.objective))
|
||||
xgb1.fit(trainingDF)
|
||||
assert(xgb1.getObjective == "multi:softprob")
|
||||
|
||||
// shouldn't change user's objective setting
|
||||
val paramMap2 = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
||||
"num_class" -> "6", "num_round" -> 5, "num_workers" -> numWorkers,
|
||||
"tree_method" -> treeMethod, "objective" -> "multi:softmax")
|
||||
val xgb2 = new XGBoostClassifier(paramMap2)
|
||||
assert(xgb2.getObjective == "multi:softmax")
|
||||
xgb2.fit(trainingDF)
|
||||
assert(xgb2.getObjective == "multi:softmax")
|
||||
}
|
||||
|
||||
test("use base margin") {
|
||||
val training1 = buildDataFrame(Classification.train)
|
||||
val training2 = training1.withColumn("margin", functions.rand())
|
||||
|
||||
@ -146,6 +146,24 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
|
||||
prediction.foreach(x => assert(math.abs(x.getAs[Double]("prediction") - first) <= 0.01f))
|
||||
}
|
||||
|
||||
test("objective will be set if not specifying it") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod)
|
||||
val training = buildDataFrame(Regression.train)
|
||||
val xgb = new XGBoostRegressor(paramMap)
|
||||
assert(!xgb.isDefined(xgb.objective))
|
||||
xgb.fit(training)
|
||||
assert(xgb.getObjective == "reg:squarederror")
|
||||
|
||||
val paramMap1 = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod,
|
||||
"objective" -> "reg:squaredlogerror")
|
||||
val xgb1 = new XGBoostRegressor(paramMap1)
|
||||
assert(xgb1.getObjective == "reg:squaredlogerror")
|
||||
xgb1.fit(training)
|
||||
assert(xgb1.getObjective == "reg:squaredlogerror")
|
||||
}
|
||||
|
||||
test("test predictionLeaf") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user