[jvm-packages] automatically set the max/min direction for best score (#9404)

This commit is contained in:
Bobby Wang
2023-07-27 11:09:55 +08:00
committed by GitHub
parent 7579905e18
commit 8f0efb4ab3
4 changed files with 194 additions and 42 deletions

View File

@@ -17,6 +17,8 @@ package ml.dmlc.xgboost4j.java;
import java.io.*;
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
@@ -30,6 +32,11 @@ import org.apache.hadoop.fs.FileSystem;
public class XGBoost {
private static final Log logger = LogFactory.getLog(XGBoost.class);
public static final String[] MAXIMIZ_METRICES = {
"auc", "aucpr", "pre", "pre@", "map", "ndcg",
"auc@", "aucpr@", "map@", "ndcg@",
};
/**
* load model from modelPath
*
@@ -158,7 +165,7 @@ public class XGBoost {
//collect eval matrixs
String[] evalNames;
DMatrix[] evalMats;
float bestScore;
float bestScore = 1;
int bestIteration;
List<String> names = new ArrayList<String>();
List<DMatrix> mats = new ArrayList<DMatrix>();
@@ -175,11 +182,7 @@ public class XGBoost {
evalNames = names.toArray(new String[names.size()]);
evalMats = mats.toArray(new DMatrix[mats.size()]);
if (isMaximizeEvaluation(params)) {
bestScore = -Float.MAX_VALUE;
} else {
bestScore = Float.MAX_VALUE;
}
bestIteration = 0;
metrics = metrics == null ? new float[evalNames.length][numRounds] : metrics;
@@ -210,6 +213,9 @@ public class XGBoost {
checkpointIterations = new HashSet<>(ecm.getCheckpointRounds(checkpointInterval, numRounds));
}
boolean initial_best_score_flag = false;
boolean max_direction = false;
// begin to train
for (int iter = booster.getVersion() / 2; iter < numRounds; iter++) {
if (booster.getVersion() % 2 == 0) {
@@ -231,6 +237,18 @@ public class XGBoost {
} else {
evalInfo = booster.evalSet(evalMats, evalNames, iter, metricsOut);
}
if (!initial_best_score_flag) {
if (isMaximizeEvaluation(evalInfo, evalNames, params)) {
max_direction = true;
bestScore = -Float.MAX_VALUE;
} else {
max_direction = false;
bestScore = Float.MAX_VALUE;
}
initial_best_score_flag = true;
}
for (int i = 0; i < metricsOut.length; i++) {
metrics[i][iter] = metricsOut[i];
}
@@ -238,7 +256,7 @@ public class XGBoost {
// If there is more than one evaluation datasets, the last one would be used
// to determinate early stop.
float score = metricsOut[metricsOut.length - 1];
if (isMaximizeEvaluation(params)) {
if (max_direction) {
// Update best score if the current score is better (no update when equal)
if (score > bestScore) {
bestScore = score;
@@ -264,9 +282,7 @@ public class XGBoost {
break;
}
if (Communicator.getRank() == 0 && shouldPrint(params, iter)) {
if (shouldPrint(params, iter)){
Communicator.communicatorPrint(evalInfo + '\n');
}
Communicator.communicatorPrint(evalInfo + '\n');
}
}
booster.saveRabitCheckpoint();
@@ -360,16 +376,50 @@ public class XGBoost {
return iter - bestIteration >= earlyStoppingRounds;
}
private static boolean isMaximizeEvaluation(Map<String, Object> params) {
try {
String maximize = String.valueOf(params.get("maximize_evaluation_metrics"));
assert(maximize != null);
return Boolean.valueOf(maximize);
} catch (Exception ex) {
logger.error("maximize_evaluation_metrics has to be specified for enabling early stop," +
" allowed value: true/false", ex);
throw ex;
private static String getMetricNameFromlog(String evalInfo, String[] evalNames) {
String regexPattern = Pattern.quote(evalNames[0]) + "-(.*):";
Pattern pattern = Pattern.compile(regexPattern);
Matcher matcher = pattern.matcher(evalInfo);
String metricName = null;
if (matcher.find()) {
metricName = matcher.group(1);
logger.debug("Got the metric name: " + metricName);
}
return metricName;
}
// visiable for testing
public static boolean isMaximizeEvaluation(String evalInfo,
String[] evalNames,
Map<String, Object> params) {
String metricName;
if (params.get("maximize_evaluation_metrics") != null) {
// user has forced the direction no matter what is the metric name.
String maximize = String.valueOf(params.get("maximize_evaluation_metrics"));
return Boolean.valueOf(maximize);
}
if (params.get("eval_metric") != null) {
// user has special metric name
metricName = String.valueOf(params.get("eval_metric"));
} else {
// infer the metric name from log
metricName = getMetricNameFromlog(evalInfo, evalNames);
}
assert metricName != null;
if (!"mape".equals(metricName)) {
for (String x : MAXIMIZ_METRICES) {
if (metricName.startsWith(x)) {
return true;
}
}
}
return false;
}
/**