[jvm-packges] set the correct objective if user doesn't explicitly set it (#7781)

This commit is contained in:
Bobby Wang
2022-05-18 14:05:18 +08:00
committed by GitHub
parent 806c92c80b
commit 5ef33adf68
6 changed files with 71 additions and 9 deletions

View File

@@ -169,6 +169,23 @@ class XGBoostClassifier (
}
override protected def train(dataset: Dataset[_]): XGBoostClassificationModel = {
val _numClasses = getNumClasses(dataset)
if (isDefined(numClass) && $(numClass) != _numClasses) {
throw new Exception("The number of classes in dataset doesn't match " +
"\'num_class\' in xgboost params.")
}
if (_numClasses == 2) {
if (!isDefined(objective)) {
// If user doesn't set objective, force it to binary:logistic
setObjective("binary:logistic")
}
} else if (_numClasses > 2) {
if (!isDefined(objective)) {
// If user doesn't set objective, force it to multi:softprob
setObjective("multi:softprob")
}
}
if (!isDefined(evalMetric) || $(evalMetric).isEmpty) {
set(evalMetric, setupDefaultEvalMetric())
@@ -178,12 +195,6 @@ class XGBoostClassifier (
set(objectiveType, "classification")
}
val _numClasses = getNumClasses(dataset)
if (isDefined(numClass) && $(numClass) != _numClasses) {
throw new Exception("The number of classes in dataset doesn't match " +
"\'num_class\' in xgboost params.")
}
// Packing with all params plus params user defined
val derivedXGBParamMap = xgboostParams ++ MLlib2XGBoostParams
val buildTrainingData = PreXGBoost.buildDatasetToRDD(this, dataset, derivedXGBParamMap)

View File

@@ -169,6 +169,11 @@ class XGBoostRegressor (
override protected def train(dataset: Dataset[_]): XGBoostRegressionModel = {
if (!isDefined(objective)) {
// If user doesn't set objective, force it to reg:squarederror
setObjective("reg:squarederror")
}
if (!isDefined(evalMetric) || $(evalMetric).isEmpty) {
set(evalMetric, setupDefaultEvalMetric())
}

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014-2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -105,7 +105,7 @@ private[spark] trait LearningTaskParams extends Params {
final def getMaximizeEvaluationMetrics: Boolean = $(maximizeEvaluationMetrics)
setDefault(objective -> "reg:squarederror", baseScore -> 0.5, trainTestRatio -> 1.0,
setDefault(baseScore -> 0.5, trainTestRatio -> 1.0,
numEarlyStoppingRounds -> 0, cacheTrainingSet -> false)
}