[jvm-packages]support multiple validation datasets in Spark (#3910)

* add back train method but mark as deprecated

* add back train method but mark as deprecated

* add back train method but mark as deprecated

* add back train method but mark as deprecated

* fix scalastyle error

* fix scalastyle error

* fix scalastyle error

* fix scalastyle error

* wrap iterators

* enable copartition training and validationset

* add parameters

* converge code path and have init unit test

* enable multi evals for ranking

* unit test and doc

* update example

* fix early stopping

* address the offline comments

* udpate doc

* test eval metrics

* fix compilation issue

* fix example
This commit is contained in:
Nan Zhu 2018-12-17 21:03:57 -08:00 committed by GitHub
parent c8c7b9649c
commit c055a32609
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 477 additions and 136 deletions

View File

@ -200,6 +200,11 @@ In additional to ``num_early_stopping_rounds``, you also need to define ``maximi
After specifying these two parameters, the training would stop when the metrics goes to the other direction against the one specified by ``maximize_evaluation_metrics`` for ``num_early_stopping_rounds`` iterations. After specifying these two parameters, the training would stop when the metrics goes to the other direction against the one specified by ``maximize_evaluation_metrics`` for ``num_early_stopping_rounds`` iterations.
Training with Evaluation Sets
----------------
You can also monitor the performance of the model during training with multiple evaluation datasets. By specifying ``eval_sets`` or call ``setEvalSets`` over a XGBoostClassifier or XGBoostRegressor, you can pass in multiple evaluation datasets typed as a Map from String to DataFrame.
Prediction Prediction
========== ==========

View File

@ -40,7 +40,7 @@ object SparkTraining {
StructField("petal length", DoubleType, true), StructField("petal length", DoubleType, true),
StructField("petal width", DoubleType, true), StructField("petal width", DoubleType, true),
StructField("class", StringType, true))) StructField("class", StringType, true)))
val rawInput = spark.read.schema(schema).csv(args(0)) val rawInput = spark.read.schema(schema).csv(inputPath)
// transform class to index to make xgboost happy // transform class to index to make xgboost happy
val stringIndexer = new StringIndexer() val stringIndexer = new StringIndexer()
@ -55,6 +55,8 @@ object SparkTraining {
val xgbInput = vectorAssembler.transform(labelTransformed).select("features", val xgbInput = vectorAssembler.transform(labelTransformed).select("features",
"classIndex") "classIndex")
val Array(train, eval1, eval2, test) = xgbInput.randomSplit(Array(0.6, 0.2, 0.1, 0.1))
/** /**
* setup "timeout_request_workers" -> 60000L to make this application if it cannot get enough resources * setup "timeout_request_workers" -> 60000L to make this application if it cannot get enough resources
* to get 2 workers within 60000 ms * to get 2 workers within 60000 ms
@ -67,12 +69,13 @@ object SparkTraining {
"objective" -> "multi:softprob", "objective" -> "multi:softprob",
"num_class" -> 3, "num_class" -> 3,
"num_round" -> 100, "num_round" -> 100,
"num_workers" -> 2) "num_workers" -> 2,
"eval_sets" -> Map("eval1" -> eval1, "eval2" -> eval2))
val xgbClassifier = new XGBoostClassifier(xgbParam). val xgbClassifier = new XGBoostClassifier(xgbParam).
setFeaturesCol("features"). setFeaturesCol("features").
setLabelCol("classIndex") setLabelCol("classIndex")
val xgbClassificationModel = xgbClassifier.fit(xgbInput) val xgbClassificationModel = xgbClassifier.fit(train)
val results = xgbClassificationModel.transform(xgbInput) val results = xgbClassificationModel.transform(test)
results.show() results.show()
} }
} }

View File

@ -20,6 +20,11 @@ import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint} import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.Param
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Column, DataFrame, Row}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.{FloatType, IntegerType}
object DataUtils extends Serializable { object DataUtils extends Serializable {
private[spark] implicit class XGBLabeledPointFeatures( private[spark] implicit class XGBLabeledPointFeatures(
@ -67,4 +72,38 @@ object DataUtils extends Serializable {
XGBLabeledPoint(0.0f, v.indices, v.values.map(_.toFloat)) XGBLabeledPoint(0.0f, v.indices, v.values.map(_.toFloat))
} }
} }
private[spark] def convertDataFrameToXGBLabeledPointRDDs(
labelCol: Column,
featuresCol: Column,
weight: Column,
baseMargin: Column,
group: Option[Column],
dataFrames: DataFrame*): Array[RDD[XGBLabeledPoint]] = {
val selectedColumns = group.map(groupCol => Seq(labelCol.cast(FloatType),
featuresCol,
weight.cast(FloatType),
groupCol.cast(IntegerType),
baseMargin.cast(FloatType))).getOrElse(Seq(labelCol.cast(FloatType),
featuresCol,
weight.cast(FloatType),
baseMargin.cast(FloatType)))
dataFrames.toArray.map {
df => df.select(selectedColumns: _*).rdd.map {
case Row(label: Float, features: Vector, weight: Float, group: Int, baseMargin: Float) =>
val (indices, values) = features match {
case v: SparseVector => (v.indices, v.values.map(_.toFloat))
case v: DenseVector => (null, v.values.map(_.toFloat))
}
XGBLabeledPoint(label, indices, values, weight, group, baseMargin)
case Row(label: Float, features: Vector, weight: Float, baseMargin: Float) =>
val (indices, values) = features match {
case v: SparseVector => (v.indices, v.values.map(_.toFloat))
case v: DenseVector => (null, v.values.map(_.toFloat))
}
XGBLabeledPoint(label, indices, values, weight, baseMargin = baseMargin)
}
}
}
} }

View File

@ -19,6 +19,7 @@ package ml.dmlc.xgboost4j.scala.spark
import java.io.File import java.io.File
import java.nio.file.Files import java.nio.file.Files
import scala.collection.mutable.ListBuffer
import scala.collection.{AbstractIterator, mutable} import scala.collection.{AbstractIterator, mutable}
import scala.util.Random import scala.util.Random
@ -31,7 +32,7 @@ import org.apache.commons.logging.LogFactory
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext} import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext}
import org.apache.spark.sql.SparkSession import org.apache.spark.sql.{DataFrame, SparkSession}
/** /**
@ -114,13 +115,12 @@ object XGBoost extends Serializable {
round: Int, round: Int,
obj: ObjectiveTrait, obj: ObjectiveTrait,
eval: EvalTrait, eval: EvalTrait,
prevBooster: Booster) prevBooster: Booster): Iterator[(Booster, Map[String, Array[Float]])] = {
: Iterator[(Booster, Map[String, Array[Float]])] = {
// to workaround the empty partitions in training dataset, // to workaround the empty partitions in training dataset,
// this might not be the best efficient implementation, see // this might not be the best efficient implementation, see
// (https://github.com/dmlc/xgboost/issues/1277) // (https://github.com/dmlc/xgboost/issues/1277)
if (watches.train.rowNum == 0) { if (watches.toMap("train").rowNum == 0) {
throw new XGBoostError( throw new XGBoostError(
s"detected an empty partition in the training data, partition ID:" + s"detected an empty partition in the training data, partition ID:" +
s" ${TaskContext.getPartitionId()}") s" ${TaskContext.getPartitionId()}")
@ -138,7 +138,7 @@ object XGBoost extends Serializable {
} }
} }
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round)) val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round))
val booster = SXGBoost.train(watches.train, params, round, val booster = SXGBoost.train(watches.toMap("train"), params, round,
watches.toMap, metrics, obj, eval, watches.toMap, metrics, obj, eval,
earlyStoppingRound = numEarlyStoppingRounds, prevBooster) earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
Iterator(booster -> watches.toMap.keys.zip(metrics).toMap) Iterator(booster -> watches.toMap.keys.zip(metrics).toMap)
@ -175,6 +175,53 @@ object XGBoost extends Serializable {
tracker tracker
} }
class IteratorWrapper[T](arrayOfXGBLabeledPoints: Array[(String, Iterator[T])])
extends Iterator[(String, Iterator[T])] {
private var currentIndex = 0
override def hasNext: Boolean = currentIndex <= arrayOfXGBLabeledPoints.length - 1
override def next(): (String, Iterator[T]) = {
currentIndex += 1
arrayOfXGBLabeledPoints(currentIndex - 1)
}
}
private def coPartitionNoGroupSets(
trainingData: RDD[XGBLabeledPoint],
evalSets: Map[String, RDD[XGBLabeledPoint]],
params: Map[String, Any]) = {
val nWorkers = params("num_workers").asInstanceOf[Int]
// eval_sets is supposed to be set by the caller of [[trainDistributed]]
val allDatasets = Map("train" -> trainingData) ++ evalSets
val repartitionedDatasets = allDatasets.map{case (name, rdd) =>
if (rdd.getNumPartitions != nWorkers) {
(name, rdd.repartition(nWorkers))
} else {
(name, rdd)
}
}
repartitionedDatasets.foldLeft(trainingData.sparkContext.parallelize(
Array.fill[(String, Iterator[XGBLabeledPoint])](nWorkers)(null), nWorkers)){
case (rddOfIterWrapper, (name, rddOfIter)) =>
rddOfIterWrapper.zipPartitions(rddOfIter){
(itrWrapper, itr) =>
if (!itr.hasNext) {
logger.error("when specifying eval sets as dataframes, you have to ensure that " +
"the number of elements in each dataframe is larger than the number of workers")
throw new Exception("too few elements in evaluation sets")
}
val itrArray = itrWrapper.toArray
if (itrArray.head != null) {
new IteratorWrapper(itrArray :+ (name -> itr))
} else {
new IteratorWrapper(Array(name -> itr))
}
}
}
}
/** /**
* Check to see if Spark expects SSL encryption (`spark.ssl.enabled` set to true). * Check to see if Spark expects SSL encryption (`spark.ssl.enabled` set to true).
* If so, throw an exception unless this safety measure has been explicitly overridden * If so, throw an exception unless this safety measure has been explicitly overridden
@ -215,10 +262,16 @@ object XGBoost extends Serializable {
val eval = params.getOrElse("custom_eval", null).asInstanceOf[EvalTrait] val eval = params.getOrElse("custom_eval", null).asInstanceOf[EvalTrait]
val missing = params.getOrElse("missing", Float.NaN).asInstanceOf[Float] val missing = params.getOrElse("missing", Float.NaN).asInstanceOf[Float]
validateSparkSslConf(sparkContext) validateSparkSslConf(sparkContext)
if (params.contains("tree_method")) { if (params.contains("tree_method")) {
require(params("tree_method") != "hist", "xgboost4j-spark does not support fast histogram" + require(params("tree_method") != "hist", "xgboost4j-spark does not support fast histogram" +
" for now") " for now")
} }
if (params.contains("train_test_ratio")) {
logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" +
" pass a training and multiple evaluation datasets by passing 'eval_sets' and " +
"'eval_set_names'")
}
require(nWorkers > 0, "you must specify more than 0 workers") require(nWorkers > 0, "you must specify more than 0 workers")
if (obj != null) { if (obj != null) {
require(params.get("objective_type").isDefined, "parameter \"objective_type\" is not" + require(params.get("objective_type").isDefined, "parameter \"objective_type\" is not" +
@ -247,17 +300,30 @@ object XGBoost extends Serializable {
params: Map[String, Any], params: Map[String, Any],
rabitEnv: java.util.Map[String, String], rabitEnv: java.util.Map[String, String],
checkpointRound: Int, checkpointRound: Int,
prevBooster: Booster) = { prevBooster: Booster,
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _) = val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _) =
parameterFetchAndValidation(params, trainingData.sparkContext) parameterFetchAndValidation(params, trainingData.sparkContext)
val partitionedData = repartitionForTraining(trainingData, nWorkers) val partitionedData = repartitionForTraining(trainingData, nWorkers)
partitionedData.mapPartitions(labeledPoints => { if (evalSetsMap.isEmpty) {
val watches = Watches.buildWatches(params, partitionedData.mapPartitions(labeledPoints => {
removeMissingValues(labeledPoints, missing), val watches = Watches.buildWatches(params,
getCacheDirName(useExternalMemory)) removeMissingValues(labeledPoints, missing),
buildDistributedBooster(watches, params, rabitEnv, checkpointRound, getCacheDirName(useExternalMemory))
obj, eval, prevBooster) buildDistributedBooster(watches, params, rabitEnv, checkpointRound,
}).cache() obj, eval, prevBooster)
}).cache()
} else {
coPartitionNoGroupSets(partitionedData, evalSetsMap, params).mapPartitions {
nameAndLabeledPointSets =>
val watches = Watches.buildWatches(
nameAndLabeledPointSets.map {
case (name, iter) => (name, removeMissingValues(iter, missing))},
getCacheDirName(useExternalMemory))
buildDistributedBooster(watches, params, rabitEnv, checkpointRound,
obj, eval, prevBooster)
}.cache()
}
} }
private def trainForRanking( private def trainForRanking(
@ -265,17 +331,30 @@ object XGBoost extends Serializable {
params: Map[String, Any], params: Map[String, Any],
rabitEnv: java.util.Map[String, String], rabitEnv: java.util.Map[String, String],
checkpointRound: Int, checkpointRound: Int,
prevBooster: Booster) = { prevBooster: Booster,
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _) = val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _) =
parameterFetchAndValidation(params, trainingData.sparkContext) parameterFetchAndValidation(params, trainingData.sparkContext)
val partitionedData = repartitionForTrainingGroup(trainingData, nWorkers) val partitionedTrainingSet = repartitionForTrainingGroup(trainingData, nWorkers)
partitionedData.mapPartitions(labeledPointGroups => { if (evalSetsMap.isEmpty) {
val watches = Watches.buildWatchesWithGroup(params, partitionedTrainingSet.mapPartitions(labeledPointGroups => {
removeMissingValuesWithGroup(labeledPointGroups, missing), val watches = Watches.buildWatchesWithGroup(params,
getCacheDirName(useExternalMemory)) removeMissingValuesWithGroup(labeledPointGroups, missing),
buildDistributedBooster(watches, params, rabitEnv, checkpointRound, getCacheDirName(useExternalMemory))
obj, eval, prevBooster) buildDistributedBooster(watches, params, rabitEnv, checkpointRound, obj, eval, prevBooster)
}).cache() }).cache()
} else {
coPartitionGroupSets(partitionedTrainingSet, evalSetsMap, params).mapPartitions(
labeledPointGroupSets => {
val watches = Watches.buildWatchesWithGroup(
labeledPointGroupSets.map {
case (name, iter) => (name, removeMissingValuesWithGroup(iter, missing))
},
getCacheDirName(useExternalMemory))
buildDistributedBooster(watches, params, rabitEnv, checkpointRound, obj, eval,
prevBooster)
}).cache()
}
} }
/** /**
@ -285,7 +364,9 @@ object XGBoost extends Serializable {
private[spark] def trainDistributed( private[spark] def trainDistributed(
trainingData: RDD[XGBLabeledPoint], trainingData: RDD[XGBLabeledPoint],
params: Map[String, Any], params: Map[String, Any],
hasGroup: Boolean = false): (Booster, Map[String, Array[Float]]) = { hasGroup: Boolean = false,
evalSetsMap: Map[String, RDD[XGBLabeledPoint]] = Map()):
(Booster, Map[String, Array[Float]]) = {
val (nWorkers, round, _, _, _, _, trackerConf, timeoutRequestWorkers, val (nWorkers, round, _, _, _, _, trackerConf, timeoutRequestWorkers,
checkpointPath, checkpointInterval) = parameterFetchAndValidation(params, checkpointPath, checkpointInterval) = parameterFetchAndValidation(params,
trainingData.sparkContext) trainingData.sparkContext)
@ -303,10 +384,10 @@ object XGBoost extends Serializable {
val rabitEnv = tracker.getWorkerEnvs val rabitEnv = tracker.getWorkerEnvs
val boostersAndMetrics = if (hasGroup) { val boostersAndMetrics = if (hasGroup) {
trainForRanking(trainingData, overriddenParams, rabitEnv, checkpointRound, trainForRanking(trainingData, overriddenParams, rabitEnv, checkpointRound,
prevBooster) prevBooster, evalSetsMap)
} else { } else {
trainForNonRanking(trainingData, overriddenParams, rabitEnv, checkpointRound, trainForNonRanking(trainingData, overriddenParams, rabitEnv, checkpointRound,
prevBooster) prevBooster, evalSetsMap)
} }
val sparkJobThread = new Thread() { val sparkJobThread = new Thread() {
override def run() { override def run() {
@ -340,8 +421,7 @@ object XGBoost extends Serializable {
} }
} }
private[spark] def repartitionForTrainingGroup( private def aggByGroupInfo(trainingData: RDD[XGBLabeledPoint]) = {
trainingData: RDD[XGBLabeledPoint], nWorkers: Int): RDD[Array[XGBLabeledPoint]] = {
val normalGroups: RDD[Array[XGBLabeledPoint]] = trainingData.mapPartitions( val normalGroups: RDD[Array[XGBLabeledPoint]] = trainingData.mapPartitions(
// LabeledPointGroupIterator returns (Boolean, Array[XGBLabeledPoint]) // LabeledPointGroupIterator returns (Boolean, Array[XGBLabeledPoint])
new LabeledPointGroupIterator(_)).filter(!_.isEdgeGroup).map(_.points) new LabeledPointGroupIterator(_)).filter(!_.isEdgeGroup).map(_.points)
@ -349,22 +429,61 @@ object XGBoost extends Serializable {
// edge groups with partition id. // edge groups with partition id.
val edgeGroups: RDD[(Int, XGBLabeledPointGroup)] = trainingData.mapPartitions( val edgeGroups: RDD[(Int, XGBLabeledPointGroup)] = trainingData.mapPartitions(
new LabeledPointGroupIterator(_)).filter(_.isEdgeGroup).map( new LabeledPointGroupIterator(_)).filter(_.isEdgeGroup).map(
group => (TaskContext.getPartitionId(), group)) group => (TaskContext.getPartitionId(), group))
// group chunks from different partitions together by group id in XGBLabeledPoint. // group chunks from different partitions together by group id in XGBLabeledPoint.
// use groupBy instead of aggregateBy since all groups within a partition have unique groud ids. // use groupBy instead of aggregateBy since all groups within a partition have unique group ids.
val stitchedGroups: RDD[Array[XGBLabeledPoint]] = edgeGroups.groupBy(_._2.groupId).map( val stitchedGroups: RDD[Array[XGBLabeledPoint]] = edgeGroups.groupBy(_._2.groupId).map(
groups => { groups => {
val it: Iterable[(Int, XGBLabeledPointGroup)] = groups._2 val it: Iterable[(Int, XGBLabeledPointGroup)] = groups._2
// sorted by partition id and merge list of Array[XGBLabeledPoint] into one array // sorted by partition id and merge list of Array[XGBLabeledPoint] into one array
it.toArray.sortBy(_._1).map(_._2.points).flatten it.toArray.sortBy(_._1).flatMap(_._2.points)
}) })
normalGroups.union(stitchedGroups)
}
var allGroups = normalGroups.union(stitchedGroups) private[spark] def repartitionForTrainingGroup(
trainingData: RDD[XGBLabeledPoint], nWorkers: Int): RDD[Array[XGBLabeledPoint]] = {
val allGroups = aggByGroupInfo(trainingData)
logger.info(s"repartitioning training group set to $nWorkers partitions") logger.info(s"repartitioning training group set to $nWorkers partitions")
allGroups.repartition(nWorkers) allGroups.repartition(nWorkers)
} }
private def coPartitionGroupSets(
aggedTrainingSet: RDD[Array[XGBLabeledPoint]],
evalSets: Map[String, RDD[XGBLabeledPoint]],
params: Map[String, Any]): RDD[(String, Iterator[Array[XGBLabeledPoint]])] = {
val nWorkers = params("num_workers").asInstanceOf[Int]
val repartitionedDatasets = Map("train" -> aggedTrainingSet) ++ evalSets.map {
case (name, rdd) => {
val aggedRdd = aggByGroupInfo(rdd)
if (aggedRdd.getNumPartitions != nWorkers) {
name -> aggedRdd.repartition(nWorkers)
} else {
name -> aggedRdd
}
}
}
repartitionedDatasets.foldLeft(aggedTrainingSet.sparkContext.parallelize(
Array.fill[(String, Iterator[Array[XGBLabeledPoint]])](nWorkers)(null), nWorkers)){
case (rddOfIterWrapper, (name, rddOfIter)) =>
rddOfIterWrapper.zipPartitions(rddOfIter){
(itrWrapper, itr) =>
if (!itr.hasNext) {
logger.error("when specifying eval sets as dataframes, you have to ensure that " +
"the number of elements in each dataframe is larger than the number of workers")
throw new Exception("too few elements in evaluation sets")
}
val itrArray = itrWrapper.toArray
if (itrArray.head != null) {
new IteratorWrapper(itrArray :+ (name -> itr))
} else {
new IteratorWrapper(Array(name -> itr))
}
}
}
}
private def postTrackerReturnProcessing( private def postTrackerReturnProcessing(
trackerReturnVal: Int, trackerReturnVal: Int,
distributedBoostersAndMetrics: RDD[(Booster, Map[String, Array[Float]])], distributedBoostersAndMetrics: RDD[(Booster, Map[String, Array[Float]])],
@ -395,12 +514,13 @@ object XGBoost extends Serializable {
} }
private class Watches private( private class Watches private(
val train: DMatrix, val datasets: Array[DMatrix],
val test: DMatrix, val names: Array[String],
private val cacheDirName: Option[String]) { val cacheDirName: Option[String]) {
def toMap: Map[String, DMatrix] = Map("train" -> train, "test" -> test) def toMap: Map[String, DMatrix] = {
.filter { case (_, matrix) => matrix.rowNum > 0 } names.zip(datasets).toMap.filter { case (_, matrix) => matrix.rowNum > 0 }
}
def size: Int = toMap.size def size: Int = toMap.size
@ -440,6 +560,26 @@ private object Watches {
} }
} }
def buildWatches(
nameAndLabeledPointSets: Iterator[(String, Iterator[XGBLabeledPoint])],
cachedDirName: Option[String]): Watches = {
val dms = nameAndLabeledPointSets.map {
case (name, labeledPoints) =>
val baseMargins = new mutable.ArrayBuilder.ofFloat
val duplicatedItr = labeledPoints.map(labeledPoint => {
baseMargins += labeledPoint.baseMargin
labeledPoint
})
val dMatrix = new DMatrix(duplicatedItr, cachedDirName.map(_ + s"/$name").orNull)
val baseMargin = fromBaseMarginsToArray(baseMargins.result().iterator)
if (baseMargin.isDefined) {
dMatrix.setBaseMargin(baseMargin.get)
}
(name, dMatrix)
}.toArray
new Watches(dms.map(_._2), dms.map(_._1), cachedDirName)
}
def buildWatches( def buildWatches(
params: Map[String, Any], params: Map[String, Any],
labeledPoints: Iterator[XGBLabeledPoint], labeledPoints: Iterator[XGBLabeledPoint],
@ -468,7 +608,30 @@ private object Watches {
if (trainMargin.isDefined) trainMatrix.setBaseMargin(trainMargin.get) if (trainMargin.isDefined) trainMatrix.setBaseMargin(trainMargin.get)
if (testMargin.isDefined) testMatrix.setBaseMargin(testMargin.get) if (testMargin.isDefined) testMatrix.setBaseMargin(testMargin.get)
new Watches(trainMatrix, testMatrix, cacheDirName) new Watches(Array(trainMatrix, testMatrix), Array("train", "test"), cacheDirName)
}
def buildWatchesWithGroup(
nameAndlabeledPointGroupSets: Iterator[(String, Iterator[Array[XGBLabeledPoint]])],
cachedDirName: Option[String]): Watches = {
val dms = nameAndlabeledPointGroupSets.map {
case (name, labeledPointsGroups) =>
val baseMargins = new mutable.ArrayBuilder.ofFloat
val duplicatedItr = labeledPointsGroups.map(labeledPoints => {
labeledPoints.map { labeledPoint =>
baseMargins += labeledPoint.baseMargin
labeledPoint
}
})
val dMatrix = new DMatrix(duplicatedItr.flatMap(_.iterator),
cachedDirName.map(_ + s"/$name").orNull)
val baseMargin = fromBaseMarginsToArray(baseMargins.result().iterator)
if (baseMargin.isDefined) {
dMatrix.setBaseMargin(baseMargin.get)
}
(name, dMatrix)
}.toArray
new Watches(dms.map(_._2), dms.map(_._1), cachedDirName)
} }
def buildWatchesWithGroup( def buildWatchesWithGroup(
@ -513,7 +676,7 @@ private object Watches {
if (trainMargin.isDefined) trainMatrix.setBaseMargin(trainMargin.get) if (trainMargin.isDefined) trainMatrix.setBaseMargin(trainMargin.get)
if (testMargin.isDefined) testMatrix.setBaseMargin(testMargin.get) if (testMargin.isDefined) testMatrix.setBaseMargin(testMargin.get)
new Watches(trainMatrix, testMatrix, cacheDirName) new Watches(Array(trainMatrix, testMatrix), Array("train", "test"), cacheDirName)
} }
} }
@ -532,7 +695,7 @@ private[spark] class LabeledPointGroupIterator(base: Iterator[XGBLabeledPoint])
private var isNewGroup = false private var isNewGroup = false
override def hasNext: Boolean = { override def hasNext: Boolean = {
return base.hasNext || isNewGroup base.hasNext || isNewGroup
} }
override def next(): XGBLabeledPointGroup = { override def next(): XGBLabeledPointGroup = {

View File

@ -43,7 +43,7 @@ import org.apache.spark.broadcast.Broadcast
private[spark] trait XGBoostClassifierParams extends GeneralParams with LearningTaskParams private[spark] trait XGBoostClassifierParams extends GeneralParams with LearningTaskParams
with BoosterParams with HasWeightCol with HasBaseMarginCol with HasNumClass with ParamMapFuncs with BoosterParams with HasWeightCol with HasBaseMarginCol with HasNumClass with ParamMapFuncs
with HasLeafPredictionCol with HasContribPredictionCol with HasLeafPredictionCol with HasContribPredictionCol with NonParamVariables
class XGBoostClassifier ( class XGBoostClassifier (
override val uid: String, override val uid: String,
@ -182,23 +182,19 @@ class XGBoostClassifier (
col($(baseMarginCol)) col($(baseMarginCol))
} }
val instances: RDD[XGBLabeledPoint] = dataset.select( val trainingSet: RDD[XGBLabeledPoint] = DataUtils.convertDataFrameToXGBLabeledPointRDDs(
col($(featuresCol)), col($(labelCol)), col($(featuresCol)), weight, baseMargin,
col($(labelCol)).cast(FloatType), None, dataset.asInstanceOf[DataFrame]).head
baseMargin.cast(FloatType), val evalRDDMap = getEvalSets(xgboostParams).map {
weight.cast(FloatType) case (name, dataFrame) => (name,
).rdd.map { case Row(features: Vector, label: Float, baseMargin: Float, weight: Float) => DataUtils.convertDataFrameToXGBLabeledPointRDDs(col($(labelCol)), col($(featuresCol)),
val (indices, values) = features match { weight, baseMargin, None, dataFrame).head)
case v: SparseVector => (v.indices, v.values.map(_.toFloat))
case v: DenseVector => (null, v.values.map(_.toFloat))
}
XGBLabeledPoint(label, indices, values, baseMargin = baseMargin, weight = weight)
} }
transformSchema(dataset.schema, logging = true) transformSchema(dataset.schema, logging = true)
val derivedXGBParamMap = MLlib2XGBoostParams val derivedXGBParamMap = MLlib2XGBoostParams
// All non-null param maps in XGBoostClassifier are in derivedXGBParamMap. // All non-null param maps in XGBoostClassifier are in derivedXGBParamMap.
val (_booster, _metrics) = XGBoost.trainDistributed(instances, derivedXGBParamMap, val (_booster, _metrics) = XGBoost.trainDistributed(trainingSet, derivedXGBParamMap,
hasGroup = false) hasGroup = false, evalRDDMap)
val model = new XGBoostClassificationModel(uid, _numClasses, _booster) val model = new XGBoostClassificationModel(uid, _numClasses, _booster)
val summary = XGBoostTrainingSummary(_metrics) val summary = XGBoostTrainingSummary(_metrics)
model.setSummary(summary) model.setSummary(summary)

View File

@ -43,7 +43,7 @@ import org.apache.spark.broadcast.Broadcast
private[spark] trait XGBoostRegressorParams extends GeneralParams with BoosterParams private[spark] trait XGBoostRegressorParams extends GeneralParams with BoosterParams
with LearningTaskParams with HasBaseMarginCol with HasWeightCol with HasGroupCol with LearningTaskParams with HasBaseMarginCol with HasWeightCol with HasGroupCol
with ParamMapFuncs with HasLeafPredictionCol with HasContribPredictionCol with ParamMapFuncs with HasLeafPredictionCol with HasContribPredictionCol with NonParamVariables
class XGBoostRegressor ( class XGBoostRegressor (
override val uid: String, override val uid: String,
@ -174,26 +174,19 @@ class XGBoostRegressor (
col($(baseMarginCol)) col($(baseMarginCol))
} }
val group = if (!isDefined(groupCol) || $(groupCol).isEmpty) lit(-1) else col($(groupCol)) val group = if (!isDefined(groupCol) || $(groupCol).isEmpty) lit(-1) else col($(groupCol))
val trainingSet: RDD[XGBLabeledPoint] = DataUtils.convertDataFrameToXGBLabeledPointRDDs(
val instances: RDD[XGBLabeledPoint] = dataset.select( col($(labelCol)), col($(featuresCol)), weight, baseMargin, Some(group),
col($(labelCol)).cast(FloatType), dataset.asInstanceOf[DataFrame]).head
col($(featuresCol)), val evalRDDMap = getEvalSets(xgboostParams).map {
weight.cast(FloatType), case (name, dataFrame) => (name,
group.cast(IntegerType), DataUtils.convertDataFrameToXGBLabeledPointRDDs(col($(labelCol)), col($(featuresCol)),
baseMargin.cast(FloatType) weight, baseMargin, Some(group), dataFrame).head)
).rdd.map {
case Row(label: Float, features: Vector, weight: Float, group: Int, baseMargin: Float) =>
val (indices, values) = features match {
case v: SparseVector => (v.indices, v.values.map(_.toFloat))
case v: DenseVector => (null, v.values.map(_.toFloat))
}
XGBLabeledPoint(label, indices, values, weight, group, baseMargin)
} }
transformSchema(dataset.schema, logging = true) transformSchema(dataset.schema, logging = true)
val derivedXGBParamMap = MLlib2XGBoostParams val derivedXGBParamMap = MLlib2XGBoostParams
// All non-null param maps in XGBoostRegressor are in derivedXGBParamMap. // All non-null param maps in XGBoostRegressor are in derivedXGBParamMap.
val (_booster, _metrics) = XGBoost.trainDistributed(instances, derivedXGBParamMap, val (_booster, _metrics) = XGBoost.trainDistributed(trainingSet, derivedXGBParamMap,
hasGroup = group != lit(-1)) hasGroup = group != lit(-1), evalRDDMap)
val model = new XGBoostRegressionModel(uid, _booster) val model = new XGBoostRegressionModel(uid, _booster)
val summary = XGBoostTrainingSummary(_metrics) val summary = XGBoostTrainingSummary(_metrics)
model.setSummary(summary) model.setSummary(summary)

View File

@ -18,12 +18,17 @@ package ml.dmlc.xgboost4j.scala.spark
class XGBoostTrainingSummary private( class XGBoostTrainingSummary private(
val trainObjectiveHistory: Array[Float], val trainObjectiveHistory: Array[Float],
val testObjectiveHistory: Option[Array[Float]] val validationObjectiveHistory: (String, Array[Float])*) extends Serializable {
) extends Serializable {
override def toString: String = { override def toString: String = {
val train = trainObjectiveHistory.toList val train = trainObjectiveHistory.mkString(",")
val test = testObjectiveHistory.map(_.toList) val vaidationObjectiveHistoryString = {
s"XGBoostTrainingSummary(trainObjectiveHistory=$train, testObjectiveHistory=$test)" validationObjectiveHistory.map {
case (name, metrics) =>
s"${name}ObjectiveHistory=${metrics.mkString(",")}"
}.mkString(";")
}
s"XGBoostTrainingSummary(trainObjectiveHistory=$train; $vaidationObjectiveHistoryString)"
} }
} }
@ -31,6 +36,6 @@ private[xgboost4j] object XGBoostTrainingSummary {
def apply(metrics: Map[String, Array[Float]]): XGBoostTrainingSummary = { def apply(metrics: Map[String, Array[Float]]): XGBoostTrainingSummary = {
new XGBoostTrainingSummary( new XGBoostTrainingSummary(
trainObjectiveHistory = metrics("train"), trainObjectiveHistory = metrics("train"),
testObjectiveHistory = metrics.get("test")) metrics.filter(_._1 != "train").toSeq: _*)
} }
} }

View File

@ -22,26 +22,7 @@ import org.json4s.{DefaultFormats, Extraction, NoTypeHints}
import org.json4s.jackson.JsonMethods.{compact, parse, render} import org.json4s.jackson.JsonMethods.{compact, parse, render}
import org.apache.spark.ml.param.{Param, ParamPair, Params} import org.apache.spark.ml.param.{Param, ParamPair, Params}
import org.apache.spark.sql.DataFrame
class GroupDataParam(
parent: Params,
name: String,
doc: String) extends Param[Seq[Seq[Int]]](parent, name, doc) {
/** Creates a param pair with the given value (for Java). */
override def w(value: Seq[Seq[Int]]): ParamPair[Seq[Seq[Int]]] = super.w(value)
override def jsonEncode(value: Seq[Seq[Int]]): String = {
import org.json4s.jackson.Serialization
implicit val formats = Serialization.formats(NoTypeHints)
compact(render(Extraction.decompose(value)))
}
override def jsonDecode(json: String): Seq[Seq[Int]] = {
implicit val formats = DefaultFormats
parse(json).extract[Seq[Seq[Int]]]
}
}
class CustomEvalParam( class CustomEvalParam(
parent: Params, parent: Params,

View File

@ -22,6 +22,14 @@ import ml.dmlc.xgboost4j.scala.spark.TrackerConf
import org.apache.spark.ml.param._ import org.apache.spark.ml.param._
import scala.collection.mutable import scala.collection.mutable
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Column, DataFrame, Row}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.{FloatType, IntegerType}
private[spark] trait GeneralParams extends Params { private[spark] trait GeneralParams extends Params {
/** /**
@ -154,8 +162,7 @@ private[spark] trait GeneralParams extends Params {
useExternalMemory -> false, silent -> 0, useExternalMemory -> false, silent -> 0,
customObj -> null, customEval -> null, missing -> Float.NaN, customObj -> null, customEval -> null, missing -> Float.NaN,
trackerConf -> TrackerConf(), seed -> 0, timeoutRequestWorkers -> 30 * 60 * 1000L, trackerConf -> TrackerConf(), seed -> 0, timeoutRequestWorkers -> 30 * 60 * 1000L,
checkpointPath -> "", checkpointInterval -> -1 checkpointPath -> "", checkpointInterval -> -1)
)
} }
trait HasLeafPredictionCol extends Params { trait HasLeafPredictionCol extends Params {

View File

@ -0,0 +1,35 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark.params
import org.apache.spark.sql.DataFrame
trait NonParamVariables {
protected var evalSetsMap: Map[String, DataFrame] = Map.empty
def setEvalSets(evalSets: Map[String, DataFrame]): Unit = {
evalSetsMap = evalSets
}
def getEvalSets(params: Map[String, Any]): Map[String, DataFrame] = {
if (params.contains("eval_sets")) {
params("eval_sets").asInstanceOf[Map[String, DataFrame]]
} else {
evalSetsMap
}
}
}

View File

@ -137,7 +137,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
assert(predictionDF.columns.contains("final_prediction") === false) assert(predictionDF.columns.contains("final_prediction") === false)
assert(model.summary.trainObjectiveHistory.length === 5) assert(model.summary.trainObjectiveHistory.length === 5)
assert(model.summary.testObjectiveHistory.isEmpty) assert(model.summary.validationObjectiveHistory.isEmpty)
} }
test("XGBoost and Spark parameters synchronize correctly") { test("XGBoost and Spark parameters synchronize correctly") {
@ -191,31 +191,6 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
assert(count != 0) assert(count != 0)
} }
test("training summary") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "num_round" -> 5, "nWorkers" -> numWorkers)
val trainingDF = buildDataFrame(Classification.train)
val xgb = new XGBoostClassifier(paramMap)
val model = xgb.fit(trainingDF)
assert(model.summary.trainObjectiveHistory.length === 5)
assert(model.summary.testObjectiveHistory.isEmpty)
}
test("train/test split") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
"num_round" -> 5, "num_workers" -> numWorkers)
val training = buildDataFrame(Classification.train)
val xgb = new XGBoostClassifier(paramMap)
val model = xgb.fit(training)
val Some(testObjectiveHistory) = model.summary.testObjectiveHistory
assert(testObjectiveHistory.length === 5)
assert(model.summary.trainObjectiveHistory !== testObjectiveHistory)
}
test("test predictionLeaf") { test("test predictionLeaf") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5", "objective" -> "binary:logistic", "train_test_ratio" -> "0.5",

View File

@ -277,4 +277,93 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
assert(booster != null) assert(booster != null)
} }
test("training summary") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "num_round" -> 5, "nWorkers" -> numWorkers)
val trainingDF = buildDataFrame(Classification.train)
val xgb = new XGBoostClassifier(paramMap)
val model = xgb.fit(trainingDF)
assert(model.summary.trainObjectiveHistory.length === 5)
assert(model.summary.validationObjectiveHistory.isEmpty)
}
test("train/test split") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
"num_round" -> 5, "num_workers" -> numWorkers)
val training = buildDataFrame(Classification.train)
val xgb = new XGBoostClassifier(paramMap)
val model = xgb.fit(training)
assert(model.summary.validationObjectiveHistory.length === 1)
assert(model.summary.validationObjectiveHistory(0)._1 === "test")
assert(model.summary.validationObjectiveHistory(0)._2.length === 5)
assert(model.summary.trainObjectiveHistory !== model.summary.validationObjectiveHistory(0))
}
test("train with multiple validation datasets (non-ranking)") {
val training = buildDataFrame(Classification.train)
val Array(train, eval1, eval2) = training.randomSplit(Array(0.6, 0.2, 0.2))
val paramMap1 = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic",
"num_round" -> 5, "num_workers" -> numWorkers)
val xgb1 = new XGBoostClassifier(paramMap1)
xgb1.setEvalSets(Map("eval1" -> eval1, "eval2" -> eval2))
val model1 = xgb1.fit(train)
assert(model1.summary.validationObjectiveHistory.length === 2)
assert(model1.summary.validationObjectiveHistory.map(_._1).toSet === Set("eval1", "eval2"))
assert(model1.summary.validationObjectiveHistory(0)._2.length === 5)
assert(model1.summary.validationObjectiveHistory(1)._2.length === 5)
assert(model1.summary.trainObjectiveHistory !== model1.summary.validationObjectiveHistory(0))
assert(model1.summary.trainObjectiveHistory !== model1.summary.validationObjectiveHistory(1))
val paramMap2 = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic",
"num_round" -> 5, "num_workers" -> numWorkers,
"eval_sets" -> Map("eval1" -> eval1, "eval2" -> eval2))
val xgb2 = new XGBoostClassifier(paramMap2)
val model2 = xgb2.fit(train)
assert(model2.summary.validationObjectiveHistory.length === 2)
assert(model2.summary.validationObjectiveHistory.map(_._1).toSet === Set("eval1", "eval2"))
assert(model2.summary.validationObjectiveHistory(0)._2.length === 5)
assert(model2.summary.validationObjectiveHistory(1)._2.length === 5)
assert(model2.summary.trainObjectiveHistory !== model2.summary.validationObjectiveHistory(0))
assert(model2.summary.trainObjectiveHistory !== model2.summary.validationObjectiveHistory(1))
}
test("train with multiple validation datasets (ranking)") {
val training = buildDataFrameWithGroup(Ranking.train, 5)
val Array(train, eval1, eval2) = training.randomSplit(Array(0.6, 0.2, 0.2))
val paramMap1 = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "rank:pairwise",
"num_round" -> 5, "num_workers" -> numWorkers, "group_col" -> "group")
val xgb1 = new XGBoostRegressor(paramMap1)
xgb1.setEvalSets(Map("eval1" -> eval1, "eval2" -> eval2))
val model1 = xgb1.fit(train)
assert(model1 != null)
assert(model1.summary.validationObjectiveHistory.length === 2)
assert(model1.summary.validationObjectiveHistory.map(_._1).toSet === Set("eval1", "eval2"))
assert(model1.summary.validationObjectiveHistory(0)._2.length === 5)
assert(model1.summary.validationObjectiveHistory(1)._2.length === 5)
assert(model1.summary.trainObjectiveHistory !== model1.summary.validationObjectiveHistory(0))
assert(model1.summary.trainObjectiveHistory !== model1.summary.validationObjectiveHistory(1))
val paramMap2 = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "rank:pairwise",
"num_round" -> 5, "num_workers" -> numWorkers, "group_col" -> "group",
"eval_sets" -> Map("eval1" -> eval1, "eval2" -> eval2))
val xgb2 = new XGBoostRegressor(paramMap2)
val model2 = xgb2.fit(train)
assert(model2 != null)
assert(model2.summary.validationObjectiveHistory.length === 2)
assert(model2.summary.validationObjectiveHistory.map(_._1).toSet === Set("eval1", "eval2"))
assert(model2.summary.validationObjectiveHistory(0)._2.length === 5)
assert(model2.summary.validationObjectiveHistory(1)._2.length === 5)
assert(model2.summary.trainObjectiveHistory !== model2.summary.validationObjectiveHistory(0))
assert(model2.summary.trainObjectiveHistory !== model2.summary.validationObjectiveHistory(1))
}
} }

View File

@ -222,14 +222,19 @@ public class XGBoost {
if (iter < earlyStoppingRounds - 1) { if (iter < earlyStoppingRounds - 1) {
return true; return true;
} }
float[] criterion = metrics[metrics.length - 1]; for (int metricsId = metrics.length == 1 ? 0 : 1; metricsId < metrics.length; metricsId++) {
for (int shift = 0; shift < earlyStoppingRounds - 1; shift++) { float[] criterion = metrics[metricsId];
// the initial value of onTrack is false and if the metrics in any of `earlyStoppingRounds` for (int shift = 0; shift < earlyStoppingRounds - 1; shift++) {
// iterations goes to the expected direction, we should consider these `earlyStoppingRounds` // the initial value of onTrack is false and if the metrics in any of `earlyStoppingRounds`
// as `onTrack` // iterations goes to the expected direction, we should consider these `earlyStoppingRounds`
onTrack |= maximizeEvaluationMetrics ? // as `onTrack`
criterion[iter - shift] >= criterion[iter - shift - 1] : onTrack |= maximizeEvaluationMetrics ?
criterion[iter - shift] <= criterion[iter - shift - 1]; criterion[iter - shift] >= criterion[iter - shift - 1] :
criterion[iter - shift] <= criterion[iter - shift - 1];
}
if (!onTrack) {
return false;
}
} }
return onTrack; return onTrack;
} }

View File

@ -185,6 +185,51 @@ public class BoosterImplTest {
} }
} }
@Test
public void testEarlyStoppingForMultipleMetrics() {
Map<String, Object> paramMap = new HashMap<String, Object>() {
{
put("max_depth", 3);
put("silent", 1);
put("objective", "binary:logistic");
put("maximize_evaluation_metrics", "true");
}
};
int earlyStoppingRound = 3;
int totalIterations = 5;
int numOfMetrics = 3;
float[][] metrics = new float[numOfMetrics][totalIterations];
for (int i = 0; i < numOfMetrics; i++) {
for (int j = 0; j < totalIterations; j++) {
metrics[0][j] = j;
}
}
for (int i = 0; i < totalIterations; i++) {
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRound, metrics, i);
TestCase.assertTrue(onTrack);
}
for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = totalIterations - i;
}
// when we have multiple datasets, the training metrics is not considered
for (int i = 0; i < totalIterations; i++) {
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRound, metrics, i);
TestCase.assertTrue(onTrack);
}
for (int i = 0; i < totalIterations; i++) {
metrics[1][i] = totalIterations - i;
}
for (int i = 0; i < totalIterations; i++) {
// if any metrics off, we need to stop
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRound, metrics, i);
if (i >= earlyStoppingRound - 1) {
TestCase.assertFalse(onTrack);
} else {
TestCase.assertTrue(onTrack);
}
}
}
@Test @Test
public void testDescendMetrics() { public void testDescendMetrics() {
Map<String, Object> paramMap = new HashMap<String, Object>() { Map<String, Object> paramMap = new HashMap<String, Object>() {