[jvm-packages]support multiple validation datasets in Spark (#3910)
* add back train method but mark as deprecated * add back train method but mark as deprecated * add back train method but mark as deprecated * add back train method but mark as deprecated * fix scalastyle error * fix scalastyle error * fix scalastyle error * fix scalastyle error * wrap iterators * enable copartition training and validationset * add parameters * converge code path and have init unit test * enable multi evals for ranking * unit test and doc * update example * fix early stopping * address the offline comments * udpate doc * test eval metrics * fix compilation issue * fix example
This commit is contained in:
parent
c8c7b9649c
commit
c055a32609
@ -200,6 +200,11 @@ In additional to ``num_early_stopping_rounds``, you also need to define ``maximi
|
||||
|
||||
After specifying these two parameters, the training would stop when the metrics goes to the other direction against the one specified by ``maximize_evaluation_metrics`` for ``num_early_stopping_rounds`` iterations.
|
||||
|
||||
Training with Evaluation Sets
|
||||
----------------
|
||||
|
||||
You can also monitor the performance of the model during training with multiple evaluation datasets. By specifying ``eval_sets`` or call ``setEvalSets`` over a XGBoostClassifier or XGBoostRegressor, you can pass in multiple evaluation datasets typed as a Map from String to DataFrame.
|
||||
|
||||
Prediction
|
||||
==========
|
||||
|
||||
|
||||
@ -40,7 +40,7 @@ object SparkTraining {
|
||||
StructField("petal length", DoubleType, true),
|
||||
StructField("petal width", DoubleType, true),
|
||||
StructField("class", StringType, true)))
|
||||
val rawInput = spark.read.schema(schema).csv(args(0))
|
||||
val rawInput = spark.read.schema(schema).csv(inputPath)
|
||||
|
||||
// transform class to index to make xgboost happy
|
||||
val stringIndexer = new StringIndexer()
|
||||
@ -55,6 +55,8 @@ object SparkTraining {
|
||||
val xgbInput = vectorAssembler.transform(labelTransformed).select("features",
|
||||
"classIndex")
|
||||
|
||||
val Array(train, eval1, eval2, test) = xgbInput.randomSplit(Array(0.6, 0.2, 0.1, 0.1))
|
||||
|
||||
/**
|
||||
* setup "timeout_request_workers" -> 60000L to make this application if it cannot get enough resources
|
||||
* to get 2 workers within 60000 ms
|
||||
@ -67,12 +69,13 @@ object SparkTraining {
|
||||
"objective" -> "multi:softprob",
|
||||
"num_class" -> 3,
|
||||
"num_round" -> 100,
|
||||
"num_workers" -> 2)
|
||||
"num_workers" -> 2,
|
||||
"eval_sets" -> Map("eval1" -> eval1, "eval2" -> eval2))
|
||||
val xgbClassifier = new XGBoostClassifier(xgbParam).
|
||||
setFeaturesCol("features").
|
||||
setLabelCol("classIndex")
|
||||
val xgbClassificationModel = xgbClassifier.fit(xgbInput)
|
||||
val results = xgbClassificationModel.transform(xgbInput)
|
||||
val xgbClassificationModel = xgbClassifier.fit(train)
|
||||
val results = xgbClassificationModel.transform(test)
|
||||
results.show()
|
||||
}
|
||||
}
|
||||
|
||||
@ -20,6 +20,11 @@ import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
|
||||
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
|
||||
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
|
||||
import org.apache.spark.ml.param.Param
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.{Column, DataFrame, Row}
|
||||
import org.apache.spark.sql.functions.col
|
||||
import org.apache.spark.sql.types.{FloatType, IntegerType}
|
||||
|
||||
object DataUtils extends Serializable {
|
||||
private[spark] implicit class XGBLabeledPointFeatures(
|
||||
@ -67,4 +72,38 @@ object DataUtils extends Serializable {
|
||||
XGBLabeledPoint(0.0f, v.indices, v.values.map(_.toFloat))
|
||||
}
|
||||
}
|
||||
|
||||
private[spark] def convertDataFrameToXGBLabeledPointRDDs(
|
||||
labelCol: Column,
|
||||
featuresCol: Column,
|
||||
weight: Column,
|
||||
baseMargin: Column,
|
||||
group: Option[Column],
|
||||
dataFrames: DataFrame*): Array[RDD[XGBLabeledPoint]] = {
|
||||
val selectedColumns = group.map(groupCol => Seq(labelCol.cast(FloatType),
|
||||
featuresCol,
|
||||
weight.cast(FloatType),
|
||||
groupCol.cast(IntegerType),
|
||||
baseMargin.cast(FloatType))).getOrElse(Seq(labelCol.cast(FloatType),
|
||||
featuresCol,
|
||||
weight.cast(FloatType),
|
||||
baseMargin.cast(FloatType)))
|
||||
dataFrames.toArray.map {
|
||||
df => df.select(selectedColumns: _*).rdd.map {
|
||||
case Row(label: Float, features: Vector, weight: Float, group: Int, baseMargin: Float) =>
|
||||
val (indices, values) = features match {
|
||||
case v: SparseVector => (v.indices, v.values.map(_.toFloat))
|
||||
case v: DenseVector => (null, v.values.map(_.toFloat))
|
||||
}
|
||||
XGBLabeledPoint(label, indices, values, weight, group, baseMargin)
|
||||
case Row(label: Float, features: Vector, weight: Float, baseMargin: Float) =>
|
||||
val (indices, values) = features match {
|
||||
case v: SparseVector => (v.indices, v.values.map(_.toFloat))
|
||||
case v: DenseVector => (null, v.values.map(_.toFloat))
|
||||
}
|
||||
XGBLabeledPoint(label, indices, values, weight, baseMargin = baseMargin)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -19,6 +19,7 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
import java.io.File
|
||||
import java.nio.file.Files
|
||||
|
||||
import scala.collection.mutable.ListBuffer
|
||||
import scala.collection.{AbstractIterator, mutable}
|
||||
import scala.util.Random
|
||||
|
||||
@ -31,7 +32,7 @@ import org.apache.commons.logging.LogFactory
|
||||
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext}
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.apache.spark.sql.{DataFrame, SparkSession}
|
||||
|
||||
|
||||
/**
|
||||
@ -114,13 +115,12 @@ object XGBoost extends Serializable {
|
||||
round: Int,
|
||||
obj: ObjectiveTrait,
|
||||
eval: EvalTrait,
|
||||
prevBooster: Booster)
|
||||
: Iterator[(Booster, Map[String, Array[Float]])] = {
|
||||
prevBooster: Booster): Iterator[(Booster, Map[String, Array[Float]])] = {
|
||||
|
||||
// to workaround the empty partitions in training dataset,
|
||||
// this might not be the best efficient implementation, see
|
||||
// (https://github.com/dmlc/xgboost/issues/1277)
|
||||
if (watches.train.rowNum == 0) {
|
||||
if (watches.toMap("train").rowNum == 0) {
|
||||
throw new XGBoostError(
|
||||
s"detected an empty partition in the training data, partition ID:" +
|
||||
s" ${TaskContext.getPartitionId()}")
|
||||
@ -138,7 +138,7 @@ object XGBoost extends Serializable {
|
||||
}
|
||||
}
|
||||
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round))
|
||||
val booster = SXGBoost.train(watches.train, params, round,
|
||||
val booster = SXGBoost.train(watches.toMap("train"), params, round,
|
||||
watches.toMap, metrics, obj, eval,
|
||||
earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
|
||||
Iterator(booster -> watches.toMap.keys.zip(metrics).toMap)
|
||||
@ -175,6 +175,53 @@ object XGBoost extends Serializable {
|
||||
tracker
|
||||
}
|
||||
|
||||
class IteratorWrapper[T](arrayOfXGBLabeledPoints: Array[(String, Iterator[T])])
|
||||
extends Iterator[(String, Iterator[T])] {
|
||||
|
||||
private var currentIndex = 0
|
||||
|
||||
override def hasNext: Boolean = currentIndex <= arrayOfXGBLabeledPoints.length - 1
|
||||
|
||||
override def next(): (String, Iterator[T]) = {
|
||||
currentIndex += 1
|
||||
arrayOfXGBLabeledPoints(currentIndex - 1)
|
||||
}
|
||||
}
|
||||
|
||||
private def coPartitionNoGroupSets(
|
||||
trainingData: RDD[XGBLabeledPoint],
|
||||
evalSets: Map[String, RDD[XGBLabeledPoint]],
|
||||
params: Map[String, Any]) = {
|
||||
val nWorkers = params("num_workers").asInstanceOf[Int]
|
||||
// eval_sets is supposed to be set by the caller of [[trainDistributed]]
|
||||
val allDatasets = Map("train" -> trainingData) ++ evalSets
|
||||
val repartitionedDatasets = allDatasets.map{case (name, rdd) =>
|
||||
if (rdd.getNumPartitions != nWorkers) {
|
||||
(name, rdd.repartition(nWorkers))
|
||||
} else {
|
||||
(name, rdd)
|
||||
}
|
||||
}
|
||||
repartitionedDatasets.foldLeft(trainingData.sparkContext.parallelize(
|
||||
Array.fill[(String, Iterator[XGBLabeledPoint])](nWorkers)(null), nWorkers)){
|
||||
case (rddOfIterWrapper, (name, rddOfIter)) =>
|
||||
rddOfIterWrapper.zipPartitions(rddOfIter){
|
||||
(itrWrapper, itr) =>
|
||||
if (!itr.hasNext) {
|
||||
logger.error("when specifying eval sets as dataframes, you have to ensure that " +
|
||||
"the number of elements in each dataframe is larger than the number of workers")
|
||||
throw new Exception("too few elements in evaluation sets")
|
||||
}
|
||||
val itrArray = itrWrapper.toArray
|
||||
if (itrArray.head != null) {
|
||||
new IteratorWrapper(itrArray :+ (name -> itr))
|
||||
} else {
|
||||
new IteratorWrapper(Array(name -> itr))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check to see if Spark expects SSL encryption (`spark.ssl.enabled` set to true).
|
||||
* If so, throw an exception unless this safety measure has been explicitly overridden
|
||||
@ -215,10 +262,16 @@ object XGBoost extends Serializable {
|
||||
val eval = params.getOrElse("custom_eval", null).asInstanceOf[EvalTrait]
|
||||
val missing = params.getOrElse("missing", Float.NaN).asInstanceOf[Float]
|
||||
validateSparkSslConf(sparkContext)
|
||||
|
||||
if (params.contains("tree_method")) {
|
||||
require(params("tree_method") != "hist", "xgboost4j-spark does not support fast histogram" +
|
||||
" for now")
|
||||
}
|
||||
if (params.contains("train_test_ratio")) {
|
||||
logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" +
|
||||
" pass a training and multiple evaluation datasets by passing 'eval_sets' and " +
|
||||
"'eval_set_names'")
|
||||
}
|
||||
require(nWorkers > 0, "you must specify more than 0 workers")
|
||||
if (obj != null) {
|
||||
require(params.get("objective_type").isDefined, "parameter \"objective_type\" is not" +
|
||||
@ -247,17 +300,30 @@ object XGBoost extends Serializable {
|
||||
params: Map[String, Any],
|
||||
rabitEnv: java.util.Map[String, String],
|
||||
checkpointRound: Int,
|
||||
prevBooster: Booster) = {
|
||||
prevBooster: Booster,
|
||||
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
|
||||
val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _) =
|
||||
parameterFetchAndValidation(params, trainingData.sparkContext)
|
||||
val partitionedData = repartitionForTraining(trainingData, nWorkers)
|
||||
partitionedData.mapPartitions(labeledPoints => {
|
||||
val watches = Watches.buildWatches(params,
|
||||
removeMissingValues(labeledPoints, missing),
|
||||
getCacheDirName(useExternalMemory))
|
||||
buildDistributedBooster(watches, params, rabitEnv, checkpointRound,
|
||||
obj, eval, prevBooster)
|
||||
}).cache()
|
||||
if (evalSetsMap.isEmpty) {
|
||||
partitionedData.mapPartitions(labeledPoints => {
|
||||
val watches = Watches.buildWatches(params,
|
||||
removeMissingValues(labeledPoints, missing),
|
||||
getCacheDirName(useExternalMemory))
|
||||
buildDistributedBooster(watches, params, rabitEnv, checkpointRound,
|
||||
obj, eval, prevBooster)
|
||||
}).cache()
|
||||
} else {
|
||||
coPartitionNoGroupSets(partitionedData, evalSetsMap, params).mapPartitions {
|
||||
nameAndLabeledPointSets =>
|
||||
val watches = Watches.buildWatches(
|
||||
nameAndLabeledPointSets.map {
|
||||
case (name, iter) => (name, removeMissingValues(iter, missing))},
|
||||
getCacheDirName(useExternalMemory))
|
||||
buildDistributedBooster(watches, params, rabitEnv, checkpointRound,
|
||||
obj, eval, prevBooster)
|
||||
}.cache()
|
||||
}
|
||||
}
|
||||
|
||||
private def trainForRanking(
|
||||
@ -265,17 +331,30 @@ object XGBoost extends Serializable {
|
||||
params: Map[String, Any],
|
||||
rabitEnv: java.util.Map[String, String],
|
||||
checkpointRound: Int,
|
||||
prevBooster: Booster) = {
|
||||
prevBooster: Booster,
|
||||
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
|
||||
val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _) =
|
||||
parameterFetchAndValidation(params, trainingData.sparkContext)
|
||||
val partitionedData = repartitionForTrainingGroup(trainingData, nWorkers)
|
||||
partitionedData.mapPartitions(labeledPointGroups => {
|
||||
val watches = Watches.buildWatchesWithGroup(params,
|
||||
removeMissingValuesWithGroup(labeledPointGroups, missing),
|
||||
getCacheDirName(useExternalMemory))
|
||||
buildDistributedBooster(watches, params, rabitEnv, checkpointRound,
|
||||
obj, eval, prevBooster)
|
||||
}).cache()
|
||||
val partitionedTrainingSet = repartitionForTrainingGroup(trainingData, nWorkers)
|
||||
if (evalSetsMap.isEmpty) {
|
||||
partitionedTrainingSet.mapPartitions(labeledPointGroups => {
|
||||
val watches = Watches.buildWatchesWithGroup(params,
|
||||
removeMissingValuesWithGroup(labeledPointGroups, missing),
|
||||
getCacheDirName(useExternalMemory))
|
||||
buildDistributedBooster(watches, params, rabitEnv, checkpointRound, obj, eval, prevBooster)
|
||||
}).cache()
|
||||
} else {
|
||||
coPartitionGroupSets(partitionedTrainingSet, evalSetsMap, params).mapPartitions(
|
||||
labeledPointGroupSets => {
|
||||
val watches = Watches.buildWatchesWithGroup(
|
||||
labeledPointGroupSets.map {
|
||||
case (name, iter) => (name, removeMissingValuesWithGroup(iter, missing))
|
||||
},
|
||||
getCacheDirName(useExternalMemory))
|
||||
buildDistributedBooster(watches, params, rabitEnv, checkpointRound, obj, eval,
|
||||
prevBooster)
|
||||
}).cache()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@ -285,7 +364,9 @@ object XGBoost extends Serializable {
|
||||
private[spark] def trainDistributed(
|
||||
trainingData: RDD[XGBLabeledPoint],
|
||||
params: Map[String, Any],
|
||||
hasGroup: Boolean = false): (Booster, Map[String, Array[Float]]) = {
|
||||
hasGroup: Boolean = false,
|
||||
evalSetsMap: Map[String, RDD[XGBLabeledPoint]] = Map()):
|
||||
(Booster, Map[String, Array[Float]]) = {
|
||||
val (nWorkers, round, _, _, _, _, trackerConf, timeoutRequestWorkers,
|
||||
checkpointPath, checkpointInterval) = parameterFetchAndValidation(params,
|
||||
trainingData.sparkContext)
|
||||
@ -303,10 +384,10 @@ object XGBoost extends Serializable {
|
||||
val rabitEnv = tracker.getWorkerEnvs
|
||||
val boostersAndMetrics = if (hasGroup) {
|
||||
trainForRanking(trainingData, overriddenParams, rabitEnv, checkpointRound,
|
||||
prevBooster)
|
||||
prevBooster, evalSetsMap)
|
||||
} else {
|
||||
trainForNonRanking(trainingData, overriddenParams, rabitEnv, checkpointRound,
|
||||
prevBooster)
|
||||
prevBooster, evalSetsMap)
|
||||
}
|
||||
val sparkJobThread = new Thread() {
|
||||
override def run() {
|
||||
@ -340,8 +421,7 @@ object XGBoost extends Serializable {
|
||||
}
|
||||
}
|
||||
|
||||
private[spark] def repartitionForTrainingGroup(
|
||||
trainingData: RDD[XGBLabeledPoint], nWorkers: Int): RDD[Array[XGBLabeledPoint]] = {
|
||||
private def aggByGroupInfo(trainingData: RDD[XGBLabeledPoint]) = {
|
||||
val normalGroups: RDD[Array[XGBLabeledPoint]] = trainingData.mapPartitions(
|
||||
// LabeledPointGroupIterator returns (Boolean, Array[XGBLabeledPoint])
|
||||
new LabeledPointGroupIterator(_)).filter(!_.isEdgeGroup).map(_.points)
|
||||
@ -349,22 +429,61 @@ object XGBoost extends Serializable {
|
||||
// edge groups with partition id.
|
||||
val edgeGroups: RDD[(Int, XGBLabeledPointGroup)] = trainingData.mapPartitions(
|
||||
new LabeledPointGroupIterator(_)).filter(_.isEdgeGroup).map(
|
||||
group => (TaskContext.getPartitionId(), group))
|
||||
group => (TaskContext.getPartitionId(), group))
|
||||
|
||||
// group chunks from different partitions together by group id in XGBLabeledPoint.
|
||||
// use groupBy instead of aggregateBy since all groups within a partition have unique groud ids.
|
||||
// use groupBy instead of aggregateBy since all groups within a partition have unique group ids.
|
||||
val stitchedGroups: RDD[Array[XGBLabeledPoint]] = edgeGroups.groupBy(_._2.groupId).map(
|
||||
groups => {
|
||||
val it: Iterable[(Int, XGBLabeledPointGroup)] = groups._2
|
||||
// sorted by partition id and merge list of Array[XGBLabeledPoint] into one array
|
||||
it.toArray.sortBy(_._1).map(_._2.points).flatten
|
||||
it.toArray.sortBy(_._1).flatMap(_._2.points)
|
||||
})
|
||||
normalGroups.union(stitchedGroups)
|
||||
}
|
||||
|
||||
var allGroups = normalGroups.union(stitchedGroups)
|
||||
private[spark] def repartitionForTrainingGroup(
|
||||
trainingData: RDD[XGBLabeledPoint], nWorkers: Int): RDD[Array[XGBLabeledPoint]] = {
|
||||
val allGroups = aggByGroupInfo(trainingData)
|
||||
logger.info(s"repartitioning training group set to $nWorkers partitions")
|
||||
allGroups.repartition(nWorkers)
|
||||
}
|
||||
|
||||
private def coPartitionGroupSets(
|
||||
aggedTrainingSet: RDD[Array[XGBLabeledPoint]],
|
||||
evalSets: Map[String, RDD[XGBLabeledPoint]],
|
||||
params: Map[String, Any]): RDD[(String, Iterator[Array[XGBLabeledPoint]])] = {
|
||||
val nWorkers = params("num_workers").asInstanceOf[Int]
|
||||
val repartitionedDatasets = Map("train" -> aggedTrainingSet) ++ evalSets.map {
|
||||
case (name, rdd) => {
|
||||
val aggedRdd = aggByGroupInfo(rdd)
|
||||
if (aggedRdd.getNumPartitions != nWorkers) {
|
||||
name -> aggedRdd.repartition(nWorkers)
|
||||
} else {
|
||||
name -> aggedRdd
|
||||
}
|
||||
}
|
||||
}
|
||||
repartitionedDatasets.foldLeft(aggedTrainingSet.sparkContext.parallelize(
|
||||
Array.fill[(String, Iterator[Array[XGBLabeledPoint]])](nWorkers)(null), nWorkers)){
|
||||
case (rddOfIterWrapper, (name, rddOfIter)) =>
|
||||
rddOfIterWrapper.zipPartitions(rddOfIter){
|
||||
(itrWrapper, itr) =>
|
||||
if (!itr.hasNext) {
|
||||
logger.error("when specifying eval sets as dataframes, you have to ensure that " +
|
||||
"the number of elements in each dataframe is larger than the number of workers")
|
||||
throw new Exception("too few elements in evaluation sets")
|
||||
}
|
||||
val itrArray = itrWrapper.toArray
|
||||
if (itrArray.head != null) {
|
||||
new IteratorWrapper(itrArray :+ (name -> itr))
|
||||
} else {
|
||||
new IteratorWrapper(Array(name -> itr))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def postTrackerReturnProcessing(
|
||||
trackerReturnVal: Int,
|
||||
distributedBoostersAndMetrics: RDD[(Booster, Map[String, Array[Float]])],
|
||||
@ -395,12 +514,13 @@ object XGBoost extends Serializable {
|
||||
}
|
||||
|
||||
private class Watches private(
|
||||
val train: DMatrix,
|
||||
val test: DMatrix,
|
||||
private val cacheDirName: Option[String]) {
|
||||
val datasets: Array[DMatrix],
|
||||
val names: Array[String],
|
||||
val cacheDirName: Option[String]) {
|
||||
|
||||
def toMap: Map[String, DMatrix] = Map("train" -> train, "test" -> test)
|
||||
.filter { case (_, matrix) => matrix.rowNum > 0 }
|
||||
def toMap: Map[String, DMatrix] = {
|
||||
names.zip(datasets).toMap.filter { case (_, matrix) => matrix.rowNum > 0 }
|
||||
}
|
||||
|
||||
def size: Int = toMap.size
|
||||
|
||||
@ -440,6 +560,26 @@ private object Watches {
|
||||
}
|
||||
}
|
||||
|
||||
def buildWatches(
|
||||
nameAndLabeledPointSets: Iterator[(String, Iterator[XGBLabeledPoint])],
|
||||
cachedDirName: Option[String]): Watches = {
|
||||
val dms = nameAndLabeledPointSets.map {
|
||||
case (name, labeledPoints) =>
|
||||
val baseMargins = new mutable.ArrayBuilder.ofFloat
|
||||
val duplicatedItr = labeledPoints.map(labeledPoint => {
|
||||
baseMargins += labeledPoint.baseMargin
|
||||
labeledPoint
|
||||
})
|
||||
val dMatrix = new DMatrix(duplicatedItr, cachedDirName.map(_ + s"/$name").orNull)
|
||||
val baseMargin = fromBaseMarginsToArray(baseMargins.result().iterator)
|
||||
if (baseMargin.isDefined) {
|
||||
dMatrix.setBaseMargin(baseMargin.get)
|
||||
}
|
||||
(name, dMatrix)
|
||||
}.toArray
|
||||
new Watches(dms.map(_._2), dms.map(_._1), cachedDirName)
|
||||
}
|
||||
|
||||
def buildWatches(
|
||||
params: Map[String, Any],
|
||||
labeledPoints: Iterator[XGBLabeledPoint],
|
||||
@ -468,7 +608,30 @@ private object Watches {
|
||||
if (trainMargin.isDefined) trainMatrix.setBaseMargin(trainMargin.get)
|
||||
if (testMargin.isDefined) testMatrix.setBaseMargin(testMargin.get)
|
||||
|
||||
new Watches(trainMatrix, testMatrix, cacheDirName)
|
||||
new Watches(Array(trainMatrix, testMatrix), Array("train", "test"), cacheDirName)
|
||||
}
|
||||
|
||||
def buildWatchesWithGroup(
|
||||
nameAndlabeledPointGroupSets: Iterator[(String, Iterator[Array[XGBLabeledPoint]])],
|
||||
cachedDirName: Option[String]): Watches = {
|
||||
val dms = nameAndlabeledPointGroupSets.map {
|
||||
case (name, labeledPointsGroups) =>
|
||||
val baseMargins = new mutable.ArrayBuilder.ofFloat
|
||||
val duplicatedItr = labeledPointsGroups.map(labeledPoints => {
|
||||
labeledPoints.map { labeledPoint =>
|
||||
baseMargins += labeledPoint.baseMargin
|
||||
labeledPoint
|
||||
}
|
||||
})
|
||||
val dMatrix = new DMatrix(duplicatedItr.flatMap(_.iterator),
|
||||
cachedDirName.map(_ + s"/$name").orNull)
|
||||
val baseMargin = fromBaseMarginsToArray(baseMargins.result().iterator)
|
||||
if (baseMargin.isDefined) {
|
||||
dMatrix.setBaseMargin(baseMargin.get)
|
||||
}
|
||||
(name, dMatrix)
|
||||
}.toArray
|
||||
new Watches(dms.map(_._2), dms.map(_._1), cachedDirName)
|
||||
}
|
||||
|
||||
def buildWatchesWithGroup(
|
||||
@ -513,7 +676,7 @@ private object Watches {
|
||||
if (trainMargin.isDefined) trainMatrix.setBaseMargin(trainMargin.get)
|
||||
if (testMargin.isDefined) testMatrix.setBaseMargin(testMargin.get)
|
||||
|
||||
new Watches(trainMatrix, testMatrix, cacheDirName)
|
||||
new Watches(Array(trainMatrix, testMatrix), Array("train", "test"), cacheDirName)
|
||||
}
|
||||
}
|
||||
|
||||
@ -532,7 +695,7 @@ private[spark] class LabeledPointGroupIterator(base: Iterator[XGBLabeledPoint])
|
||||
private var isNewGroup = false
|
||||
|
||||
override def hasNext: Boolean = {
|
||||
return base.hasNext || isNewGroup
|
||||
base.hasNext || isNewGroup
|
||||
}
|
||||
|
||||
override def next(): XGBLabeledPointGroup = {
|
||||
|
||||
@ -43,7 +43,7 @@ import org.apache.spark.broadcast.Broadcast
|
||||
|
||||
private[spark] trait XGBoostClassifierParams extends GeneralParams with LearningTaskParams
|
||||
with BoosterParams with HasWeightCol with HasBaseMarginCol with HasNumClass with ParamMapFuncs
|
||||
with HasLeafPredictionCol with HasContribPredictionCol
|
||||
with HasLeafPredictionCol with HasContribPredictionCol with NonParamVariables
|
||||
|
||||
class XGBoostClassifier (
|
||||
override val uid: String,
|
||||
@ -182,23 +182,19 @@ class XGBoostClassifier (
|
||||
col($(baseMarginCol))
|
||||
}
|
||||
|
||||
val instances: RDD[XGBLabeledPoint] = dataset.select(
|
||||
col($(featuresCol)),
|
||||
col($(labelCol)).cast(FloatType),
|
||||
baseMargin.cast(FloatType),
|
||||
weight.cast(FloatType)
|
||||
).rdd.map { case Row(features: Vector, label: Float, baseMargin: Float, weight: Float) =>
|
||||
val (indices, values) = features match {
|
||||
case v: SparseVector => (v.indices, v.values.map(_.toFloat))
|
||||
case v: DenseVector => (null, v.values.map(_.toFloat))
|
||||
}
|
||||
XGBLabeledPoint(label, indices, values, baseMargin = baseMargin, weight = weight)
|
||||
val trainingSet: RDD[XGBLabeledPoint] = DataUtils.convertDataFrameToXGBLabeledPointRDDs(
|
||||
col($(labelCol)), col($(featuresCol)), weight, baseMargin,
|
||||
None, dataset.asInstanceOf[DataFrame]).head
|
||||
val evalRDDMap = getEvalSets(xgboostParams).map {
|
||||
case (name, dataFrame) => (name,
|
||||
DataUtils.convertDataFrameToXGBLabeledPointRDDs(col($(labelCol)), col($(featuresCol)),
|
||||
weight, baseMargin, None, dataFrame).head)
|
||||
}
|
||||
transformSchema(dataset.schema, logging = true)
|
||||
val derivedXGBParamMap = MLlib2XGBoostParams
|
||||
// All non-null param maps in XGBoostClassifier are in derivedXGBParamMap.
|
||||
val (_booster, _metrics) = XGBoost.trainDistributed(instances, derivedXGBParamMap,
|
||||
hasGroup = false)
|
||||
val (_booster, _metrics) = XGBoost.trainDistributed(trainingSet, derivedXGBParamMap,
|
||||
hasGroup = false, evalRDDMap)
|
||||
val model = new XGBoostClassificationModel(uid, _numClasses, _booster)
|
||||
val summary = XGBoostTrainingSummary(_metrics)
|
||||
model.setSummary(summary)
|
||||
|
||||
@ -43,7 +43,7 @@ import org.apache.spark.broadcast.Broadcast
|
||||
|
||||
private[spark] trait XGBoostRegressorParams extends GeneralParams with BoosterParams
|
||||
with LearningTaskParams with HasBaseMarginCol with HasWeightCol with HasGroupCol
|
||||
with ParamMapFuncs with HasLeafPredictionCol with HasContribPredictionCol
|
||||
with ParamMapFuncs with HasLeafPredictionCol with HasContribPredictionCol with NonParamVariables
|
||||
|
||||
class XGBoostRegressor (
|
||||
override val uid: String,
|
||||
@ -174,26 +174,19 @@ class XGBoostRegressor (
|
||||
col($(baseMarginCol))
|
||||
}
|
||||
val group = if (!isDefined(groupCol) || $(groupCol).isEmpty) lit(-1) else col($(groupCol))
|
||||
|
||||
val instances: RDD[XGBLabeledPoint] = dataset.select(
|
||||
col($(labelCol)).cast(FloatType),
|
||||
col($(featuresCol)),
|
||||
weight.cast(FloatType),
|
||||
group.cast(IntegerType),
|
||||
baseMargin.cast(FloatType)
|
||||
).rdd.map {
|
||||
case Row(label: Float, features: Vector, weight: Float, group: Int, baseMargin: Float) =>
|
||||
val (indices, values) = features match {
|
||||
case v: SparseVector => (v.indices, v.values.map(_.toFloat))
|
||||
case v: DenseVector => (null, v.values.map(_.toFloat))
|
||||
}
|
||||
XGBLabeledPoint(label, indices, values, weight, group, baseMargin)
|
||||
val trainingSet: RDD[XGBLabeledPoint] = DataUtils.convertDataFrameToXGBLabeledPointRDDs(
|
||||
col($(labelCol)), col($(featuresCol)), weight, baseMargin, Some(group),
|
||||
dataset.asInstanceOf[DataFrame]).head
|
||||
val evalRDDMap = getEvalSets(xgboostParams).map {
|
||||
case (name, dataFrame) => (name,
|
||||
DataUtils.convertDataFrameToXGBLabeledPointRDDs(col($(labelCol)), col($(featuresCol)),
|
||||
weight, baseMargin, Some(group), dataFrame).head)
|
||||
}
|
||||
transformSchema(dataset.schema, logging = true)
|
||||
val derivedXGBParamMap = MLlib2XGBoostParams
|
||||
// All non-null param maps in XGBoostRegressor are in derivedXGBParamMap.
|
||||
val (_booster, _metrics) = XGBoost.trainDistributed(instances, derivedXGBParamMap,
|
||||
hasGroup = group != lit(-1))
|
||||
val (_booster, _metrics) = XGBoost.trainDistributed(trainingSet, derivedXGBParamMap,
|
||||
hasGroup = group != lit(-1), evalRDDMap)
|
||||
val model = new XGBoostRegressionModel(uid, _booster)
|
||||
val summary = XGBoostTrainingSummary(_metrics)
|
||||
model.setSummary(summary)
|
||||
|
||||
@ -18,12 +18,17 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
class XGBoostTrainingSummary private(
|
||||
val trainObjectiveHistory: Array[Float],
|
||||
val testObjectiveHistory: Option[Array[Float]]
|
||||
) extends Serializable {
|
||||
val validationObjectiveHistory: (String, Array[Float])*) extends Serializable {
|
||||
|
||||
override def toString: String = {
|
||||
val train = trainObjectiveHistory.toList
|
||||
val test = testObjectiveHistory.map(_.toList)
|
||||
s"XGBoostTrainingSummary(trainObjectiveHistory=$train, testObjectiveHistory=$test)"
|
||||
val train = trainObjectiveHistory.mkString(",")
|
||||
val vaidationObjectiveHistoryString = {
|
||||
validationObjectiveHistory.map {
|
||||
case (name, metrics) =>
|
||||
s"${name}ObjectiveHistory=${metrics.mkString(",")}"
|
||||
}.mkString(";")
|
||||
}
|
||||
s"XGBoostTrainingSummary(trainObjectiveHistory=$train; $vaidationObjectiveHistoryString)"
|
||||
}
|
||||
}
|
||||
|
||||
@ -31,6 +36,6 @@ private[xgboost4j] object XGBoostTrainingSummary {
|
||||
def apply(metrics: Map[String, Array[Float]]): XGBoostTrainingSummary = {
|
||||
new XGBoostTrainingSummary(
|
||||
trainObjectiveHistory = metrics("train"),
|
||||
testObjectiveHistory = metrics.get("test"))
|
||||
metrics.filter(_._1 != "train").toSeq: _*)
|
||||
}
|
||||
}
|
||||
|
||||
@ -22,26 +22,7 @@ import org.json4s.{DefaultFormats, Extraction, NoTypeHints}
|
||||
import org.json4s.jackson.JsonMethods.{compact, parse, render}
|
||||
|
||||
import org.apache.spark.ml.param.{Param, ParamPair, Params}
|
||||
|
||||
class GroupDataParam(
|
||||
parent: Params,
|
||||
name: String,
|
||||
doc: String) extends Param[Seq[Seq[Int]]](parent, name, doc) {
|
||||
|
||||
/** Creates a param pair with the given value (for Java). */
|
||||
override def w(value: Seq[Seq[Int]]): ParamPair[Seq[Seq[Int]]] = super.w(value)
|
||||
|
||||
override def jsonEncode(value: Seq[Seq[Int]]): String = {
|
||||
import org.json4s.jackson.Serialization
|
||||
implicit val formats = Serialization.formats(NoTypeHints)
|
||||
compact(render(Extraction.decompose(value)))
|
||||
}
|
||||
|
||||
override def jsonDecode(json: String): Seq[Seq[Int]] = {
|
||||
implicit val formats = DefaultFormats
|
||||
parse(json).extract[Seq[Seq[Int]]]
|
||||
}
|
||||
}
|
||||
import org.apache.spark.sql.DataFrame
|
||||
|
||||
class CustomEvalParam(
|
||||
parent: Params,
|
||||
|
||||
@ -22,6 +22,14 @@ import ml.dmlc.xgboost4j.scala.spark.TrackerConf
|
||||
import org.apache.spark.ml.param._
|
||||
import scala.collection.mutable
|
||||
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
|
||||
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector}
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.{Column, DataFrame, Row}
|
||||
import org.apache.spark.sql.functions.col
|
||||
import org.apache.spark.sql.types.{FloatType, IntegerType}
|
||||
|
||||
private[spark] trait GeneralParams extends Params {
|
||||
|
||||
/**
|
||||
@ -154,8 +162,7 @@ private[spark] trait GeneralParams extends Params {
|
||||
useExternalMemory -> false, silent -> 0,
|
||||
customObj -> null, customEval -> null, missing -> Float.NaN,
|
||||
trackerConf -> TrackerConf(), seed -> 0, timeoutRequestWorkers -> 30 * 60 * 1000L,
|
||||
checkpointPath -> "", checkpointInterval -> -1
|
||||
)
|
||||
checkpointPath -> "", checkpointInterval -> -1)
|
||||
}
|
||||
|
||||
trait HasLeafPredictionCol extends Params {
|
||||
|
||||
@ -0,0 +1,35 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark.params
|
||||
|
||||
import org.apache.spark.sql.DataFrame
|
||||
|
||||
trait NonParamVariables {
|
||||
protected var evalSetsMap: Map[String, DataFrame] = Map.empty
|
||||
|
||||
def setEvalSets(evalSets: Map[String, DataFrame]): Unit = {
|
||||
evalSetsMap = evalSets
|
||||
}
|
||||
|
||||
def getEvalSets(params: Map[String, Any]): Map[String, DataFrame] = {
|
||||
if (params.contains("eval_sets")) {
|
||||
params("eval_sets").asInstanceOf[Map[String, DataFrame]]
|
||||
} else {
|
||||
evalSetsMap
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -137,7 +137,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
|
||||
assert(predictionDF.columns.contains("final_prediction") === false)
|
||||
|
||||
assert(model.summary.trainObjectiveHistory.length === 5)
|
||||
assert(model.summary.testObjectiveHistory.isEmpty)
|
||||
assert(model.summary.validationObjectiveHistory.isEmpty)
|
||||
}
|
||||
|
||||
test("XGBoost and Spark parameters synchronize correctly") {
|
||||
@ -191,31 +191,6 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
|
||||
assert(count != 0)
|
||||
}
|
||||
|
||||
test("training summary") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "num_round" -> 5, "nWorkers" -> numWorkers)
|
||||
|
||||
val trainingDF = buildDataFrame(Classification.train)
|
||||
val xgb = new XGBoostClassifier(paramMap)
|
||||
val model = xgb.fit(trainingDF)
|
||||
|
||||
assert(model.summary.trainObjectiveHistory.length === 5)
|
||||
assert(model.summary.testObjectiveHistory.isEmpty)
|
||||
}
|
||||
|
||||
test("train/test split") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
|
||||
"num_round" -> 5, "num_workers" -> numWorkers)
|
||||
val training = buildDataFrame(Classification.train)
|
||||
|
||||
val xgb = new XGBoostClassifier(paramMap)
|
||||
val model = xgb.fit(training)
|
||||
val Some(testObjectiveHistory) = model.summary.testObjectiveHistory
|
||||
assert(testObjectiveHistory.length === 5)
|
||||
assert(model.summary.trainObjectiveHistory !== testObjectiveHistory)
|
||||
}
|
||||
|
||||
test("test predictionLeaf") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
|
||||
|
||||
@ -277,4 +277,93 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
||||
|
||||
assert(booster != null)
|
||||
}
|
||||
|
||||
test("training summary") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "num_round" -> 5, "nWorkers" -> numWorkers)
|
||||
|
||||
val trainingDF = buildDataFrame(Classification.train)
|
||||
val xgb = new XGBoostClassifier(paramMap)
|
||||
val model = xgb.fit(trainingDF)
|
||||
|
||||
assert(model.summary.trainObjectiveHistory.length === 5)
|
||||
assert(model.summary.validationObjectiveHistory.isEmpty)
|
||||
}
|
||||
|
||||
test("train/test split") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
|
||||
"num_round" -> 5, "num_workers" -> numWorkers)
|
||||
val training = buildDataFrame(Classification.train)
|
||||
|
||||
val xgb = new XGBoostClassifier(paramMap)
|
||||
val model = xgb.fit(training)
|
||||
assert(model.summary.validationObjectiveHistory.length === 1)
|
||||
assert(model.summary.validationObjectiveHistory(0)._1 === "test")
|
||||
assert(model.summary.validationObjectiveHistory(0)._2.length === 5)
|
||||
assert(model.summary.trainObjectiveHistory !== model.summary.validationObjectiveHistory(0))
|
||||
}
|
||||
|
||||
test("train with multiple validation datasets (non-ranking)") {
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val Array(train, eval1, eval2) = training.randomSplit(Array(0.6, 0.2, 0.2))
|
||||
val paramMap1 = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic",
|
||||
"num_round" -> 5, "num_workers" -> numWorkers)
|
||||
|
||||
val xgb1 = new XGBoostClassifier(paramMap1)
|
||||
xgb1.setEvalSets(Map("eval1" -> eval1, "eval2" -> eval2))
|
||||
val model1 = xgb1.fit(train)
|
||||
assert(model1.summary.validationObjectiveHistory.length === 2)
|
||||
assert(model1.summary.validationObjectiveHistory.map(_._1).toSet === Set("eval1", "eval2"))
|
||||
assert(model1.summary.validationObjectiveHistory(0)._2.length === 5)
|
||||
assert(model1.summary.validationObjectiveHistory(1)._2.length === 5)
|
||||
assert(model1.summary.trainObjectiveHistory !== model1.summary.validationObjectiveHistory(0))
|
||||
assert(model1.summary.trainObjectiveHistory !== model1.summary.validationObjectiveHistory(1))
|
||||
|
||||
val paramMap2 = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic",
|
||||
"num_round" -> 5, "num_workers" -> numWorkers,
|
||||
"eval_sets" -> Map("eval1" -> eval1, "eval2" -> eval2))
|
||||
val xgb2 = new XGBoostClassifier(paramMap2)
|
||||
val model2 = xgb2.fit(train)
|
||||
assert(model2.summary.validationObjectiveHistory.length === 2)
|
||||
assert(model2.summary.validationObjectiveHistory.map(_._1).toSet === Set("eval1", "eval2"))
|
||||
assert(model2.summary.validationObjectiveHistory(0)._2.length === 5)
|
||||
assert(model2.summary.validationObjectiveHistory(1)._2.length === 5)
|
||||
assert(model2.summary.trainObjectiveHistory !== model2.summary.validationObjectiveHistory(0))
|
||||
assert(model2.summary.trainObjectiveHistory !== model2.summary.validationObjectiveHistory(1))
|
||||
}
|
||||
|
||||
test("train with multiple validation datasets (ranking)") {
|
||||
val training = buildDataFrameWithGroup(Ranking.train, 5)
|
||||
val Array(train, eval1, eval2) = training.randomSplit(Array(0.6, 0.2, 0.2))
|
||||
val paramMap1 = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "rank:pairwise",
|
||||
"num_round" -> 5, "num_workers" -> numWorkers, "group_col" -> "group")
|
||||
val xgb1 = new XGBoostRegressor(paramMap1)
|
||||
xgb1.setEvalSets(Map("eval1" -> eval1, "eval2" -> eval2))
|
||||
val model1 = xgb1.fit(train)
|
||||
assert(model1 != null)
|
||||
assert(model1.summary.validationObjectiveHistory.length === 2)
|
||||
assert(model1.summary.validationObjectiveHistory.map(_._1).toSet === Set("eval1", "eval2"))
|
||||
assert(model1.summary.validationObjectiveHistory(0)._2.length === 5)
|
||||
assert(model1.summary.validationObjectiveHistory(1)._2.length === 5)
|
||||
assert(model1.summary.trainObjectiveHistory !== model1.summary.validationObjectiveHistory(0))
|
||||
assert(model1.summary.trainObjectiveHistory !== model1.summary.validationObjectiveHistory(1))
|
||||
|
||||
val paramMap2 = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "rank:pairwise",
|
||||
"num_round" -> 5, "num_workers" -> numWorkers, "group_col" -> "group",
|
||||
"eval_sets" -> Map("eval1" -> eval1, "eval2" -> eval2))
|
||||
val xgb2 = new XGBoostRegressor(paramMap2)
|
||||
val model2 = xgb2.fit(train)
|
||||
assert(model2 != null)
|
||||
assert(model2.summary.validationObjectiveHistory.length === 2)
|
||||
assert(model2.summary.validationObjectiveHistory.map(_._1).toSet === Set("eval1", "eval2"))
|
||||
assert(model2.summary.validationObjectiveHistory(0)._2.length === 5)
|
||||
assert(model2.summary.validationObjectiveHistory(1)._2.length === 5)
|
||||
assert(model2.summary.trainObjectiveHistory !== model2.summary.validationObjectiveHistory(0))
|
||||
assert(model2.summary.trainObjectiveHistory !== model2.summary.validationObjectiveHistory(1))
|
||||
}
|
||||
}
|
||||
|
||||
@ -222,14 +222,19 @@ public class XGBoost {
|
||||
if (iter < earlyStoppingRounds - 1) {
|
||||
return true;
|
||||
}
|
||||
float[] criterion = metrics[metrics.length - 1];
|
||||
for (int shift = 0; shift < earlyStoppingRounds - 1; shift++) {
|
||||
// the initial value of onTrack is false and if the metrics in any of `earlyStoppingRounds`
|
||||
// iterations goes to the expected direction, we should consider these `earlyStoppingRounds`
|
||||
// as `onTrack`
|
||||
onTrack |= maximizeEvaluationMetrics ?
|
||||
criterion[iter - shift] >= criterion[iter - shift - 1] :
|
||||
criterion[iter - shift] <= criterion[iter - shift - 1];
|
||||
for (int metricsId = metrics.length == 1 ? 0 : 1; metricsId < metrics.length; metricsId++) {
|
||||
float[] criterion = metrics[metricsId];
|
||||
for (int shift = 0; shift < earlyStoppingRounds - 1; shift++) {
|
||||
// the initial value of onTrack is false and if the metrics in any of `earlyStoppingRounds`
|
||||
// iterations goes to the expected direction, we should consider these `earlyStoppingRounds`
|
||||
// as `onTrack`
|
||||
onTrack |= maximizeEvaluationMetrics ?
|
||||
criterion[iter - shift] >= criterion[iter - shift - 1] :
|
||||
criterion[iter - shift] <= criterion[iter - shift - 1];
|
||||
}
|
||||
if (!onTrack) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return onTrack;
|
||||
}
|
||||
|
||||
@ -185,6 +185,51 @@ public class BoosterImplTest {
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEarlyStoppingForMultipleMetrics() {
|
||||
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||
{
|
||||
put("max_depth", 3);
|
||||
put("silent", 1);
|
||||
put("objective", "binary:logistic");
|
||||
put("maximize_evaluation_metrics", "true");
|
||||
}
|
||||
};
|
||||
int earlyStoppingRound = 3;
|
||||
int totalIterations = 5;
|
||||
int numOfMetrics = 3;
|
||||
float[][] metrics = new float[numOfMetrics][totalIterations];
|
||||
for (int i = 0; i < numOfMetrics; i++) {
|
||||
for (int j = 0; j < totalIterations; j++) {
|
||||
metrics[0][j] = j;
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < totalIterations; i++) {
|
||||
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRound, metrics, i);
|
||||
TestCase.assertTrue(onTrack);
|
||||
}
|
||||
for (int i = 0; i < totalIterations; i++) {
|
||||
metrics[0][i] = totalIterations - i;
|
||||
}
|
||||
// when we have multiple datasets, the training metrics is not considered
|
||||
for (int i = 0; i < totalIterations; i++) {
|
||||
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRound, metrics, i);
|
||||
TestCase.assertTrue(onTrack);
|
||||
}
|
||||
for (int i = 0; i < totalIterations; i++) {
|
||||
metrics[1][i] = totalIterations - i;
|
||||
}
|
||||
for (int i = 0; i < totalIterations; i++) {
|
||||
// if any metrics off, we need to stop
|
||||
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRound, metrics, i);
|
||||
if (i >= earlyStoppingRound - 1) {
|
||||
TestCase.assertFalse(onTrack);
|
||||
} else {
|
||||
TestCase.assertTrue(onTrack);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDescendMetrics() {
|
||||
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user