[jvm-packages] automatically set the max/min direction for best score (#9404)
This commit is contained in:
@@ -23,7 +23,6 @@ import scala.util.Random
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{Communicator, IRabitTracker, XGBoostError, RabitTracker => PyRabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams
|
||||
import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
@@ -55,9 +54,6 @@ object TrackerConf {
|
||||
def apply(): TrackerConf = TrackerConf(0L)
|
||||
}
|
||||
|
||||
private[scala] case class XGBoostExecutionEarlyStoppingParams(numEarlyStoppingRounds: Int,
|
||||
maximizeEvalMetrics: Boolean)
|
||||
|
||||
private[scala] case class XGBoostExecutionInputParams(trainTestRatio: Double, seed: Long)
|
||||
|
||||
private[scala] case class XGBoostExecutionParams(
|
||||
@@ -71,7 +67,7 @@ private[scala] case class XGBoostExecutionParams(
|
||||
trackerConf: TrackerConf,
|
||||
checkpointParam: Option[ExternalCheckpointParams],
|
||||
xgbInputParams: XGBoostExecutionInputParams,
|
||||
earlyStoppingParams: XGBoostExecutionEarlyStoppingParams,
|
||||
earlyStoppingRounds: Int,
|
||||
cacheTrainingSet: Boolean,
|
||||
device: Option[String],
|
||||
isLocal: Boolean,
|
||||
@@ -146,15 +142,8 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
||||
val numEarlyStoppingRounds = overridedParams.getOrElse(
|
||||
"num_early_stopping_rounds", 0).asInstanceOf[Int]
|
||||
overridedParams += "num_early_stopping_rounds" -> numEarlyStoppingRounds
|
||||
if (numEarlyStoppingRounds > 0 &&
|
||||
!overridedParams.contains("maximize_evaluation_metrics")) {
|
||||
if (overridedParams.getOrElse("custom_eval", null) != null) {
|
||||
if (numEarlyStoppingRounds > 0 && overridedParams.getOrElse("custom_eval", null) != null) {
|
||||
throw new IllegalArgumentException("custom_eval does not support early stopping")
|
||||
}
|
||||
val eval_metric = overridedParams("eval_metric").toString
|
||||
val maximize = LearningTaskParams.evalMetricsToMaximize contains eval_metric
|
||||
logger.info("parameter \"maximize_evaluation_metrics\" is set to " + maximize)
|
||||
overridedParams += ("maximize_evaluation_metrics" -> maximize)
|
||||
}
|
||||
overridedParams
|
||||
}
|
||||
@@ -213,10 +202,6 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
||||
|
||||
val earlyStoppingRounds = overridedParams.getOrElse(
|
||||
"num_early_stopping_rounds", 0).asInstanceOf[Int]
|
||||
val maximizeEvalMetrics = overridedParams.getOrElse(
|
||||
"maximize_evaluation_metrics", true).asInstanceOf[Boolean]
|
||||
val xgbExecEarlyStoppingParams = XGBoostExecutionEarlyStoppingParams(earlyStoppingRounds,
|
||||
maximizeEvalMetrics)
|
||||
|
||||
val cacheTrainingSet = overridedParams.getOrElse("cache_training_set", false)
|
||||
.asInstanceOf[Boolean]
|
||||
@@ -232,7 +217,7 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
||||
missing, allowNonZeroForMissing, trackerConf,
|
||||
checkpointParam,
|
||||
inputParams,
|
||||
xgbExecEarlyStoppingParams,
|
||||
earlyStoppingRounds,
|
||||
cacheTrainingSet,
|
||||
device,
|
||||
isLocal,
|
||||
@@ -319,7 +304,7 @@ object XGBoost extends Serializable {
|
||||
|
||||
watches = buildWatchesAndCheck(buildWatches)
|
||||
|
||||
val numEarlyStoppingRounds = xgbExecutionParam.earlyStoppingParams.numEarlyStoppingRounds
|
||||
val numEarlyStoppingRounds = xgbExecutionParam.earlyStoppingRounds
|
||||
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](numRounds))
|
||||
val externalCheckpointParams = xgbExecutionParam.checkpointParam
|
||||
|
||||
|
||||
@@ -112,8 +112,4 @@ private[spark] object LearningTaskParams {
|
||||
|
||||
val supportedObjectiveType = HashSet("regression", "classification")
|
||||
|
||||
val evalMetricsToMaximize = HashSet("auc", "aucpr", "ndcg", "map")
|
||||
|
||||
val evalMetricsToMinimize = HashSet("rmse", "rmsle", "mae", "mape", "logloss", "error", "merror",
|
||||
"mlogloss", "gamma-deviance")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user