[jvm-packages] automatically set the max/min direction for best score (#9404)
This commit is contained in:
parent
7579905e18
commit
8f0efb4ab3
@ -23,7 +23,6 @@ import scala.util.Random
|
|||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.{Communicator, IRabitTracker, XGBoostError, RabitTracker => PyRabitTracker}
|
import ml.dmlc.xgboost4j.java.{Communicator, IRabitTracker, XGBoostError, RabitTracker => PyRabitTracker}
|
||||||
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams
|
|
||||||
import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager
|
import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager
|
||||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||||
@ -55,9 +54,6 @@ object TrackerConf {
|
|||||||
def apply(): TrackerConf = TrackerConf(0L)
|
def apply(): TrackerConf = TrackerConf(0L)
|
||||||
}
|
}
|
||||||
|
|
||||||
private[scala] case class XGBoostExecutionEarlyStoppingParams(numEarlyStoppingRounds: Int,
|
|
||||||
maximizeEvalMetrics: Boolean)
|
|
||||||
|
|
||||||
private[scala] case class XGBoostExecutionInputParams(trainTestRatio: Double, seed: Long)
|
private[scala] case class XGBoostExecutionInputParams(trainTestRatio: Double, seed: Long)
|
||||||
|
|
||||||
private[scala] case class XGBoostExecutionParams(
|
private[scala] case class XGBoostExecutionParams(
|
||||||
@ -71,7 +67,7 @@ private[scala] case class XGBoostExecutionParams(
|
|||||||
trackerConf: TrackerConf,
|
trackerConf: TrackerConf,
|
||||||
checkpointParam: Option[ExternalCheckpointParams],
|
checkpointParam: Option[ExternalCheckpointParams],
|
||||||
xgbInputParams: XGBoostExecutionInputParams,
|
xgbInputParams: XGBoostExecutionInputParams,
|
||||||
earlyStoppingParams: XGBoostExecutionEarlyStoppingParams,
|
earlyStoppingRounds: Int,
|
||||||
cacheTrainingSet: Boolean,
|
cacheTrainingSet: Boolean,
|
||||||
device: Option[String],
|
device: Option[String],
|
||||||
isLocal: Boolean,
|
isLocal: Boolean,
|
||||||
@ -146,16 +142,9 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
|||||||
val numEarlyStoppingRounds = overridedParams.getOrElse(
|
val numEarlyStoppingRounds = overridedParams.getOrElse(
|
||||||
"num_early_stopping_rounds", 0).asInstanceOf[Int]
|
"num_early_stopping_rounds", 0).asInstanceOf[Int]
|
||||||
overridedParams += "num_early_stopping_rounds" -> numEarlyStoppingRounds
|
overridedParams += "num_early_stopping_rounds" -> numEarlyStoppingRounds
|
||||||
if (numEarlyStoppingRounds > 0 &&
|
if (numEarlyStoppingRounds > 0 && overridedParams.getOrElse("custom_eval", null) != null) {
|
||||||
!overridedParams.contains("maximize_evaluation_metrics")) {
|
|
||||||
if (overridedParams.getOrElse("custom_eval", null) != null) {
|
|
||||||
throw new IllegalArgumentException("custom_eval does not support early stopping")
|
throw new IllegalArgumentException("custom_eval does not support early stopping")
|
||||||
}
|
}
|
||||||
val eval_metric = overridedParams("eval_metric").toString
|
|
||||||
val maximize = LearningTaskParams.evalMetricsToMaximize contains eval_metric
|
|
||||||
logger.info("parameter \"maximize_evaluation_metrics\" is set to " + maximize)
|
|
||||||
overridedParams += ("maximize_evaluation_metrics" -> maximize)
|
|
||||||
}
|
|
||||||
overridedParams
|
overridedParams
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -213,10 +202,6 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
|||||||
|
|
||||||
val earlyStoppingRounds = overridedParams.getOrElse(
|
val earlyStoppingRounds = overridedParams.getOrElse(
|
||||||
"num_early_stopping_rounds", 0).asInstanceOf[Int]
|
"num_early_stopping_rounds", 0).asInstanceOf[Int]
|
||||||
val maximizeEvalMetrics = overridedParams.getOrElse(
|
|
||||||
"maximize_evaluation_metrics", true).asInstanceOf[Boolean]
|
|
||||||
val xgbExecEarlyStoppingParams = XGBoostExecutionEarlyStoppingParams(earlyStoppingRounds,
|
|
||||||
maximizeEvalMetrics)
|
|
||||||
|
|
||||||
val cacheTrainingSet = overridedParams.getOrElse("cache_training_set", false)
|
val cacheTrainingSet = overridedParams.getOrElse("cache_training_set", false)
|
||||||
.asInstanceOf[Boolean]
|
.asInstanceOf[Boolean]
|
||||||
@ -232,7 +217,7 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
|||||||
missing, allowNonZeroForMissing, trackerConf,
|
missing, allowNonZeroForMissing, trackerConf,
|
||||||
checkpointParam,
|
checkpointParam,
|
||||||
inputParams,
|
inputParams,
|
||||||
xgbExecEarlyStoppingParams,
|
earlyStoppingRounds,
|
||||||
cacheTrainingSet,
|
cacheTrainingSet,
|
||||||
device,
|
device,
|
||||||
isLocal,
|
isLocal,
|
||||||
@ -319,7 +304,7 @@ object XGBoost extends Serializable {
|
|||||||
|
|
||||||
watches = buildWatchesAndCheck(buildWatches)
|
watches = buildWatchesAndCheck(buildWatches)
|
||||||
|
|
||||||
val numEarlyStoppingRounds = xgbExecutionParam.earlyStoppingParams.numEarlyStoppingRounds
|
val numEarlyStoppingRounds = xgbExecutionParam.earlyStoppingRounds
|
||||||
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](numRounds))
|
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](numRounds))
|
||||||
val externalCheckpointParams = xgbExecutionParam.checkpointParam
|
val externalCheckpointParams = xgbExecutionParam.checkpointParam
|
||||||
|
|
||||||
|
|||||||
@ -112,8 +112,4 @@ private[spark] object LearningTaskParams {
|
|||||||
|
|
||||||
val supportedObjectiveType = HashSet("regression", "classification")
|
val supportedObjectiveType = HashSet("regression", "classification")
|
||||||
|
|
||||||
val evalMetricsToMaximize = HashSet("auc", "aucpr", "ndcg", "map")
|
|
||||||
|
|
||||||
val evalMetricsToMinimize = HashSet("rmse", "rmsle", "mae", "mape", "logloss", "error", "merror",
|
|
||||||
"mlogloss", "gamma-deviance")
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -17,6 +17,8 @@ package ml.dmlc.xgboost4j.java;
|
|||||||
|
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
import java.util.regex.Matcher;
|
||||||
|
import java.util.regex.Pattern;
|
||||||
|
|
||||||
import org.apache.commons.logging.Log;
|
import org.apache.commons.logging.Log;
|
||||||
import org.apache.commons.logging.LogFactory;
|
import org.apache.commons.logging.LogFactory;
|
||||||
@ -30,6 +32,11 @@ import org.apache.hadoop.fs.FileSystem;
|
|||||||
public class XGBoost {
|
public class XGBoost {
|
||||||
private static final Log logger = LogFactory.getLog(XGBoost.class);
|
private static final Log logger = LogFactory.getLog(XGBoost.class);
|
||||||
|
|
||||||
|
public static final String[] MAXIMIZ_METRICES = {
|
||||||
|
"auc", "aucpr", "pre", "pre@", "map", "ndcg",
|
||||||
|
"auc@", "aucpr@", "map@", "ndcg@",
|
||||||
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* load model from modelPath
|
* load model from modelPath
|
||||||
*
|
*
|
||||||
@ -158,7 +165,7 @@ public class XGBoost {
|
|||||||
//collect eval matrixs
|
//collect eval matrixs
|
||||||
String[] evalNames;
|
String[] evalNames;
|
||||||
DMatrix[] evalMats;
|
DMatrix[] evalMats;
|
||||||
float bestScore;
|
float bestScore = 1;
|
||||||
int bestIteration;
|
int bestIteration;
|
||||||
List<String> names = new ArrayList<String>();
|
List<String> names = new ArrayList<String>();
|
||||||
List<DMatrix> mats = new ArrayList<DMatrix>();
|
List<DMatrix> mats = new ArrayList<DMatrix>();
|
||||||
@ -175,11 +182,7 @@ public class XGBoost {
|
|||||||
|
|
||||||
evalNames = names.toArray(new String[names.size()]);
|
evalNames = names.toArray(new String[names.size()]);
|
||||||
evalMats = mats.toArray(new DMatrix[mats.size()]);
|
evalMats = mats.toArray(new DMatrix[mats.size()]);
|
||||||
if (isMaximizeEvaluation(params)) {
|
|
||||||
bestScore = -Float.MAX_VALUE;
|
|
||||||
} else {
|
|
||||||
bestScore = Float.MAX_VALUE;
|
|
||||||
}
|
|
||||||
bestIteration = 0;
|
bestIteration = 0;
|
||||||
metrics = metrics == null ? new float[evalNames.length][numRounds] : metrics;
|
metrics = metrics == null ? new float[evalNames.length][numRounds] : metrics;
|
||||||
|
|
||||||
@ -210,6 +213,9 @@ public class XGBoost {
|
|||||||
checkpointIterations = new HashSet<>(ecm.getCheckpointRounds(checkpointInterval, numRounds));
|
checkpointIterations = new HashSet<>(ecm.getCheckpointRounds(checkpointInterval, numRounds));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
boolean initial_best_score_flag = false;
|
||||||
|
boolean max_direction = false;
|
||||||
|
|
||||||
// begin to train
|
// begin to train
|
||||||
for (int iter = booster.getVersion() / 2; iter < numRounds; iter++) {
|
for (int iter = booster.getVersion() / 2; iter < numRounds; iter++) {
|
||||||
if (booster.getVersion() % 2 == 0) {
|
if (booster.getVersion() % 2 == 0) {
|
||||||
@ -231,6 +237,18 @@ public class XGBoost {
|
|||||||
} else {
|
} else {
|
||||||
evalInfo = booster.evalSet(evalMats, evalNames, iter, metricsOut);
|
evalInfo = booster.evalSet(evalMats, evalNames, iter, metricsOut);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!initial_best_score_flag) {
|
||||||
|
if (isMaximizeEvaluation(evalInfo, evalNames, params)) {
|
||||||
|
max_direction = true;
|
||||||
|
bestScore = -Float.MAX_VALUE;
|
||||||
|
} else {
|
||||||
|
max_direction = false;
|
||||||
|
bestScore = Float.MAX_VALUE;
|
||||||
|
}
|
||||||
|
initial_best_score_flag = true;
|
||||||
|
}
|
||||||
|
|
||||||
for (int i = 0; i < metricsOut.length; i++) {
|
for (int i = 0; i < metricsOut.length; i++) {
|
||||||
metrics[i][iter] = metricsOut[i];
|
metrics[i][iter] = metricsOut[i];
|
||||||
}
|
}
|
||||||
@ -238,7 +256,7 @@ public class XGBoost {
|
|||||||
// If there is more than one evaluation datasets, the last one would be used
|
// If there is more than one evaluation datasets, the last one would be used
|
||||||
// to determinate early stop.
|
// to determinate early stop.
|
||||||
float score = metricsOut[metricsOut.length - 1];
|
float score = metricsOut[metricsOut.length - 1];
|
||||||
if (isMaximizeEvaluation(params)) {
|
if (max_direction) {
|
||||||
// Update best score if the current score is better (no update when equal)
|
// Update best score if the current score is better (no update when equal)
|
||||||
if (score > bestScore) {
|
if (score > bestScore) {
|
||||||
bestScore = score;
|
bestScore = score;
|
||||||
@ -264,11 +282,9 @@ public class XGBoost {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if (Communicator.getRank() == 0 && shouldPrint(params, iter)) {
|
if (Communicator.getRank() == 0 && shouldPrint(params, iter)) {
|
||||||
if (shouldPrint(params, iter)){
|
|
||||||
Communicator.communicatorPrint(evalInfo + '\n');
|
Communicator.communicatorPrint(evalInfo + '\n');
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
booster.saveRabitCheckpoint();
|
booster.saveRabitCheckpoint();
|
||||||
}
|
}
|
||||||
return booster;
|
return booster;
|
||||||
@ -360,16 +376,50 @@ public class XGBoost {
|
|||||||
return iter - bestIteration >= earlyStoppingRounds;
|
return iter - bestIteration >= earlyStoppingRounds;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static boolean isMaximizeEvaluation(Map<String, Object> params) {
|
private static String getMetricNameFromlog(String evalInfo, String[] evalNames) {
|
||||||
try {
|
String regexPattern = Pattern.quote(evalNames[0]) + "-(.*):";
|
||||||
String maximize = String.valueOf(params.get("maximize_evaluation_metrics"));
|
Pattern pattern = Pattern.compile(regexPattern);
|
||||||
assert(maximize != null);
|
Matcher matcher = pattern.matcher(evalInfo);
|
||||||
return Boolean.valueOf(maximize);
|
|
||||||
} catch (Exception ex) {
|
String metricName = null;
|
||||||
logger.error("maximize_evaluation_metrics has to be specified for enabling early stop," +
|
if (matcher.find()) {
|
||||||
" allowed value: true/false", ex);
|
metricName = matcher.group(1);
|
||||||
throw ex;
|
logger.debug("Got the metric name: " + metricName);
|
||||||
}
|
}
|
||||||
|
return metricName;
|
||||||
|
}
|
||||||
|
|
||||||
|
// visiable for testing
|
||||||
|
public static boolean isMaximizeEvaluation(String evalInfo,
|
||||||
|
String[] evalNames,
|
||||||
|
Map<String, Object> params) {
|
||||||
|
|
||||||
|
String metricName;
|
||||||
|
|
||||||
|
if (params.get("maximize_evaluation_metrics") != null) {
|
||||||
|
// user has forced the direction no matter what is the metric name.
|
||||||
|
String maximize = String.valueOf(params.get("maximize_evaluation_metrics"));
|
||||||
|
return Boolean.valueOf(maximize);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (params.get("eval_metric") != null) {
|
||||||
|
// user has special metric name
|
||||||
|
metricName = String.valueOf(params.get("eval_metric"));
|
||||||
|
} else {
|
||||||
|
// infer the metric name from log
|
||||||
|
metricName = getMetricNameFromlog(evalInfo, evalNames);
|
||||||
|
}
|
||||||
|
|
||||||
|
assert metricName != null;
|
||||||
|
|
||||||
|
if (!"mape".equals(metricName)) {
|
||||||
|
for (String x : MAXIMIZ_METRICES) {
|
||||||
|
if (metricName.startsWith(x)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -0,0 +1,121 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2023 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.java;
|
||||||
|
|
||||||
|
import junit.framework.TestCase;
|
||||||
|
import ml.dmlc.xgboost4j.LabeledPoint;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Random;
|
||||||
|
|
||||||
|
public class XGBoostTest {
|
||||||
|
|
||||||
|
private String composeEvalInfo(String metric, String evalName) {
|
||||||
|
return "[0]\t" + evalName + "-" + metric + ":" + "\ttest";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testIsMaximizeEvaluation() {
|
||||||
|
String[] minimum_metrics = {"mape", "logloss", "error", "others"};
|
||||||
|
String[] evalNames = {"set-abc"};
|
||||||
|
|
||||||
|
HashMap<String, Object> params = new HashMap<>();
|
||||||
|
|
||||||
|
// test1, infer the metric from faked log
|
||||||
|
for (String x : XGBoost.MAXIMIZ_METRICES) {
|
||||||
|
String evalInfo = composeEvalInfo(x, evalNames[0]);
|
||||||
|
TestCase.assertTrue(XGBoost.isMaximizeEvaluation(evalInfo, evalNames, params));
|
||||||
|
}
|
||||||
|
|
||||||
|
// test2, the direction for mape should be minimum
|
||||||
|
String evalInfo = composeEvalInfo("mape", evalNames[0]);
|
||||||
|
TestCase.assertFalse(XGBoost.isMaximizeEvaluation(evalInfo, evalNames, params));
|
||||||
|
|
||||||
|
// test3, force maximize_evaluation_metrics
|
||||||
|
params.clear();
|
||||||
|
params.put("maximize_evaluation_metrics", true);
|
||||||
|
// auc should be max,
|
||||||
|
evalInfo = composeEvalInfo("auc", evalNames[0]);
|
||||||
|
TestCase.assertTrue(XGBoost.isMaximizeEvaluation(evalInfo, evalNames, params));
|
||||||
|
|
||||||
|
params.clear();
|
||||||
|
params.put("maximize_evaluation_metrics", false);
|
||||||
|
// auc should be min,
|
||||||
|
evalInfo = composeEvalInfo("auc", evalNames[0]);
|
||||||
|
TestCase.assertFalse(XGBoost.isMaximizeEvaluation(evalInfo, evalNames, params));
|
||||||
|
|
||||||
|
// test4, set the metric manually
|
||||||
|
for (String x : XGBoost.MAXIMIZ_METRICES) {
|
||||||
|
params.clear();
|
||||||
|
params.put("eval_metric", x);
|
||||||
|
evalInfo = composeEvalInfo(x, evalNames[0]);
|
||||||
|
TestCase.assertTrue(XGBoost.isMaximizeEvaluation(evalInfo, evalNames, params));
|
||||||
|
}
|
||||||
|
|
||||||
|
// test5, set the metric manually
|
||||||
|
for (String x : minimum_metrics) {
|
||||||
|
params.clear();
|
||||||
|
params.put("eval_metric", x);
|
||||||
|
evalInfo = composeEvalInfo(x, evalNames[0]);
|
||||||
|
TestCase.assertFalse(XGBoost.isMaximizeEvaluation(evalInfo, evalNames, params));
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testEarlyStop() throws XGBoostError {
|
||||||
|
Random random = new Random(1);
|
||||||
|
|
||||||
|
java.util.ArrayList<Float> labelall = new java.util.ArrayList<Float>();
|
||||||
|
int nrep = 3000;
|
||||||
|
java.util.List<LabeledPoint> blist = new java.util.LinkedList<LabeledPoint>();
|
||||||
|
for (int i = 0; i < nrep; ++i) {
|
||||||
|
LabeledPoint p = new LabeledPoint(
|
||||||
|
i % 2, 4,
|
||||||
|
new int[]{0, 1, 2, 3},
|
||||||
|
new float[]{random.nextFloat(), random.nextFloat(), random.nextFloat(), random.nextFloat()});
|
||||||
|
blist.add(p);
|
||||||
|
labelall.add(p.label());
|
||||||
|
}
|
||||||
|
|
||||||
|
DMatrix dmat = new DMatrix(blist.iterator(), null);
|
||||||
|
|
||||||
|
int round = 50;
|
||||||
|
int earlyStop = 2;
|
||||||
|
|
||||||
|
HashMap<String, Object> mapParams = new HashMap<>();
|
||||||
|
mapParams.put("eta", 0.1);
|
||||||
|
mapParams.put("objective", "binary:logistic");
|
||||||
|
mapParams.put("max_depth", 3);
|
||||||
|
mapParams.put("eval_metric", "auc");
|
||||||
|
mapParams.put("silent", 0);
|
||||||
|
|
||||||
|
HashMap<String, DMatrix> mapWatches = new HashMap<>();
|
||||||
|
mapWatches.put("selTrain-*", dmat);
|
||||||
|
|
||||||
|
try {
|
||||||
|
Booster booster = XGBoost.train(dmat, mapParams, round, mapWatches, null, null, null, earlyStop);
|
||||||
|
Map<String, String> attrs = booster.getAttrs();
|
||||||
|
TestCase.assertTrue(Integer.valueOf(attrs.get("best_iteration")) < round - 1);
|
||||||
|
} catch (Exception e) {
|
||||||
|
TestCase.assertFalse(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user