[jvm-packages] Fix #3489: Spark repartitionForData can potentially shuffle all data and lose ordering required for ranking objectives (#3654)

This commit is contained in:
weitian 2018-10-03 08:43:55 -07:00 committed by Nan Zhu
parent d594b11f35
commit efc4f85505
5 changed files with 274 additions and 109 deletions

View File

@ -19,7 +19,7 @@ package ml.dmlc.xgboost4j.scala.spark
import java.io.File import java.io.File
import java.nio.file.Files import java.nio.file.Files
import scala.collection.mutable import scala.collection.{AbstractIterator, mutable}
import scala.util.Random import scala.util.Random
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker} import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
@ -53,6 +53,17 @@ object TrackerConf {
def apply(): TrackerConf = TrackerConf(0L, "python") def apply(): TrackerConf = TrackerConf(0L, "python")
} }
/**
* Traing data group in a RDD partition.
* @param groupId The group id
* @param points Array of XGBLabeledPoint within the same group.
* @param isEdgeGroup whether it is a frist or last group in a RDD partition.
*/
private[spark] case class XGBLabeledPointGroup(
groupId: Int,
points: Array[XGBLabeledPoint],
isEdgeGroup: Boolean)
object XGBoost extends Serializable { object XGBoost extends Serializable {
private val logger = LogFactory.getLog("XGBoostSpark") private val logger = LogFactory.getLog("XGBoostSpark")
@ -74,78 +85,62 @@ object XGBoost extends Serializable {
} }
} }
private def fromBaseMarginsToArray(baseMargins: Iterator[Float]): Option[Array[Float]] = { private def removeMissingValuesWithGroup(
val builder = new mutable.ArrayBuilder.ofFloat() xgbLabelPointGroups: Iterator[Array[XGBLabeledPoint]],
var nTotal = 0 missing: Float): Iterator[Array[XGBLabeledPoint]] = {
var nUndefined = 0 if (!missing.isNaN) {
while (baseMargins.hasNext) { xgbLabelPointGroups.map {
nTotal += 1 labeledPoints => XGBoost.removeMissingValues(labeledPoints.iterator, missing).toArray
val baseMargin = baseMargins.next()
if (baseMargin.isNaN) {
nUndefined += 1 // don't waste space for all-NaNs.
} else {
builder += baseMargin
} }
}
if (nUndefined == nTotal) {
None
} else if (nUndefined == 0) {
Some(builder.result())
} else { } else {
throw new IllegalArgumentException( xgbLabelPointGroups
s"Encountered a partition with $nUndefined NaN base margin values. " +
s"If you want to specify base margin, ensure all values are non-NaN.")
} }
} }
private[spark] def buildDistributedBoosters( private def getCacheDirName(useExternalMemory: Boolean): Option[String] = {
data: RDD[XGBLabeledPoint], val taskId = TaskContext.getPartitionId().toString
if (useExternalMemory) {
val dir = Files.createTempDirectory(s"${TaskContext.get().stageId()}-cache-$taskId")
Some(dir.toAbsolutePath.toString)
} else {
None
}
}
private def buildDistributedBooster(
watches: Watches,
params: Map[String, Any], params: Map[String, Any],
rabitEnv: java.util.Map[String, String], rabitEnv: java.util.Map[String, String],
round: Int, round: Int,
obj: ObjectiveTrait, obj: ObjectiveTrait,
eval: EvalTrait, eval: EvalTrait,
useExternalMemory: Boolean, prevBooster: Booster)
missing: Float, : Iterator[(Booster, Map[String, Array[Float]])] = {
prevBooster: Booster
): RDD[(Booster, Map[String, Array[Float]])] = {
val partitionedBaseMargin = data.map(_.baseMargin)
// to workaround the empty partitions in training dataset, // to workaround the empty partitions in training dataset,
// this might not be the best efficient implementation, see // this might not be the best efficient implementation, see
// (https://github.com/dmlc/xgboost/issues/1277) // (https://github.com/dmlc/xgboost/issues/1277)
data.zipPartitions(partitionedBaseMargin) { (labeledPoints, baseMargins) => if (watches.train.rowNum == 0) {
if (labeledPoints.isEmpty) { throw new XGBoostError(
throw new XGBoostError( s"detected an empty partition in the training data, partition ID:" +
s"detected an empty partition in the training data, partition ID:" + s" ${TaskContext.getPartitionId()}")
s" ${TaskContext.getPartitionId()}") }
} val taskId = TaskContext.getPartitionId().toString
val taskId = TaskContext.getPartitionId().toString rabitEnv.put("DMLC_TASK_ID", taskId)
val cacheDirName = if (useExternalMemory) { Rabit.init(rabitEnv)
val dir = Files.createTempDirectory(s"${TaskContext.get().stageId()}-cache-$taskId")
Some(dir.toAbsolutePath.toString)
} else {
None
}
rabitEnv.put("DMLC_TASK_ID", taskId)
Rabit.init(rabitEnv)
val watches = Watches(params,
removeMissingValues(labeledPoints, missing),
fromBaseMarginsToArray(baseMargins), cacheDirName)
try { try {
val numEarlyStoppingRounds = params.get("num_early_stopping_rounds") val numEarlyStoppingRounds = params.get("num_early_stopping_rounds")
.map(_.toString.toInt).getOrElse(0) .map(_.toString.toInt).getOrElse(0)
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round)) val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round))
val booster = SXGBoost.train(watches.train, params, round, val booster = SXGBoost.train(watches.train, params, round,
watches.toMap, metrics, obj, eval, watches.toMap, metrics, obj, eval,
earlyStoppingRound = numEarlyStoppingRounds, prevBooster) earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
Iterator(booster -> watches.toMap.keys.zip(metrics).toMap) Iterator(booster -> watches.toMap.keys.zip(metrics).toMap)
} finally { } finally {
Rabit.shutdown() Rabit.shutdown()
watches.delete() watches.delete()
} }
}.cache()
} }
private def overrideParamsAccordingToTaskCPUs( private def overrideParamsAccordingToTaskCPUs(
@ -219,7 +214,8 @@ object XGBoost extends Serializable {
obj: ObjectiveTrait = null, obj: ObjectiveTrait = null,
eval: EvalTrait = null, eval: EvalTrait = null,
useExternalMemory: Boolean = false, useExternalMemory: Boolean = false,
missing: Float = Float.NaN): (Booster, Map[String, Array[Float]]) = { missing: Float = Float.NaN,
hasGroup: Boolean = false): (Booster, Map[String, Array[Float]]) = {
validateSparkSslConf(trainingData.context) validateSparkSslConf(trainingData.context)
if (params.contains("tree_method")) { if (params.contains("tree_method")) {
require(params("tree_method") != "hist", "xgboost4j-spark does not support fast histogram" + require(params("tree_method") != "hist", "xgboost4j-spark does not support fast histogram" +
@ -244,7 +240,6 @@ object XGBoost extends Serializable {
" an instance of Long.") " an instance of Long.")
} }
val (checkpointPath, checkpointInterval) = CheckpointManager.extractParams(params) val (checkpointPath, checkpointInterval) = CheckpointManager.extractParams(params)
val partitionedData = repartitionForTraining(trainingData, nWorkers)
val sc = trainingData.sparkContext val sc = trainingData.sparkContext
val checkpointManager = new CheckpointManager(sc, checkpointPath) val checkpointManager = new CheckpointManager(sc, checkpointPath)
@ -258,9 +253,29 @@ object XGBoost extends Serializable {
try { try {
val overriddenParams = overrideParamsAccordingToTaskCPUs(params, sc) val overriddenParams = overrideParamsAccordingToTaskCPUs(params, sc)
val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, nWorkers) val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, nWorkers)
val boostersAndMetrics = buildDistributedBoosters(partitionedData, overriddenParams, val rabitEnv = tracker.getWorkerEnvs
tracker.getWorkerEnvs, checkpointRound, obj, eval, useExternalMemory, missing, val boostersAndMetrics = hasGroup match {
prevBooster) case true => {
val partitionedData = repartitionForTrainingGroup(trainingData, nWorkers)
partitionedData.mapPartitions(labeledPointGroups => {
val watches = Watches.buildWatchesWithGroup(params,
removeMissingValuesWithGroup(labeledPointGroups, missing),
getCacheDirName(useExternalMemory))
buildDistributedBooster(watches, params, rabitEnv, checkpointRound,
obj, eval, prevBooster)
}).cache()
}
case false => {
val partitionedData = repartitionForTraining(trainingData, nWorkers)
partitionedData.mapPartitions(labeledPoints => {
val watches = Watches.buildWatches(params,
removeMissingValues(labeledPoints, missing),
getCacheDirName(useExternalMemory))
buildDistributedBooster(watches, params, rabitEnv, checkpointRound,
obj, eval, prevBooster)
}).cache()
}
}
val sparkJobThread = new Thread() { val sparkJobThread = new Thread() {
override def run() { override def run() {
// force the job // force the job
@ -278,13 +293,12 @@ object XGBoost extends Serializable {
checkpointManager.updateCheckpoint(prevBooster) checkpointManager.updateCheckpoint(prevBooster)
} }
(booster, metrics) (booster, metrics)
} finally { } finally {
tracker.stop() tracker.stop()
} }
}.last }.last
} }
private[spark] def repartitionForTraining(trainingData: RDD[XGBLabeledPoint], nWorkers: Int) = { private[spark] def repartitionForTraining(trainingData: RDD[XGBLabeledPoint], nWorkers: Int) = {
if (trainingData.getNumPartitions != nWorkers) { if (trainingData.getNumPartitions != nWorkers) {
logger.info(s"repartitioning training set to $nWorkers partitions") logger.info(s"repartitioning training set to $nWorkers partitions")
@ -294,6 +308,31 @@ object XGBoost extends Serializable {
} }
} }
private[spark] def repartitionForTrainingGroup(
trainingData: RDD[XGBLabeledPoint], nWorkers: Int): RDD[Array[XGBLabeledPoint]] = {
val normalGroups: RDD[Array[XGBLabeledPoint]] = trainingData.mapPartitions(
// LabeledPointGroupIterator returns (Boolean, Array[XGBLabeledPoint])
new LabeledPointGroupIterator(_)).filter(!_.isEdgeGroup).map(_.points)
// edge groups with partition id.
val edgeGroups: RDD[(Int, XGBLabeledPointGroup)] = trainingData.mapPartitions(
new LabeledPointGroupIterator(_)).filter(_.isEdgeGroup).map(
group => (TaskContext.getPartitionId(), group))
// group chunks from different partitions together by group id in XGBLabeledPoint.
// use groupBy instead of aggregateBy since all groups within a partition have unique groud ids.
val stitchedGroups: RDD[Array[XGBLabeledPoint]] = edgeGroups.groupBy(_._2.groupId).map(
groups => {
val it: Iterable[(Int, XGBLabeledPointGroup)] = groups._2
// sorted by partition id and merge list of Array[XGBLabeledPoint] into one array
it.toArray.sortBy(_._1).map(_._2.points).flatten
})
var allGroups = normalGroups.union(stitchedGroups)
logger.info(s"repartitioning training group set to $nWorkers partitions")
allGroups.repartition(nWorkers)
}
private def postTrackerReturnProcessing( private def postTrackerReturnProcessing(
trackerReturnVal: Int, trackerReturnVal: Int,
distributedBoostersAndMetrics: RDD[(Booster, Map[String, Array[Float]])], distributedBoostersAndMetrics: RDD[(Booster, Map[String, Array[Float]])],
@ -321,9 +360,9 @@ object XGBoost extends Serializable {
} }
private class Watches private( private class Watches private(
val train: DMatrix, val train: DMatrix,
val test: DMatrix, val test: DMatrix,
private val cacheDirName: Option[String]) { private val cacheDirName: Option[String]) {
def toMap: Map[String, DMatrix] = Map("train" -> train, "test" -> test) def toMap: Map[String, DMatrix] = Map("train" -> train, "test" -> test)
.filter { case (_, matrix) => matrix.rowNum > 0 } .filter { case (_, matrix) => matrix.rowNum > 0 }
@ -342,59 +381,152 @@ private class Watches private(
private object Watches { private object Watches {
def buildGroups(groups: Seq[Int]): Seq[Int] = { private def fromBaseMarginsToArray(baseMargins: Iterator[Float]): Option[Array[Float]] = {
val output = mutable.ArrayBuffer.empty[Int] val builder = new mutable.ArrayBuilder.ofFloat()
var count = 1 var nTotal = 0
var lastGroup = groups.head var nUndefined = 0
for (group <- groups.tail) { while (baseMargins.hasNext) {
if (group != lastGroup) { nTotal += 1
lastGroup = group val baseMargin = baseMargins.next()
output += count if (baseMargin.isNaN) {
count = 1 nUndefined += 1 // don't waste space for all-NaNs.
} else { } else {
count += 1 builder += baseMargin
} }
} }
output += count if (nUndefined == nTotal) {
output None
} else if (nUndefined == 0) {
Some(builder.result())
} else {
throw new IllegalArgumentException(
s"Encountered a partition with $nUndefined NaN base margin values. " +
s"If you want to specify base margin, ensure all values are non-NaN.")
}
} }
def apply( def buildWatches(
params: Map[String, Any], params: Map[String, Any],
labeledPoints: Iterator[XGBLabeledPoint], labeledPoints: Iterator[XGBLabeledPoint],
baseMarginsOpt: Option[Array[Float]],
cacheDirName: Option[String]): Watches = { cacheDirName: Option[String]): Watches = {
val trainTestRatio = params.get("train_test_ratio").map(_.toString.toDouble).getOrElse(1.0) val trainTestRatio = params.get("train_test_ratio").map(_.toString.toDouble).getOrElse(1.0)
val seed = params.get("seed").map(_.toString.toLong).getOrElse(System.nanoTime()) val seed = params.get("seed").map(_.toString.toLong).getOrElse(System.nanoTime())
val r = new Random(seed) val r = new Random(seed)
val testPoints = mutable.ArrayBuffer.empty[XGBLabeledPoint] val testPoints = mutable.ArrayBuffer.empty[XGBLabeledPoint]
val trainBaseMargins = new mutable.ArrayBuilder.ofFloat
val testBaseMargins = new mutable.ArrayBuilder.ofFloat
val trainPoints = labeledPoints.filter { labeledPoint => val trainPoints = labeledPoints.filter { labeledPoint =>
val accepted = r.nextDouble() <= trainTestRatio val accepted = r.nextDouble() <= trainTestRatio
if (!accepted) { if (!accepted) {
testPoints += labeledPoint testPoints += labeledPoint
testBaseMargins += labeledPoint.baseMargin
} else {
trainBaseMargins += labeledPoint.baseMargin
} }
accepted
}
val trainMatrix = new DMatrix(trainPoints, cacheDirName.map(_ + "/train").orNull)
val testMatrix = new DMatrix(testPoints.iterator, cacheDirName.map(_ + "/test").orNull)
val trainMargin = fromBaseMarginsToArray(trainBaseMargins.result().iterator)
val testMargin = fromBaseMarginsToArray(testBaseMargins.result().iterator)
if (trainMargin.isDefined) trainMatrix.setBaseMargin(trainMargin.get)
if (testMargin.isDefined) testMatrix.setBaseMargin(testMargin.get)
new Watches(trainMatrix, testMatrix, cacheDirName)
}
def buildWatchesWithGroup(
params: Map[String, Any],
labeledPointGroups: Iterator[Array[XGBLabeledPoint]],
cacheDirName: Option[String]): Watches = {
val trainTestRatio = params.get("train_test_ratio").map(_.toString.toDouble).getOrElse(1.0)
val seed = params.get("seed").map(_.toString.toLong).getOrElse(System.nanoTime())
val r = new Random(seed)
val testPoints = mutable.ArrayBuilder.make[XGBLabeledPoint]
val trainBaseMargins = new mutable.ArrayBuilder.ofFloat
val testBaseMargins = new mutable.ArrayBuilder.ofFloat
val trainGroups = new mutable.ArrayBuilder.ofInt
val testGroups = new mutable.ArrayBuilder.ofInt
val trainLabelPointGroups = labeledPointGroups.filter { labeledPointGroup =>
val accepted = r.nextDouble() <= trainTestRatio
if (!accepted) {
labeledPointGroup.foreach(labeledPoint => {
testPoints += labeledPoint
testBaseMargins += labeledPoint.baseMargin
})
testGroups += labeledPointGroup.length
} else {
labeledPointGroup.foreach(trainBaseMargins += _.baseMargin)
trainGroups += labeledPointGroup.length
}
accepted accepted
} }
val (trainIter1, trainIter2) = trainPoints.duplicate val trainPoints = trainLabelPointGroups.flatMap(_.iterator)
val trainMatrix = new DMatrix(trainIter1, cacheDirName.map(_ + "/train").orNull) val trainMatrix = new DMatrix(trainPoints, cacheDirName.map(_ + "/train").orNull)
val trainGroups = buildGroups(trainIter2.map(_.group).toSeq).toArray trainMatrix.setGroup(trainGroups.result())
trainMatrix.setGroup(trainGroups)
val testMatrix = new DMatrix(testPoints.iterator, cacheDirName.map(_ + "/test").orNull) val testMatrix = new DMatrix(testPoints.result().iterator, cacheDirName.map(_ + "/test").orNull)
if (trainTestRatio < 1.0) { if (trainTestRatio < 1.0) {
val testGroups = buildGroups(testPoints.map(_.group)).toArray testMatrix.setGroup(testGroups.result())
testMatrix.setGroup(testGroups)
} }
r.setSeed(seed) val trainMargin = fromBaseMarginsToArray(trainBaseMargins.result().iterator)
for (baseMargins <- baseMarginsOpt) { val testMargin = fromBaseMarginsToArray(testBaseMargins.result().iterator)
val (trainMargin, testMargin) = baseMargins.partition(_ => r.nextDouble() <= trainTestRatio) if (trainMargin.isDefined) trainMatrix.setBaseMargin(trainMargin.get)
trainMatrix.setBaseMargin(trainMargin) if (testMargin.isDefined) testMatrix.setBaseMargin(testMargin.get)
testMatrix.setBaseMargin(testMargin)
}
new Watches(trainMatrix, testMatrix, cacheDirName) new Watches(trainMatrix, testMatrix, cacheDirName)
} }
} }
/**
* Within each RDD partition, group the <code>XGBLabeledPoint</code> by group id.</p>
* And the first and the last groups may not have all the items due to the data partition.
* <code>LabeledPointGroupIterator</code> orginaizes data in a tuple format:
* (isFistGroup || isLastGroup, Array[XGBLabeledPoint]).</p>
* The edge groups across partitions can be stitched together later.
* @param base collection of <code>XGBLabeledPoint</code>
*/
private[spark] class LabeledPointGroupIterator(base: Iterator[XGBLabeledPoint])
extends AbstractIterator[XGBLabeledPointGroup] {
private var firstPointOfNextGroup: XGBLabeledPoint = null
private var isNewGroup = true
override def hasNext: Boolean = {
return base.hasNext || isNewGroup
}
override def next(): XGBLabeledPointGroup = {
val builder = mutable.ArrayBuilder.make[XGBLabeledPoint]
var isFirstGroup = true
if (firstPointOfNextGroup != null) {
builder += firstPointOfNextGroup
isFirstGroup = false
}
isNewGroup = false
while (!isNewGroup && base.hasNext) {
val point = base.next()
val groupId = if (firstPointOfNextGroup != null) firstPointOfNextGroup.group else point.group
firstPointOfNextGroup = point
if (point.group == groupId) {
// add to current group
builder += point
} else {
// start a new group
isNewGroup = true
}
}
val isLastGroup = !isNewGroup
val result = builder.result()
val group = XGBLabeledPointGroup(result(0).group, result, isFirstGroup || isLastGroup)
group
}
}

View File

@ -196,7 +196,7 @@ class XGBoostClassifier (
// All non-null param maps in XGBoostClassifier are in derivedXGBParamMap. // All non-null param maps in XGBoostClassifier are in derivedXGBParamMap.
val (_booster, _metrics) = XGBoost.trainDistributed(instances, derivedXGBParamMap, val (_booster, _metrics) = XGBoost.trainDistributed(instances, derivedXGBParamMap,
$(numRound), $(numWorkers), $(customObj), $(customEval), $(useExternalMemory), $(numRound), $(numWorkers), $(customObj), $(customEval), $(useExternalMemory),
$(missing)) $(missing), hasGroup = false)
val model = new XGBoostClassificationModel(uid, _numClasses, _booster) val model = new XGBoostClassificationModel(uid, _numClasses, _booster)
val summary = XGBoostTrainingSummary(_metrics) val summary = XGBoostTrainingSummary(_metrics)
model.setSummary(summary) model.setSummary(summary)
@ -517,3 +517,4 @@ object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel]
} }
} }
} }

View File

@ -191,7 +191,7 @@ class XGBoostRegressor (
// All non-null param maps in XGBoostRegressor are in derivedXGBParamMap. // All non-null param maps in XGBoostRegressor are in derivedXGBParamMap.
val (_booster, _metrics) = XGBoost.trainDistributed(instances, derivedXGBParamMap, val (_booster, _metrics) = XGBoost.trainDistributed(instances, derivedXGBParamMap,
$(numRound), $(numWorkers), $(customObj), $(customEval), $(useExternalMemory), $(numRound), $(numWorkers), $(customObj), $(customEval), $(useExternalMemory),
$(missing)) $(missing), hasGroup = group != lit(-1))
val model = new XGBoostRegressionModel(uid, _booster) val model = new XGBoostRegressionModel(uid, _booster)
val summary = XGBoostTrainingSummary(_metrics) val summary = XGBoostTrainingSummary(_metrics)
model.setSummary(summary) model.setSummary(summary)

View File

@ -173,7 +173,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
val training2 = training1.withColumn("margin", functions.rand()) val training2 = training1.withColumn("margin", functions.rand())
val test = buildDataFrame(Classification.test) val test = buildDataFrame(Classification.test)
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "test_train_split" -> "0.5", "objective" -> "binary:logistic", "train_test_ratio" -> "1.0",
"num_round" -> 5, "num_workers" -> numWorkers) "num_round" -> 5, "num_workers" -> numWorkers)
val xgb = new XGBoostClassifier(paramMap) val xgb = new XGBoostClassifier(paramMap)

