[jvm-packages] Fix #3489: Spark repartitionForData can potentially shuffle all data and lose ordering required for ranking objectives (#3654)
This commit is contained in:
parent
d594b11f35
commit
efc4f85505
@ -19,7 +19,7 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
import java.io.File
|
||||
import java.nio.file.Files
|
||||
|
||||
import scala.collection.mutable
|
||||
import scala.collection.{AbstractIterator, mutable}
|
||||
import scala.util.Random
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
|
||||
@ -53,6 +53,17 @@ object TrackerConf {
|
||||
def apply(): TrackerConf = TrackerConf(0L, "python")
|
||||
}
|
||||
|
||||
/**
|
||||
* Traing data group in a RDD partition.
|
||||
* @param groupId The group id
|
||||
* @param points Array of XGBLabeledPoint within the same group.
|
||||
* @param isEdgeGroup whether it is a frist or last group in a RDD partition.
|
||||
*/
|
||||
private[spark] case class XGBLabeledPointGroup(
|
||||
groupId: Int,
|
||||
points: Array[XGBLabeledPoint],
|
||||
isEdgeGroup: Boolean)
|
||||
|
||||
object XGBoost extends Serializable {
|
||||
private val logger = LogFactory.getLog("XGBoostSpark")
|
||||
|
||||
@ -74,78 +85,62 @@ object XGBoost extends Serializable {
|
||||
}
|
||||
}
|
||||
|
||||
private def fromBaseMarginsToArray(baseMargins: Iterator[Float]): Option[Array[Float]] = {
|
||||
val builder = new mutable.ArrayBuilder.ofFloat()
|
||||
var nTotal = 0
|
||||
var nUndefined = 0
|
||||
while (baseMargins.hasNext) {
|
||||
nTotal += 1
|
||||
val baseMargin = baseMargins.next()
|
||||
if (baseMargin.isNaN) {
|
||||
nUndefined += 1 // don't waste space for all-NaNs.
|
||||
} else {
|
||||
builder += baseMargin
|
||||
private def removeMissingValuesWithGroup(
|
||||
xgbLabelPointGroups: Iterator[Array[XGBLabeledPoint]],
|
||||
missing: Float): Iterator[Array[XGBLabeledPoint]] = {
|
||||
if (!missing.isNaN) {
|
||||
xgbLabelPointGroups.map {
|
||||
labeledPoints => XGBoost.removeMissingValues(labeledPoints.iterator, missing).toArray
|
||||
}
|
||||
}
|
||||
if (nUndefined == nTotal) {
|
||||
None
|
||||
} else if (nUndefined == 0) {
|
||||
Some(builder.result())
|
||||
} else {
|
||||
throw new IllegalArgumentException(
|
||||
s"Encountered a partition with $nUndefined NaN base margin values. " +
|
||||
s"If you want to specify base margin, ensure all values are non-NaN.")
|
||||
xgbLabelPointGroups
|
||||
}
|
||||
}
|
||||
|
||||
private[spark] def buildDistributedBoosters(
|
||||
data: RDD[XGBLabeledPoint],
|
||||
private def getCacheDirName(useExternalMemory: Boolean): Option[String] = {
|
||||
val taskId = TaskContext.getPartitionId().toString
|
||||
if (useExternalMemory) {
|
||||
val dir = Files.createTempDirectory(s"${TaskContext.get().stageId()}-cache-$taskId")
|
||||
Some(dir.toAbsolutePath.toString)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
private def buildDistributedBooster(
|
||||
watches: Watches,
|
||||
params: Map[String, Any],
|
||||
rabitEnv: java.util.Map[String, String],
|
||||
round: Int,
|
||||
obj: ObjectiveTrait,
|
||||
eval: EvalTrait,
|
||||
useExternalMemory: Boolean,
|
||||
missing: Float,
|
||||
prevBooster: Booster
|
||||
): RDD[(Booster, Map[String, Array[Float]])] = {
|
||||
prevBooster: Booster)
|
||||
: Iterator[(Booster, Map[String, Array[Float]])] = {
|
||||
|
||||
val partitionedBaseMargin = data.map(_.baseMargin)
|
||||
// to workaround the empty partitions in training dataset,
|
||||
// this might not be the best efficient implementation, see
|
||||
// (https://github.com/dmlc/xgboost/issues/1277)
|
||||
data.zipPartitions(partitionedBaseMargin) { (labeledPoints, baseMargins) =>
|
||||
if (labeledPoints.isEmpty) {
|
||||
throw new XGBoostError(
|
||||
s"detected an empty partition in the training data, partition ID:" +
|
||||
s" ${TaskContext.getPartitionId()}")
|
||||
}
|
||||
val taskId = TaskContext.getPartitionId().toString
|
||||
val cacheDirName = if (useExternalMemory) {
|
||||
val dir = Files.createTempDirectory(s"${TaskContext.get().stageId()}-cache-$taskId")
|
||||
Some(dir.toAbsolutePath.toString)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
rabitEnv.put("DMLC_TASK_ID", taskId)
|
||||
Rabit.init(rabitEnv)
|
||||
val watches = Watches(params,
|
||||
removeMissingValues(labeledPoints, missing),
|
||||
fromBaseMarginsToArray(baseMargins), cacheDirName)
|
||||
if (watches.train.rowNum == 0) {
|
||||
throw new XGBoostError(
|
||||
s"detected an empty partition in the training data, partition ID:" +
|
||||
s" ${TaskContext.getPartitionId()}")
|
||||
}
|
||||
val taskId = TaskContext.getPartitionId().toString
|
||||
rabitEnv.put("DMLC_TASK_ID", taskId)
|
||||
Rabit.init(rabitEnv)
|
||||
|
||||
try {
|
||||
val numEarlyStoppingRounds = params.get("num_early_stopping_rounds")
|
||||
.map(_.toString.toInt).getOrElse(0)
|
||||
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round))
|
||||
val booster = SXGBoost.train(watches.train, params, round,
|
||||
watches.toMap, metrics, obj, eval,
|
||||
earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
|
||||
Iterator(booster -> watches.toMap.keys.zip(metrics).toMap)
|
||||
} finally {
|
||||
Rabit.shutdown()
|
||||
watches.delete()
|
||||
}
|
||||
}.cache()
|
||||
try {
|
||||
val numEarlyStoppingRounds = params.get("num_early_stopping_rounds")
|
||||
.map(_.toString.toInt).getOrElse(0)
|
||||
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round))
|
||||
val booster = SXGBoost.train(watches.train, params, round,
|
||||
watches.toMap, metrics, obj, eval,
|
||||
earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
|
||||
Iterator(booster -> watches.toMap.keys.zip(metrics).toMap)
|
||||
} finally {
|
||||
Rabit.shutdown()
|
||||
watches.delete()
|
||||
}
|
||||
}
|
||||
|
||||
private def overrideParamsAccordingToTaskCPUs(
|
||||
@ -219,7 +214,8 @@ object XGBoost extends Serializable {
|
||||
obj: ObjectiveTrait = null,
|
||||
eval: EvalTrait = null,
|
||||
useExternalMemory: Boolean = false,
|
||||
missing: Float = Float.NaN): (Booster, Map[String, Array[Float]]) = {
|
||||
missing: Float = Float.NaN,
|
||||
hasGroup: Boolean = false): (Booster, Map[String, Array[Float]]) = {
|
||||
validateSparkSslConf(trainingData.context)
|
||||
if (params.contains("tree_method")) {
|
||||
require(params("tree_method") != "hist", "xgboost4j-spark does not support fast histogram" +
|
||||
@ -244,7 +240,6 @@ object XGBoost extends Serializable {
|
||||
" an instance of Long.")
|
||||
}
|
||||
val (checkpointPath, checkpointInterval) = CheckpointManager.extractParams(params)
|
||||
val partitionedData = repartitionForTraining(trainingData, nWorkers)
|
||||
|
||||
val sc = trainingData.sparkContext
|
||||
val checkpointManager = new CheckpointManager(sc, checkpointPath)
|
||||
@ -258,9 +253,29 @@ object XGBoost extends Serializable {
|
||||
try {
|
||||
val overriddenParams = overrideParamsAccordingToTaskCPUs(params, sc)
|
||||
val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, nWorkers)
|
||||
val boostersAndMetrics = buildDistributedBoosters(partitionedData, overriddenParams,
|
||||
tracker.getWorkerEnvs, checkpointRound, obj, eval, useExternalMemory, missing,
|
||||
prevBooster)
|
||||
val rabitEnv = tracker.getWorkerEnvs
|
||||
val boostersAndMetrics = hasGroup match {
|
||||
case true => {
|
||||
val partitionedData = repartitionForTrainingGroup(trainingData, nWorkers)
|
||||
partitionedData.mapPartitions(labeledPointGroups => {
|
||||
val watches = Watches.buildWatchesWithGroup(params,
|
||||
removeMissingValuesWithGroup(labeledPointGroups, missing),
|
||||
getCacheDirName(useExternalMemory))
|
||||
buildDistributedBooster(watches, params, rabitEnv, checkpointRound,
|
||||
obj, eval, prevBooster)
|
||||
}).cache()
|
||||
}
|
||||
case false => {
|
||||
val partitionedData = repartitionForTraining(trainingData, nWorkers)
|
||||
partitionedData.mapPartitions(labeledPoints => {
|
||||
val watches = Watches.buildWatches(params,
|
||||
removeMissingValues(labeledPoints, missing),
|
||||
getCacheDirName(useExternalMemory))
|
||||
buildDistributedBooster(watches, params, rabitEnv, checkpointRound,
|
||||
obj, eval, prevBooster)
|
||||
}).cache()
|
||||
}
|
||||
}
|
||||
val sparkJobThread = new Thread() {
|
||||
override def run() {
|
||||
// force the job
|
||||
@ -278,13 +293,12 @@ object XGBoost extends Serializable {
|
||||
checkpointManager.updateCheckpoint(prevBooster)
|
||||
}
|
||||
(booster, metrics)
|
||||
} finally {
|
||||
tracker.stop()
|
||||
}
|
||||
} finally {
|
||||
tracker.stop()
|
||||
}
|
||||
}.last
|
||||
}
|
||||
|
||||
|
||||
private[spark] def repartitionForTraining(trainingData: RDD[XGBLabeledPoint], nWorkers: Int) = {
|
||||
if (trainingData.getNumPartitions != nWorkers) {
|
||||
logger.info(s"repartitioning training set to $nWorkers partitions")
|
||||
@ -294,6 +308,31 @@ object XGBoost extends Serializable {
|
||||
}
|
||||
}
|
||||
|
||||
private[spark] def repartitionForTrainingGroup(
|
||||
trainingData: RDD[XGBLabeledPoint], nWorkers: Int): RDD[Array[XGBLabeledPoint]] = {
|
||||
val normalGroups: RDD[Array[XGBLabeledPoint]] = trainingData.mapPartitions(
|
||||
// LabeledPointGroupIterator returns (Boolean, Array[XGBLabeledPoint])
|
||||
new LabeledPointGroupIterator(_)).filter(!_.isEdgeGroup).map(_.points)
|
||||
|
||||
// edge groups with partition id.
|
||||
val edgeGroups: RDD[(Int, XGBLabeledPointGroup)] = trainingData.mapPartitions(
|
||||
new LabeledPointGroupIterator(_)).filter(_.isEdgeGroup).map(
|
||||
group => (TaskContext.getPartitionId(), group))
|
||||
|
||||
// group chunks from different partitions together by group id in XGBLabeledPoint.
|
||||
// use groupBy instead of aggregateBy since all groups within a partition have unique groud ids.
|
||||
val stitchedGroups: RDD[Array[XGBLabeledPoint]] = edgeGroups.groupBy(_._2.groupId).map(
|
||||
groups => {
|
||||
val it: Iterable[(Int, XGBLabeledPointGroup)] = groups._2
|
||||
// sorted by partition id and merge list of Array[XGBLabeledPoint] into one array
|
||||
it.toArray.sortBy(_._1).map(_._2.points).flatten
|
||||
})
|
||||
|
||||
var allGroups = normalGroups.union(stitchedGroups)
|
||||
logger.info(s"repartitioning training group set to $nWorkers partitions")
|
||||
allGroups.repartition(nWorkers)
|
||||
}
|
||||
|
||||
private def postTrackerReturnProcessing(
|
||||
trackerReturnVal: Int,
|
||||
distributedBoostersAndMetrics: RDD[(Booster, Map[String, Array[Float]])],
|
||||
@ -321,9 +360,9 @@ object XGBoost extends Serializable {
|
||||
}
|
||||
|
||||
private class Watches private(
|
||||
val train: DMatrix,
|
||||
val test: DMatrix,
|
||||
private val cacheDirName: Option[String]) {
|
||||
val train: DMatrix,
|
||||
val test: DMatrix,
|
||||
private val cacheDirName: Option[String]) {
|
||||
|
||||
def toMap: Map[String, DMatrix] = Map("train" -> train, "test" -> test)
|
||||
.filter { case (_, matrix) => matrix.rowNum > 0 }
|
||||
@ -342,59 +381,152 @@ private class Watches private(
|
||||
|
||||
private object Watches {
|
||||
|
||||
def buildGroups(groups: Seq[Int]): Seq[Int] = {
|
||||
val output = mutable.ArrayBuffer.empty[Int]
|
||||
var count = 1
|
||||
var lastGroup = groups.head
|
||||
for (group <- groups.tail) {
|
||||
if (group != lastGroup) {
|
||||
lastGroup = group
|
||||
output += count
|
||||
count = 1
|
||||
private def fromBaseMarginsToArray(baseMargins: Iterator[Float]): Option[Array[Float]] = {
|
||||
val builder = new mutable.ArrayBuilder.ofFloat()
|
||||
var nTotal = 0
|
||||
var nUndefined = 0
|
||||
while (baseMargins.hasNext) {
|
||||
nTotal += 1
|
||||
val baseMargin = baseMargins.next()
|
||||
if (baseMargin.isNaN) {
|
||||
nUndefined += 1 // don't waste space for all-NaNs.
|
||||
} else {
|
||||
count += 1
|
||||
builder += baseMargin
|
||||
}
|
||||
}
|
||||
output += count
|
||||
output
|
||||
if (nUndefined == nTotal) {
|
||||
None
|
||||
} else if (nUndefined == 0) {
|
||||
Some(builder.result())
|
||||
} else {
|
||||
throw new IllegalArgumentException(
|
||||
s"Encountered a partition with $nUndefined NaN base margin values. " +
|
||||
s"If you want to specify base margin, ensure all values are non-NaN.")
|
||||
}
|
||||
}
|
||||
|
||||
def apply(
|
||||
def buildWatches(
|
||||
params: Map[String, Any],
|
||||
labeledPoints: Iterator[XGBLabeledPoint],
|
||||
baseMarginsOpt: Option[Array[Float]],
|
||||
cacheDirName: Option[String]): Watches = {
|
||||
val trainTestRatio = params.get("train_test_ratio").map(_.toString.toDouble).getOrElse(1.0)
|
||||
val seed = params.get("seed").map(_.toString.toLong).getOrElse(System.nanoTime())
|
||||
val r = new Random(seed)
|
||||
val testPoints = mutable.ArrayBuffer.empty[XGBLabeledPoint]
|
||||
val trainBaseMargins = new mutable.ArrayBuilder.ofFloat
|
||||
val testBaseMargins = new mutable.ArrayBuilder.ofFloat
|
||||
val trainPoints = labeledPoints.filter { labeledPoint =>
|
||||
val accepted = r.nextDouble() <= trainTestRatio
|
||||
if (!accepted) {
|
||||
testPoints += labeledPoint
|
||||
testBaseMargins += labeledPoint.baseMargin
|
||||
} else {
|
||||
trainBaseMargins += labeledPoint.baseMargin
|
||||
}
|
||||
accepted
|
||||
}
|
||||
val trainMatrix = new DMatrix(trainPoints, cacheDirName.map(_ + "/train").orNull)
|
||||
val testMatrix = new DMatrix(testPoints.iterator, cacheDirName.map(_ + "/test").orNull)
|
||||
|
||||
val trainMargin = fromBaseMarginsToArray(trainBaseMargins.result().iterator)
|
||||
val testMargin = fromBaseMarginsToArray(testBaseMargins.result().iterator)
|
||||
if (trainMargin.isDefined) trainMatrix.setBaseMargin(trainMargin.get)
|
||||
if (testMargin.isDefined) testMatrix.setBaseMargin(testMargin.get)
|
||||
|
||||
new Watches(trainMatrix, testMatrix, cacheDirName)
|
||||
}
|
||||
|
||||
def buildWatchesWithGroup(
|
||||
params: Map[String, Any],
|
||||
labeledPointGroups: Iterator[Array[XGBLabeledPoint]],
|
||||
cacheDirName: Option[String]): Watches = {
|
||||
val trainTestRatio = params.get("train_test_ratio").map(_.toString.toDouble).getOrElse(1.0)
|
||||
val seed = params.get("seed").map(_.toString.toLong).getOrElse(System.nanoTime())
|
||||
val r = new Random(seed)
|
||||
val testPoints = mutable.ArrayBuilder.make[XGBLabeledPoint]
|
||||
val trainBaseMargins = new mutable.ArrayBuilder.ofFloat
|
||||
val testBaseMargins = new mutable.ArrayBuilder.ofFloat
|
||||
val trainGroups = new mutable.ArrayBuilder.ofInt
|
||||
val testGroups = new mutable.ArrayBuilder.ofInt
|
||||
|
||||
val trainLabelPointGroups = labeledPointGroups.filter { labeledPointGroup =>
|
||||
val accepted = r.nextDouble() <= trainTestRatio
|
||||
if (!accepted) {
|
||||
labeledPointGroup.foreach(labeledPoint => {
|
||||
testPoints += labeledPoint
|
||||
testBaseMargins += labeledPoint.baseMargin
|
||||
})
|
||||
testGroups += labeledPointGroup.length
|
||||
} else {
|
||||
labeledPointGroup.foreach(trainBaseMargins += _.baseMargin)
|
||||
trainGroups += labeledPointGroup.length
|
||||
}
|
||||
accepted
|
||||
}
|
||||
|
||||
val (trainIter1, trainIter2) = trainPoints.duplicate
|
||||
val trainMatrix = new DMatrix(trainIter1, cacheDirName.map(_ + "/train").orNull)
|
||||
val trainGroups = buildGroups(trainIter2.map(_.group).toSeq).toArray
|
||||
trainMatrix.setGroup(trainGroups)
|
||||
val trainPoints = trainLabelPointGroups.flatMap(_.iterator)
|
||||
val trainMatrix = new DMatrix(trainPoints, cacheDirName.map(_ + "/train").orNull)
|
||||
trainMatrix.setGroup(trainGroups.result())
|
||||
|
||||
val testMatrix = new DMatrix(testPoints.iterator, cacheDirName.map(_ + "/test").orNull)
|
||||
val testMatrix = new DMatrix(testPoints.result().iterator, cacheDirName.map(_ + "/test").orNull)
|
||||
if (trainTestRatio < 1.0) {
|
||||
val testGroups = buildGroups(testPoints.map(_.group)).toArray
|
||||
testMatrix.setGroup(testGroups)
|
||||
testMatrix.setGroup(testGroups.result())
|
||||
}
|
||||
|
||||
r.setSeed(seed)
|
||||
for (baseMargins <- baseMarginsOpt) {
|
||||
val (trainMargin, testMargin) = baseMargins.partition(_ => r.nextDouble() <= trainTestRatio)
|
||||
trainMatrix.setBaseMargin(trainMargin)
|
||||
testMatrix.setBaseMargin(testMargin)
|
||||
}
|
||||
val trainMargin = fromBaseMarginsToArray(trainBaseMargins.result().iterator)
|
||||
val testMargin = fromBaseMarginsToArray(testBaseMargins.result().iterator)
|
||||
if (trainMargin.isDefined) trainMatrix.setBaseMargin(trainMargin.get)
|
||||
if (testMargin.isDefined) testMatrix.setBaseMargin(testMargin.get)
|
||||
|
||||
new Watches(trainMatrix, testMatrix, cacheDirName)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Within each RDD partition, group the <code>XGBLabeledPoint</code> by group id.</p>
|
||||
* And the first and the last groups may not have all the items due to the data partition.
|
||||
* <code>LabeledPointGroupIterator</code> orginaizes data in a tuple format:
|
||||
* (isFistGroup || isLastGroup, Array[XGBLabeledPoint]).</p>
|
||||
* The edge groups across partitions can be stitched together later.
|
||||
* @param base collection of <code>XGBLabeledPoint</code>
|
||||
*/
|
||||
private[spark] class LabeledPointGroupIterator(base: Iterator[XGBLabeledPoint])
|
||||
extends AbstractIterator[XGBLabeledPointGroup] {
|
||||
|
||||
private var firstPointOfNextGroup: XGBLabeledPoint = null
|
||||
private var isNewGroup = true
|
||||
|
||||
override def hasNext: Boolean = {
|
||||
return base.hasNext || isNewGroup
|
||||
}
|
||||
|
||||
override def next(): XGBLabeledPointGroup = {
|
||||
val builder = mutable.ArrayBuilder.make[XGBLabeledPoint]
|
||||
var isFirstGroup = true
|
||||
if (firstPointOfNextGroup != null) {
|
||||
builder += firstPointOfNextGroup
|
||||
isFirstGroup = false
|
||||
}
|
||||
|
||||
isNewGroup = false
|
||||
while (!isNewGroup && base.hasNext) {
|
||||
val point = base.next()
|
||||
val groupId = if (firstPointOfNextGroup != null) firstPointOfNextGroup.group else point.group
|
||||
firstPointOfNextGroup = point
|
||||
if (point.group == groupId) {
|
||||
// add to current group
|
||||
builder += point
|
||||
} else {
|
||||
// start a new group
|
||||
isNewGroup = true
|
||||
}
|
||||
}
|
||||
|
||||
val isLastGroup = !isNewGroup
|
||||
val result = builder.result()
|
||||
val group = XGBLabeledPointGroup(result(0).group, result, isFirstGroup || isLastGroup)
|
||||
|
||||
group
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -196,7 +196,7 @@ class XGBoostClassifier (
|
||||
// All non-null param maps in XGBoostClassifier are in derivedXGBParamMap.
|
||||
val (_booster, _metrics) = XGBoost.trainDistributed(instances, derivedXGBParamMap,
|
||||
$(numRound), $(numWorkers), $(customObj), $(customEval), $(useExternalMemory),
|
||||
$(missing))
|
||||
$(missing), hasGroup = false)
|
||||
val model = new XGBoostClassificationModel(uid, _numClasses, _booster)
|
||||
val summary = XGBoostTrainingSummary(_metrics)
|
||||
model.setSummary(summary)
|
||||
@ -517,3 +517,4 @@ object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -191,7 +191,7 @@ class XGBoostRegressor (
|
||||
// All non-null param maps in XGBoostRegressor are in derivedXGBParamMap.
|
||||
val (_booster, _metrics) = XGBoost.trainDistributed(instances, derivedXGBParamMap,
|
||||
$(numRound), $(numWorkers), $(customObj), $(customEval), $(useExternalMemory),
|
||||
$(missing))
|
||||
$(missing), hasGroup = group != lit(-1))
|
||||
val model = new XGBoostRegressionModel(uid, _booster)
|
||||
val summary = XGBoostTrainingSummary(_metrics)
|
||||
model.setSummary(summary)
|
||||
|
||||
@ -173,7 +173,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
|
||||
val training2 = training1.withColumn("margin", functions.rand())
|
||||
val test = buildDataFrame(Classification.test)
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "test_train_split" -> "0.5",
|
||||
"objective" -> "binary:logistic", "train_test_ratio" -> "1.0",
|
||||
"num_round" -> 5, "num_workers" -> numWorkers)
|
||||
|
||||
val xgb = new XGBoostClassifier(paramMap)
|
||||
|
||||
@ -19,6 +19,7 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
import java.nio.file.Files
|
||||
import java.util.concurrent.LinkedBlockingDeque
|
||||
import ml.dmlc.xgboost4j.java.Rabit
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||
@ -71,18 +72,16 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
||||
assert(collectedAllReduceResults.poll().sameElements(maxVec))
|
||||
}
|
||||
|
||||
test("build RDD containing boosters with the specified worker number") {
|
||||
test("distributed training with the specified worker number") {
|
||||
val trainingRDD = sc.parallelize(Classification.train)
|
||||
val partitionedRDD = XGBoost.repartitionForTraining(trainingRDD, 2)
|
||||
val boosterRDD = XGBoost.buildDistributedBoosters(
|
||||
partitionedRDD,
|
||||
val (booster, metrics) = XGBoost.trainDistributed(
|
||||
trainingRDD,
|
||||
List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic").toMap,
|
||||
new java.util.HashMap[String, String](),
|
||||
round = 5, eval = null, obj = null, useExternalMemory = true,
|
||||
missing = Float.NaN, prevBooster = null)
|
||||
val boosterCount = boosterRDD.count()
|
||||
assert(boosterCount === 2)
|
||||
round = 5, nWorkers = numWorkers, eval = null, obj = null, useExternalMemory = false,
|
||||
hasGroup = false, missing = Float.NaN)
|
||||
|
||||
assert(booster != null)
|
||||
}
|
||||
|
||||
test("training with external memory cache") {
|
||||
@ -235,4 +234,37 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
||||
assert(error(prevModel._booster) > error(nextModel._booster))
|
||||
assert(error(nextModel._booster) < 0.1)
|
||||
}
|
||||
|
||||
test("repartitionForTrainingGroup with group data") {
|
||||
// test different splits to cover the corner cases.
|
||||
for (split <- 1 to 20) {
|
||||
val trainingRDD = sc.parallelize(Ranking.train, split)
|
||||
val traingGroupsRDD = XGBoost.repartitionForTrainingGroup(trainingRDD, 4)
|
||||
val trainingGroups: Array[Array[XGBLabeledPoint]] = traingGroupsRDD.collect()
|
||||
// check the the order of the groups with group id.
|
||||
// Ranking.train has 20 groups
|
||||
assert(trainingGroups.length == 20)
|
||||
|
||||
// compare all points
|
||||
val allPoints = trainingGroups.sortBy(_(0).group).flatten
|
||||
assert(allPoints.length == Ranking.train.size)
|
||||
for (i <- 0 to Ranking.train.size - 1) {
|
||||
assert(allPoints(i).group == Ranking.train(i).group)
|
||||
assert(allPoints(i).label == Ranking.train(i).label)
|
||||
assert(allPoints(i).values.sameElements(Ranking.train(i).values))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("distributed training with group data") {
|
||||
val trainingRDD = sc.parallelize(Ranking.train, 2)
|
||||
val (booster, metrics) = XGBoost.trainDistributed(
|
||||
trainingRDD,
|
||||
List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic").toMap,
|
||||
round = 5, nWorkers = numWorkers, eval = null, obj = null, useExternalMemory = false,
|
||||
hasGroup = true, missing = Float.NaN)
|
||||
|
||||
assert(booster != null)
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user