[jvm-packages] automatically set the max/min direction for best score (#9404)
This commit is contained in:
@@ -17,6 +17,8 @@ package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.*;
|
||||
import java.util.*;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
@@ -30,6 +32,11 @@ import org.apache.hadoop.fs.FileSystem;
|
||||
public class XGBoost {
|
||||
private static final Log logger = LogFactory.getLog(XGBoost.class);
|
||||
|
||||
public static final String[] MAXIMIZ_METRICES = {
|
||||
"auc", "aucpr", "pre", "pre@", "map", "ndcg",
|
||||
"auc@", "aucpr@", "map@", "ndcg@",
|
||||
};
|
||||
|
||||
/**
|
||||
* load model from modelPath
|
||||
*
|
||||
@@ -158,7 +165,7 @@ public class XGBoost {
|
||||
//collect eval matrixs
|
||||
String[] evalNames;
|
||||
DMatrix[] evalMats;
|
||||
float bestScore;
|
||||
float bestScore = 1;
|
||||
int bestIteration;
|
||||
List<String> names = new ArrayList<String>();
|
||||
List<DMatrix> mats = new ArrayList<DMatrix>();
|
||||
@@ -175,11 +182,7 @@ public class XGBoost {
|
||||
|
||||
evalNames = names.toArray(new String[names.size()]);
|
||||
evalMats = mats.toArray(new DMatrix[mats.size()]);
|
||||
if (isMaximizeEvaluation(params)) {
|
||||
bestScore = -Float.MAX_VALUE;
|
||||
} else {
|
||||
bestScore = Float.MAX_VALUE;
|
||||
}
|
||||
|
||||
bestIteration = 0;
|
||||
metrics = metrics == null ? new float[evalNames.length][numRounds] : metrics;
|
||||
|
||||
@@ -210,6 +213,9 @@ public class XGBoost {
|
||||
checkpointIterations = new HashSet<>(ecm.getCheckpointRounds(checkpointInterval, numRounds));
|
||||
}
|
||||
|
||||
boolean initial_best_score_flag = false;
|
||||
boolean max_direction = false;
|
||||
|
||||
// begin to train
|
||||
for (int iter = booster.getVersion() / 2; iter < numRounds; iter++) {
|
||||
if (booster.getVersion() % 2 == 0) {
|
||||
@@ -231,6 +237,18 @@ public class XGBoost {
|
||||
} else {
|
||||
evalInfo = booster.evalSet(evalMats, evalNames, iter, metricsOut);
|
||||
}
|
||||
|
||||
if (!initial_best_score_flag) {
|
||||
if (isMaximizeEvaluation(evalInfo, evalNames, params)) {
|
||||
max_direction = true;
|
||||
bestScore = -Float.MAX_VALUE;
|
||||
} else {
|
||||
max_direction = false;
|
||||
bestScore = Float.MAX_VALUE;
|
||||
}
|
||||
initial_best_score_flag = true;
|
||||
}
|
||||
|
||||
for (int i = 0; i < metricsOut.length; i++) {
|
||||
metrics[i][iter] = metricsOut[i];
|
||||
}
|
||||
@@ -238,7 +256,7 @@ public class XGBoost {
|
||||
// If there is more than one evaluation datasets, the last one would be used
|
||||
// to determinate early stop.
|
||||
float score = metricsOut[metricsOut.length - 1];
|
||||
if (isMaximizeEvaluation(params)) {
|
||||
if (max_direction) {
|
||||
// Update best score if the current score is better (no update when equal)
|
||||
if (score > bestScore) {
|
||||
bestScore = score;
|
||||
@@ -264,9 +282,7 @@ public class XGBoost {
|
||||
break;
|
||||
}
|
||||
if (Communicator.getRank() == 0 && shouldPrint(params, iter)) {
|
||||
if (shouldPrint(params, iter)){
|
||||
Communicator.communicatorPrint(evalInfo + '\n');
|
||||
}
|
||||
Communicator.communicatorPrint(evalInfo + '\n');
|
||||
}
|
||||
}
|
||||
booster.saveRabitCheckpoint();
|
||||
@@ -360,16 +376,50 @@ public class XGBoost {
|
||||
return iter - bestIteration >= earlyStoppingRounds;
|
||||
}
|
||||
|
||||
private static boolean isMaximizeEvaluation(Map<String, Object> params) {
|
||||
try {
|
||||
String maximize = String.valueOf(params.get("maximize_evaluation_metrics"));
|
||||
assert(maximize != null);
|
||||
return Boolean.valueOf(maximize);
|
||||
} catch (Exception ex) {
|
||||
logger.error("maximize_evaluation_metrics has to be specified for enabling early stop," +
|
||||
" allowed value: true/false", ex);
|
||||
throw ex;
|
||||
private static String getMetricNameFromlog(String evalInfo, String[] evalNames) {
|
||||
String regexPattern = Pattern.quote(evalNames[0]) + "-(.*):";
|
||||
Pattern pattern = Pattern.compile(regexPattern);
|
||||
Matcher matcher = pattern.matcher(evalInfo);
|
||||
|
||||
String metricName = null;
|
||||
if (matcher.find()) {
|
||||
metricName = matcher.group(1);
|
||||
logger.debug("Got the metric name: " + metricName);
|
||||
}
|
||||
return metricName;
|
||||
}
|
||||
|
||||
// visiable for testing
|
||||
public static boolean isMaximizeEvaluation(String evalInfo,
|
||||
String[] evalNames,
|
||||
Map<String, Object> params) {
|
||||
|
||||
String metricName;
|
||||
|
||||
if (params.get("maximize_evaluation_metrics") != null) {
|
||||
// user has forced the direction no matter what is the metric name.
|
||||
String maximize = String.valueOf(params.get("maximize_evaluation_metrics"));
|
||||
return Boolean.valueOf(maximize);
|
||||
}
|
||||
|
||||
if (params.get("eval_metric") != null) {
|
||||
// user has special metric name
|
||||
metricName = String.valueOf(params.get("eval_metric"));
|
||||
} else {
|
||||
// infer the metric name from log
|
||||
metricName = getMetricNameFromlog(evalInfo, evalNames);
|
||||
}
|
||||
|
||||
assert metricName != null;
|
||||
|
||||
if (!"mape".equals(metricName)) {
|
||||
for (String x : MAXIMIZ_METRICES) {
|
||||
if (metricName.startsWith(x)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
Reference in New Issue
Block a user