View File

@ -19,6 +19,7 @@ package ml.dmlc.xgboost4j.scala.spark
import java.nio.file.Files import java.nio.file.Files
import java.util.concurrent.LinkedBlockingDeque import java.util.concurrent.LinkedBlockingDeque
import ml.dmlc.xgboost4j.java.Rabit import ml.dmlc.xgboost4j.java.Rabit
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import ml.dmlc.xgboost4j.scala.DMatrix import ml.dmlc.xgboost4j.scala.DMatrix
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _} import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
@ -71,18 +72,16 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
assert(collectedAllReduceResults.poll().sameElements(maxVec)) assert(collectedAllReduceResults.poll().sameElements(maxVec))
} }
test("build RDD containing boosters with the specified worker number") { test("distributed training with the specified worker number") {
val trainingRDD = sc.parallelize(Classification.train) val trainingRDD = sc.parallelize(Classification.train)
val partitionedRDD = XGBoost.repartitionForTraining(trainingRDD, 2) val (booster, metrics) = XGBoost.trainDistributed(
val boosterRDD = XGBoost.buildDistributedBoosters( trainingRDD,
partitionedRDD,
List("eta" -> "1", "max_depth" -> "6", "silent" -> "1", List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic").toMap, "objective" -> "binary:logistic").toMap,
new java.util.HashMap[String, String](), round = 5, nWorkers = numWorkers, eval = null, obj = null, useExternalMemory = false,
round = 5, eval = null, obj = null, useExternalMemory = true, hasGroup = false, missing = Float.NaN)
missing = Float.NaN, prevBooster = null)
val boosterCount = boosterRDD.count() assert(booster != null)
assert(boosterCount === 2)
} }
test("training with external memory cache") { test("training with external memory cache") {
@ -235,4 +234,37 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
assert(error(prevModel._booster) > error(nextModel._booster)) assert(error(prevModel._booster) > error(nextModel._booster))
assert(error(nextModel._booster) < 0.1) assert(error(nextModel._booster) < 0.1)
} }
test("repartitionForTrainingGroup with group data") {
// test different splits to cover the corner cases.
for (split <- 1 to 20) {
val trainingRDD = sc.parallelize(Ranking.train, split)
val traingGroupsRDD = XGBoost.repartitionForTrainingGroup(trainingRDD, 4)
val trainingGroups: Array[Array[XGBLabeledPoint]] = traingGroupsRDD.collect()
// check the the order of the groups with group id.
// Ranking.train has 20 groups
assert(trainingGroups.length == 20)
// compare all points
val allPoints = trainingGroups.sortBy(_(0).group).flatten
assert(allPoints.length == Ranking.train.size)
for (i <- 0 to Ranking.train.size - 1) {
assert(allPoints(i).group == Ranking.train(i).group)
assert(allPoints(i).label == Ranking.train(i).label)
assert(allPoints(i).values.sameElements(Ranking.train(i).values))
}
}
}
test("distributed training with group data") {
val trainingRDD = sc.parallelize(Ranking.train, 2)
val (booster, metrics) = XGBoost.trainDistributed(
trainingRDD,
List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic").toMap,
round = 5, nWorkers = numWorkers, eval = null, obj = null, useExternalMemory = false,
hasGroup = true, missing = Float.NaN)
assert(booster != null)
}
} }