[jvm-packages] automatically set the max/min direction for best score (#9404)
This commit is contained in:
parent
7579905e18
commit
8f0efb4ab3
@ -23,7 +23,6 @@ import scala.util.Random
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{Communicator, IRabitTracker, XGBoostError, RabitTracker => PyRabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams
|
||||
import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
@ -55,9 +54,6 @@ object TrackerConf {
|
||||
def apply(): TrackerConf = TrackerConf(0L)
|
||||
}
|
||||
|
||||
private[scala] case class XGBoostExecutionEarlyStoppingParams(numEarlyStoppingRounds: Int,
|
||||
maximizeEvalMetrics: Boolean)
|
||||
|
||||
private[scala] case class XGBoostExecutionInputParams(trainTestRatio: Double, seed: Long)
|
||||
|
||||
private[scala] case class XGBoostExecutionParams(
|
||||
@ -71,7 +67,7 @@ private[scala] case class XGBoostExecutionParams(
|
||||
trackerConf: TrackerConf,
|
||||
checkpointParam: Option[ExternalCheckpointParams],
|
||||
xgbInputParams: XGBoostExecutionInputParams,
|
||||
earlyStoppingParams: XGBoostExecutionEarlyStoppingParams,
|
||||
earlyStoppingRounds: Int,
|
||||
cacheTrainingSet: Boolean,
|
||||
device: Option[String],
|
||||
isLocal: Boolean,
|
||||
@ -146,15 +142,8 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
||||
val numEarlyStoppingRounds = overridedParams.getOrElse(
|
||||
"num_early_stopping_rounds", 0).asInstanceOf[Int]
|
||||
overridedParams += "num_early_stopping_rounds" -> numEarlyStoppingRounds
|
||||
if (numEarlyStoppingRounds > 0 &&
|
||||
!overridedParams.contains("maximize_evaluation_metrics")) {
|
||||
if (overridedParams.getOrElse("custom_eval", null) != null) {
|
||||
if (numEarlyStoppingRounds > 0 && overridedParams.getOrElse("custom_eval", null) != null) {
|
||||
throw new IllegalArgumentException("custom_eval does not support early stopping")
|
||||
}
|
||||
val eval_metric = overridedParams("eval_metric").toString
|
||||
val maximize = LearningTaskParams.evalMetricsToMaximize contains eval_metric
|
||||
logger.info("parameter \"maximize_evaluation_metrics\" is set to " + maximize)
|
||||
overridedParams += ("maximize_evaluation_metrics" -> maximize)
|
||||
}
|
||||
overridedParams
|
||||
}
|
||||
@ -213,10 +202,6 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
||||
|
||||
val earlyStoppingRounds = overridedParams.getOrElse(
|
||||
"num_early_stopping_rounds", 0).asInstanceOf[Int]
|
||||
val maximizeEvalMetrics = overridedParams.getOrElse(
|
||||
"maximize_evaluation_metrics", true).asInstanceOf[Boolean]
|
||||
val xgbExecEarlyStoppingParams = XGBoostExecutionEarlyStoppingParams(earlyStoppingRounds,
|
||||
maximizeEvalMetrics)
|
||||
|
||||
val cacheTrainingSet = overridedParams.getOrElse("cache_training_set", false)
|
||||
.asInstanceOf[Boolean]
|
||||
@ -232,7 +217,7 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
||||
missing, allowNonZeroForMissing, trackerConf,
|
||||
checkpointParam,
|
||||
inputParams,
|
||||
xgbExecEarlyStoppingParams,
|
||||
earlyStoppingRounds,
|
||||
cacheTrainingSet,
|
||||
device,
|
||||
isLocal,
|
||||
@ -319,7 +304,7 @@ object XGBoost extends Serializable {
|
||||
|
||||
watches = buildWatchesAndCheck(buildWatches)
|
||||
|
||||
val numEarlyStoppingRounds = xgbExecutionParam.earlyStoppingParams.numEarlyStoppingRounds
|
||||
val numEarlyStoppingRounds = xgbExecutionParam.earlyStoppingRounds
|
||||
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](numRounds))
|
||||
val externalCheckpointParams = xgbExecutionParam.checkpointParam
|
||||
|
||||
|
||||
@ -112,8 +112,4 @@ private[spark] object LearningTaskParams {
|
||||
|
||||
val supportedObjectiveType = HashSet("regression", "classification")
|
||||
|
||||
val evalMetricsToMaximize = HashSet("auc", "aucpr", "ndcg", "map")
|
||||
|
||||
val evalMetricsToMinimize = HashSet("rmse", "rmsle", "mae", "mape", "logloss", "error", "merror",
|
||||
"mlogloss", "gamma-deviance")
|
||||
}
|
||||
|
||||
@ -17,6 +17,8 @@ package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.*;
|
||||
import java.util.*;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
@ -30,6 +32,11 @@ import org.apache.hadoop.fs.FileSystem;
|
||||
public class XGBoost {
|
||||
private static final Log logger = LogFactory.getLog(XGBoost.class);
|
||||
|
||||
public static final String[] MAXIMIZ_METRICES = {
|
||||
"auc", "aucpr", "pre", "pre@", "map", "ndcg",
|
||||
"auc@", "aucpr@", "map@", "ndcg@",
|
||||
};
|
||||
|
||||
/**
|
||||
* load model from modelPath
|
||||
*
|
||||
@ -158,7 +165,7 @@ public class XGBoost {
|
||||
//collect eval matrixs
|
||||
String[] evalNames;
|
||||
DMatrix[] evalMats;
|
||||
float bestScore;
|
||||
float bestScore = 1;
|
||||
int bestIteration;
|
||||
List<String> names = new ArrayList<String>();
|
||||
List<DMatrix> mats = new ArrayList<DMatrix>();
|
||||
@ -175,11 +182,7 @@ public class XGBoost {
|
||||
|
||||
evalNames = names.toArray(new String[names.size()]);
|
||||
evalMats = mats.toArray(new DMatrix[mats.size()]);
|
||||
if (isMaximizeEvaluation(params)) {
|
||||
bestScore = -Float.MAX_VALUE;
|
||||
} else {
|
||||
bestScore = Float.MAX_VALUE;
|
||||
}
|
||||
|
||||
bestIteration = 0;
|
||||
metrics = metrics == null ? new float[evalNames.length][numRounds] : metrics;
|
||||
|
||||
@ -210,6 +213,9 @@ public class XGBoost {
|
||||
checkpointIterations = new HashSet<>(ecm.getCheckpointRounds(checkpointInterval, numRounds));
|
||||
}
|
||||
|
||||
boolean initial_best_score_flag = false;
|
||||
boolean max_direction = false;
|
||||
|
||||
// begin to train
|
||||
for (int iter = booster.getVersion() / 2; iter < numRounds; iter++) {
|
||||
if (booster.getVersion() % 2 == 0) {
|
||||
@ -231,6 +237,18 @@ public class XGBoost {
|
||||
} else {
|
||||
evalInfo = booster.evalSet(evalMats, evalNames, iter, metricsOut);
|
||||
}
|
||||
|
||||
if (!initial_best_score_flag) {
|
||||
if (isMaximizeEvaluation(evalInfo, evalNames, params)) {
|
||||
max_direction = true;
|
||||
bestScore = -Float.MAX_VALUE;
|
||||
} else {
|
||||
max_direction = false;
|
||||
bestScore = Float.MAX_VALUE;
|
||||
}
|
||||
initial_best_score_flag = true;
|
||||
}
|
||||
|
||||
for (int i = 0; i < metricsOut.length; i++) {
|
||||
metrics[i][iter] = metricsOut[i];
|
||||
}
|
||||
@ -238,7 +256,7 @@ public class XGBoost {
|
||||
// If there is more than one evaluation datasets, the last one would be used
|
||||
// to determinate early stop.
|
||||
float score = metricsOut[metricsOut.length - 1];
|
||||
if (isMaximizeEvaluation(params)) {
|
||||
if (max_direction) {
|
||||
// Update best score if the current score is better (no update when equal)
|
||||
if (score > bestScore) {
|
||||
bestScore = score;
|
||||
@ -264,9 +282,7 @@ public class XGBoost {
|
||||
break;
|
||||
}
|
||||
if (Communicator.getRank() == 0 && shouldPrint(params, iter)) {
|
||||
if (shouldPrint(params, iter)){
|
||||
Communicator.communicatorPrint(evalInfo + '\n');
|
||||
}
|
||||
Communicator.communicatorPrint(evalInfo + '\n');
|
||||
}
|
||||
}
|
||||
booster.saveRabitCheckpoint();
|
||||
@ -360,16 +376,50 @@ public class XGBoost {
|
||||
return iter - bestIteration >= earlyStoppingRounds;
|
||||
}
|
||||
|
||||
private static boolean isMaximizeEvaluation(Map<String, Object> params) {
|
||||
try {
|
||||
String maximize = String.valueOf(params.get("maximize_evaluation_metrics"));
|
||||
assert(maximize != null);
|
||||
return Boolean.valueOf(maximize);
|
||||
} catch (Exception ex) {
|
||||
logger.error("maximize_evaluation_metrics has to be specified for enabling early stop," +
|
||||
" allowed value: true/false", ex);
|
||||
throw ex;
|
||||
private static String getMetricNameFromlog(String evalInfo, String[] evalNames) {
|
||||
String regexPattern = Pattern.quote(evalNames[0]) + "-(.*):";
|
||||
Pattern pattern = Pattern.compile(regexPattern);
|
||||
Matcher matcher = pattern.matcher(evalInfo);
|
||||
|
||||
String metricName = null;
|
||||
if (matcher.find()) {
|
||||
metricName = matcher.group(1);
|
||||
logger.debug("Got the metric name: " + metricName);
|
||||
}
|
||||
return metricName;
|
||||
}
|
||||
|
||||
// visiable for testing
|
||||
public static boolean isMaximizeEvaluation(String evalInfo,
|
||||
String[] evalNames,
|
||||
Map<String, Object> params) {
|
||||
|
||||
String metricName;
|
||||
|
||||
if (params.get("maximize_evaluation_metrics") != null) {
|
||||
// user has forced the direction no matter what is the metric name.
|
||||
String maximize = String.valueOf(params.get("maximize_evaluation_metrics"));
|
||||
return Boolean.valueOf(maximize);
|
||||
}
|
||||
|
||||
if (params.get("eval_metric") != null) {
|
||||
// user has special metric name
|
||||
metricName = String.valueOf(params.get("eval_metric"));
|
||||
} else {
|
||||
// infer the metric name from log
|
||||
metricName = getMetricNameFromlog(evalInfo, evalNames);
|
||||
}
|
||||
|
||||
assert metricName != null;
|
||||
|
||||
if (!"mape".equals(metricName)) {
|
||||
for (String x : MAXIMIZ_METRICES) {
|
||||
if (metricName.startsWith(x)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -0,0 +1,121 @@
|
||||
/*
|
||||
Copyright (c) 2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import junit.framework.TestCase;
|
||||
import ml.dmlc.xgboost4j.LabeledPoint;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Random;
|
||||
|
||||
public class XGBoostTest {
|
||||
|
||||
private String composeEvalInfo(String metric, String evalName) {
|
||||
return "[0]\t" + evalName + "-" + metric + ":" + "\ttest";
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testIsMaximizeEvaluation() {
|
||||
String[] minimum_metrics = {"mape", "logloss", "error", "others"};
|
||||
String[] evalNames = {"set-abc"};
|
||||
|
||||
HashMap<String, Object> params = new HashMap<>();
|
||||
|
||||
// test1, infer the metric from faked log
|
||||
for (String x : XGBoost.MAXIMIZ_METRICES) {
|
||||
String evalInfo = composeEvalInfo(x, evalNames[0]);
|
||||
TestCase.assertTrue(XGBoost.isMaximizeEvaluation(evalInfo, evalNames, params));
|
||||
}
|
||||
|
||||
// test2, the direction for mape should be minimum
|
||||
String evalInfo = composeEvalInfo("mape", evalNames[0]);
|
||||
TestCase.assertFalse(XGBoost.isMaximizeEvaluation(evalInfo, evalNames, params));
|
||||
|
||||
// test3, force maximize_evaluation_metrics
|
||||
params.clear();
|
||||
params.put("maximize_evaluation_metrics", true);
|
||||
// auc should be max,
|
||||
evalInfo = composeEvalInfo("auc", evalNames[0]);
|
||||
TestCase.assertTrue(XGBoost.isMaximizeEvaluation(evalInfo, evalNames, params));
|
||||
|
||||
params.clear();
|
||||
params.put("maximize_evaluation_metrics", false);
|
||||
// auc should be min,
|
||||
evalInfo = composeEvalInfo("auc", evalNames[0]);
|
||||
TestCase.assertFalse(XGBoost.isMaximizeEvaluation(evalInfo, evalNames, params));
|
||||
|
||||
// test4, set the metric manually
|
||||
for (String x : XGBoost.MAXIMIZ_METRICES) {
|
||||
params.clear();
|
||||
params.put("eval_metric", x);
|
||||
evalInfo = composeEvalInfo(x, evalNames[0]);
|
||||
TestCase.assertTrue(XGBoost.isMaximizeEvaluation(evalInfo, evalNames, params));
|
||||
}
|
||||
|
||||
// test5, set the metric manually
|
||||
for (String x : minimum_metrics) {
|
||||
params.clear();
|
||||
params.put("eval_metric", x);
|
||||
evalInfo = composeEvalInfo(x, evalNames[0]);
|
||||
TestCase.assertFalse(XGBoost.isMaximizeEvaluation(evalInfo, evalNames, params));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEarlyStop() throws XGBoostError {
|
||||
Random random = new Random(1);
|
||||
|
||||
java.util.ArrayList<Float> labelall = new java.util.ArrayList<Float>();
|
||||
int nrep = 3000;
|
||||
java.util.List<LabeledPoint> blist = new java.util.LinkedList<LabeledPoint>();
|
||||
for (int i = 0; i < nrep; ++i) {
|
||||
LabeledPoint p = new LabeledPoint(
|
||||
i % 2, 4,
|
||||
new int[]{0, 1, 2, 3},
|
||||
new float[]{random.nextFloat(), random.nextFloat(), random.nextFloat(), random.nextFloat()});
|
||||
blist.add(p);
|
||||
labelall.add(p.label());
|
||||
}
|
||||
|
||||
DMatrix dmat = new DMatrix(blist.iterator(), null);
|
||||
|
||||
int round = 50;
|
||||
int earlyStop = 2;
|
||||
|
||||
HashMap<String, Object> mapParams = new HashMap<>();
|
||||
mapParams.put("eta", 0.1);
|
||||
mapParams.put("objective", "binary:logistic");
|
||||
mapParams.put("max_depth", 3);
|
||||
mapParams.put("eval_metric", "auc");
|
||||
mapParams.put("silent", 0);
|
||||
|
||||
HashMap<String, DMatrix> mapWatches = new HashMap<>();
|
||||
mapWatches.put("selTrain-*", dmat);
|
||||
|
||||
try {
|
||||
Booster booster = XGBoost.train(dmat, mapParams, round, mapWatches, null, null, null, earlyStop);
|
||||
Map<String, String> attrs = booster.getAttrs();
|
||||
TestCase.assertTrue(Integer.valueOf(attrs.get("best_iteration")) < round - 1);
|
||||
} catch (Exception e) {
|
||||
TestCase.assertFalse(false);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user