[jvm-packges] set the correct objective if user doesn't explicitly set it (#7781)

This commit is contained in:
Bobby Wang
2022-05-18 14:05:18 +08:00
committed by GitHub
parent 806c92c80b
commit 5ef33adf68
6 changed files with 71 additions and 9 deletions

View File

@@ -138,7 +138,7 @@ class PersistenceSuite extends FunSuite with TmpFolderPerSuite with PerTest {
val testDM = new DMatrix(Classification.test.iterator)
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"custom_eval" -> new EvalError, "custom_obj" -> new CustomObj(1),
"num_round" -> "10", "num_workers" -> numWorkers)
"num_round" -> "10", "num_workers" -> numWorkers, "objective" -> "binary:logistic")
val xgbc = new XGBoostClassifier(paramMap)
val xgbcPath = new File(tempDir.toFile, "xgbc").getPath

View File

@@ -112,6 +112,34 @@ class XGBoostClassifierSuite extends FunSuite with PerTest with TmpFolderPerSuit
assert(!transformedDf.columns.contains("probability"))
}
test("objective will be set if not specifying it") {
val training = buildDataFrame(Classification.train)
val paramMap = Map("eta" -> "1", "max_depth" -> "6",
"num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod)
val xgb = new XGBoostClassifier(paramMap)
assert(!xgb.isDefined(xgb.objective))
xgb.fit(training)
assert(xgb.getObjective == "binary:logistic")
val trainingDF = buildDataFrame(MultiClassification.train)
val paramMap1 = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"num_class" -> "6", "num_round" -> 5, "num_workers" -> numWorkers,
"tree_method" -> treeMethod)
val xgb1 = new XGBoostClassifier(paramMap1)
assert(!xgb1.isDefined(xgb1.objective))
xgb1.fit(trainingDF)
assert(xgb1.getObjective == "multi:softprob")
// shouldn't change user's objective setting
val paramMap2 = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"num_class" -> "6", "num_round" -> 5, "num_workers" -> numWorkers,
"tree_method" -> treeMethod, "objective" -> "multi:softmax")
val xgb2 = new XGBoostClassifier(paramMap2)
assert(xgb2.getObjective == "multi:softmax")
xgb2.fit(trainingDF)
assert(xgb2.getObjective == "multi:softmax")
}
test("use base margin") {
val training1 = buildDataFrame(Classification.train)
val training2 = training1.withColumn("margin", functions.rand())

View File

@@ -146,6 +146,24 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
prediction.foreach(x => assert(math.abs(x.getAs[Double]("prediction") - first) <= 0.01f))
}
test("objective will be set if not specifying it") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod)
val training = buildDataFrame(Regression.train)
val xgb = new XGBoostRegressor(paramMap)
assert(!xgb.isDefined(xgb.objective))
xgb.fit(training)
assert(xgb.getObjective == "reg:squarederror")
val paramMap1 = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod,
"objective" -> "reg:squaredlogerror")
val xgb1 = new XGBoostRegressor(paramMap1)
assert(xgb1.getObjective == "reg:squaredlogerror")
xgb1.fit(training)
assert(xgb1.getObjective == "reg:squaredlogerror")
}
test("test predictionLeaf") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers,