[jvm-packages] Fix #3489: Spark repartitionForData can potentially shuffle all data and lose ordering required for ranking objectives (#3654)
This commit is contained in:
parent
d594b11f35
commit
efc4f85505
@ -19,7 +19,7 @@ package ml.dmlc.xgboost4j.scala.spark
|
|||||||
import java.io.File
|
import java.io.File
|
||||||
import java.nio.file.Files
|
import java.nio.file.Files
|
||||||
|
|
||||||
import scala.collection.mutable
|
import scala.collection.{AbstractIterator, mutable}
|
||||||
import scala.util.Random
|
import scala.util.Random
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
|
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
|
||||||
@ -53,6 +53,17 @@ object TrackerConf {
|
|||||||
def apply(): TrackerConf = TrackerConf(0L, "python")
|
def apply(): TrackerConf = TrackerConf(0L, "python")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Traing data group in a RDD partition.
|
||||||
|
* @param groupId The group id
|
||||||
|
* @param points Array of XGBLabeledPoint within the same group.
|
||||||
|
* @param isEdgeGroup whether it is a frist or last group in a RDD partition.
|
||||||
|
*/
|
||||||
|
private[spark] case class XGBLabeledPointGroup(
|
||||||
|
groupId: Int,
|
||||||
|
points: Array[XGBLabeledPoint],
|
||||||
|
isEdgeGroup: Boolean)
|
||||||
|
|
||||||
object XGBoost extends Serializable {
|
object XGBoost extends Serializable {
|
||||||
private val logger = LogFactory.getLog("XGBoostSpark")
|
private val logger = LogFactory.getLog("XGBoostSpark")
|
||||||
|
|
||||||
@ -74,78 +85,62 @@ object XGBoost extends Serializable {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private def fromBaseMarginsToArray(baseMargins: Iterator[Float]): Option[Array[Float]] = {
|
private def removeMissingValuesWithGroup(
|
||||||
val builder = new mutable.ArrayBuilder.ofFloat()
|
xgbLabelPointGroups: Iterator[Array[XGBLabeledPoint]],
|
||||||
var nTotal = 0
|
missing: Float): Iterator[Array[XGBLabeledPoint]] = {
|
||||||
var nUndefined = 0
|
if (!missing.isNaN) {
|
||||||
while (baseMargins.hasNext) {
|
xgbLabelPointGroups.map {
|
||||||
nTotal += 1
|
labeledPoints => XGBoost.removeMissingValues(labeledPoints.iterator, missing).toArray
|
||||||
val baseMargin = baseMargins.next()
|
|
||||||
if (baseMargin.isNaN) {
|
|
||||||
nUndefined += 1 // don't waste space for all-NaNs.
|
|
||||||
} else {
|
|
||||||
builder += baseMargin
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
if (nUndefined == nTotal) {
|
|
||||||
None
|
|
||||||
} else if (nUndefined == 0) {
|
|
||||||
Some(builder.result())
|
|
||||||
} else {
|
} else {
|
||||||
throw new IllegalArgumentException(
|
xgbLabelPointGroups
|
||||||
s"Encountered a partition with $nUndefined NaN base margin values. " +
|
|
||||||
s"If you want to specify base margin, ensure all values are non-NaN.")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private[spark] def buildDistributedBoosters(
|
private def getCacheDirName(useExternalMemory: Boolean): Option[String] = {
|
||||||
data: RDD[XGBLabeledPoint],
|
val taskId = TaskContext.getPartitionId().toString
|
||||||
|
if (useExternalMemory) {
|
||||||
|
val dir = Files.createTempDirectory(s"${TaskContext.get().stageId()}-cache-$taskId")
|
||||||
|
Some(dir.toAbsolutePath.toString)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private def buildDistributedBooster(
|
||||||
|
watches: Watches,
|
||||||
params: Map[String, Any],
|
params: Map[String, Any],
|
||||||
rabitEnv: java.util.Map[String, String],
|
rabitEnv: java.util.Map[String, String],
|
||||||
round: Int,
|
round: Int,
|
||||||
obj: ObjectiveTrait,
|
obj: ObjectiveTrait,
|
||||||
eval: EvalTrait,
|
eval: EvalTrait,
|
||||||
useExternalMemory: Boolean,
|
prevBooster: Booster)
|
||||||
missing: Float,
|
: Iterator[(Booster, Map[String, Array[Float]])] = {
|
||||||
prevBooster: Booster
|
|
||||||
): RDD[(Booster, Map[String, Array[Float]])] = {
|
|
||||||
|
|
||||||
val partitionedBaseMargin = data.map(_.baseMargin)
|
|
||||||
// to workaround the empty partitions in training dataset,
|
// to workaround the empty partitions in training dataset,
|
||||||
// this might not be the best efficient implementation, see
|
// this might not be the best efficient implementation, see
|
||||||
// (https://github.com/dmlc/xgboost/issues/1277)
|
// (https://github.com/dmlc/xgboost/issues/1277)
|
||||||
data.zipPartitions(partitionedBaseMargin) { (labeledPoints, baseMargins) =>
|
if (watches.train.rowNum == 0) {
|
||||||
if (labeledPoints.isEmpty) {
|
throw new XGBoostError(
|
||||||
throw new XGBoostError(
|
s"detected an empty partition in the training data, partition ID:" +
|
||||||
s"detected an empty partition in the training data, partition ID:" +
|
s" ${TaskContext.getPartitionId()}")
|
||||||
s" ${TaskContext.getPartitionId()}")
|
}
|
||||||
}
|
val taskId = TaskContext.getPartitionId().toString
|
||||||
val taskId = TaskContext.getPartitionId().toString
|
rabitEnv.put("DMLC_TASK_ID", taskId)
|
||||||
val cacheDirName = if (useExternalMemory) {
|
Rabit.init(rabitEnv)
|
||||||
val dir = Files.createTempDirectory(s"${TaskContext.get().stageId()}-cache-$taskId")
|
|
||||||
Some(dir.toAbsolutePath.toString)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
rabitEnv.put("DMLC_TASK_ID", taskId)
|
|
||||||
Rabit.init(rabitEnv)
|
|
||||||
val watches = Watches(params,
|
|
||||||
removeMissingValues(labeledPoints, missing),
|
|
||||||
fromBaseMarginsToArray(baseMargins), cacheDirName)
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
val numEarlyStoppingRounds = params.get("num_early_stopping_rounds")
|
val numEarlyStoppingRounds = params.get("num_early_stopping_rounds")
|
||||||
.map(_.toString.toInt).getOrElse(0)
|
.map(_.toString.toInt).getOrElse(0)
|
||||||
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round))
|
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round))
|
||||||
val booster = SXGBoost.train(watches.train, params, round,
|
val booster = SXGBoost.train(watches.train, params, round,
|
||||||
watches.toMap, metrics, obj, eval,
|
watches.toMap, metrics, obj, eval,
|
||||||
earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
|
earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
|
||||||
Iterator(booster -> watches.toMap.keys.zip(metrics).toMap)
|
Iterator(booster -> watches.toMap.keys.zip(metrics).toMap)
|
||||||
} finally {
|
} finally {
|
||||||
Rabit.shutdown()
|
Rabit.shutdown()
|
||||||
watches.delete()
|
watches.delete()
|
||||||
}
|
}
|
||||||
}.cache()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private def overrideParamsAccordingToTaskCPUs(
|
private def overrideParamsAccordingToTaskCPUs(
|
||||||
@ -219,7 +214,8 @@ object XGBoost extends Serializable {
|
|||||||
obj: ObjectiveTrait = null,
|
obj: ObjectiveTrait = null,
|
||||||
eval: EvalTrait = null,
|
eval: EvalTrait = null,
|
||||||
useExternalMemory: Boolean = false,
|
useExternalMemory: Boolean = false,
|
||||||
missing: Float = Float.NaN): (Booster, Map[String, Array[Float]]) = {
|
missing: Float = Float.NaN,
|
||||||
|
hasGroup: Boolean = false): (Booster, Map[String, Array[Float]]) = {
|
||||||
validateSparkSslConf(trainingData.context)
|
validateSparkSslConf(trainingData.context)
|
||||||
if (params.contains("tree_method")) {
|
if (params.contains("tree_method")) {
|
||||||
require(params("tree_method") != "hist", "xgboost4j-spark does not support fast histogram" +
|
require(params("tree_method") != "hist", "xgboost4j-spark does not support fast histogram" +
|
||||||
@ -244,7 +240,6 @@ object XGBoost extends Serializable {
|
|||||||
" an instance of Long.")
|
" an instance of Long.")
|
||||||
}
|
}
|
||||||
val (checkpointPath, checkpointInterval) = CheckpointManager.extractParams(params)
|
val (checkpointPath, checkpointInterval) = CheckpointManager.extractParams(params)
|
||||||
val partitionedData = repartitionForTraining(trainingData, nWorkers)
|
|
||||||
|
|
||||||
val sc = trainingData.sparkContext
|
val sc = trainingData.sparkContext
|
||||||
val checkpointManager = new CheckpointManager(sc, checkpointPath)
|
val checkpointManager = new CheckpointManager(sc, checkpointPath)
|
||||||
@ -258,9 +253,29 @@ object XGBoost extends Serializable {
|
|||||||
try {
|
try {
|
||||||
val overriddenParams = overrideParamsAccordingToTaskCPUs(params, sc)
|
val overriddenParams = overrideParamsAccordingToTaskCPUs(params, sc)
|
||||||
val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, nWorkers)
|
val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, nWorkers)
|
||||||
val boostersAndMetrics = buildDistributedBoosters(partitionedData, overriddenParams,
|
val rabitEnv = tracker.getWorkerEnvs
|
||||||
tracker.getWorkerEnvs, checkpointRound, obj, eval, useExternalMemory, missing,
|
val boostersAndMetrics = hasGroup match {
|
||||||
prevBooster)
|
case true => {
|
||||||
|
val partitionedData = repartitionForTrainingGroup(trainingData, nWorkers)
|
||||||
|
partitionedData.mapPartitions(labeledPointGroups => {
|
||||||
|
val watches = Watches.buildWatchesWithGroup(params,
|
||||||
|
removeMissingValuesWithGroup(labeledPointGroups, missing),
|
||||||
|
getCacheDirName(useExternalMemory))
|
||||||
|
buildDistributedBooster(watches, params, rabitEnv, checkpointRound,
|
||||||
|
obj, eval, prevBooster)
|
||||||
|
}).cache()
|
||||||
|
}
|
||||||
|
case false => {
|
||||||
|
val partitionedData = repartitionForTraining(trainingData, nWorkers)
|
||||||
|
partitionedData.mapPartitions(labeledPoints => {
|
||||||
|
val watches = Watches.buildWatches(params,
|
||||||
|
removeMissingValues(labeledPoints, missing),
|
||||||
|
getCacheDirName(useExternalMemory))
|
||||||
|
buildDistributedBooster(watches, params, rabitEnv, checkpointRound,
|
||||||
|
obj, eval, prevBooster)
|
||||||
|
}).cache()
|
||||||
|
}
|
||||||
|
}
|
||||||
val sparkJobThread = new Thread() {
|
val sparkJobThread = new Thread() {
|
||||||
override def run() {
|
override def run() {
|
||||||
// force the job
|
// force the job
|
||||||
@ -278,13 +293,12 @@ object XGBoost extends Serializable {
|
|||||||
checkpointManager.updateCheckpoint(prevBooster)
|
checkpointManager.updateCheckpoint(prevBooster)
|
||||||
}
|
}
|
||||||
(booster, metrics)
|
(booster, metrics)
|
||||||
} finally {
|
} finally {
|
||||||
tracker.stop()
|
tracker.stop()
|
||||||
}
|
}
|
||||||
}.last
|
}.last
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
private[spark] def repartitionForTraining(trainingData: RDD[XGBLabeledPoint], nWorkers: Int) = {
|
private[spark] def repartitionForTraining(trainingData: RDD[XGBLabeledPoint], nWorkers: Int) = {
|
||||||
if (trainingData.getNumPartitions != nWorkers) {
|
if (trainingData.getNumPartitions != nWorkers) {
|
||||||
logger.info(s"repartitioning training set to $nWorkers partitions")
|
logger.info(s"repartitioning training set to $nWorkers partitions")
|
||||||
@ -294,6 +308,31 @@ object XGBoost extends Serializable {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private[spark] def repartitionForTrainingGroup(
|
||||||
|
trainingData: RDD[XGBLabeledPoint], nWorkers: Int): RDD[Array[XGBLabeledPoint]] = {
|
||||||
|
val normalGroups: RDD[Array[XGBLabeledPoint]] = trainingData.mapPartitions(
|
||||||
|
// LabeledPointGroupIterator returns (Boolean, Array[XGBLabeledPoint])
|
||||||
|
new LabeledPointGroupIterator(_)).filter(!_.isEdgeGroup).map(_.points)
|
||||||
|
|
||||||
|
// edge groups with partition id.
|
||||||
|
val edgeGroups: RDD[(Int, XGBLabeledPointGroup)] = trainingData.mapPartitions(
|
||||||
|
new LabeledPointGroupIterator(_)).filter(_.isEdgeGroup).map(
|
||||||
|
group => (TaskContext.getPartitionId(), group))
|
||||||
|
|
||||||
|
// group chunks from different partitions together by group id in XGBLabeledPoint.
|
||||||
|
// use groupBy instead of aggregateBy since all groups within a partition have unique groud ids.
|
||||||
|
val stitchedGroups: RDD[Array[XGBLabeledPoint]] = edgeGroups.groupBy(_._2.groupId).map(
|
||||||
|
groups => {
|
||||||
|
val it: Iterable[(Int, XGBLabeledPointGroup)] = groups._2
|
||||||
|
// sorted by partition id and merge list of Array[XGBLabeledPoint] into one array
|
||||||
|
it.toArray.sortBy(_._1).map(_._2.points).flatten
|
||||||
|
})
|
||||||
|
|
||||||
|
var allGroups = normalGroups.union(stitchedGroups)
|
||||||
|
logger.info(s"repartitioning training group set to $nWorkers partitions")
|
||||||
|
allGroups.repartition(nWorkers)
|
||||||
|
}
|
||||||
|
|
||||||
private def postTrackerReturnProcessing(
|
private def postTrackerReturnProcessing(
|
||||||
trackerReturnVal: Int,
|
trackerReturnVal: Int,
|
||||||
distributedBoostersAndMetrics: RDD[(Booster, Map[String, Array[Float]])],
|
distributedBoostersAndMetrics: RDD[(Booster, Map[String, Array[Float]])],
|
||||||
@ -321,9 +360,9 @@ object XGBoost extends Serializable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private class Watches private(
|
private class Watches private(
|
||||||
val train: DMatrix,
|
val train: DMatrix,
|
||||||
val test: DMatrix,
|
val test: DMatrix,
|
||||||
private val cacheDirName: Option[String]) {
|
private val cacheDirName: Option[String]) {
|
||||||
|
|
||||||
def toMap: Map[String, DMatrix] = Map("train" -> train, "test" -> test)
|
def toMap: Map[String, DMatrix] = Map("train" -> train, "test" -> test)
|
||||||
.filter { case (_, matrix) => matrix.rowNum > 0 }
|
.filter { case (_, matrix) => matrix.rowNum > 0 }
|
||||||
@ -342,59 +381,152 @@ private class Watches private(
|
|||||||
|
|
||||||
private object Watches {
|
private object Watches {
|
||||||
|
|
||||||
def buildGroups(groups: Seq[Int]): Seq[Int] = {
|
private def fromBaseMarginsToArray(baseMargins: Iterator[Float]): Option[Array[Float]] = {
|
||||||
val output = mutable.ArrayBuffer.empty[Int]
|
val builder = new mutable.ArrayBuilder.ofFloat()
|
||||||
var count = 1
|
var nTotal = 0
|
||||||
var lastGroup = groups.head
|
var nUndefined = 0
|
||||||
for (group <- groups.tail) {
|
while (baseMargins.hasNext) {
|
||||||
if (group != lastGroup) {
|
nTotal += 1
|
||||||
lastGroup = group
|
val baseMargin = baseMargins.next()
|
||||||
output += count
|
if (baseMargin.isNaN) {
|
||||||
count = 1
|
nUndefined += 1 // don't waste space for all-NaNs.
|
||||||
} else {
|
} else {
|
||||||
count += 1
|
builder += baseMargin
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
output += count
|
if (nUndefined == nTotal) {
|
||||||
output
|
None
|
||||||
|
} else if (nUndefined == 0) {
|
||||||
|
Some(builder.result())
|
||||||
|
} else {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
s"Encountered a partition with $nUndefined NaN base margin values. " +
|
||||||
|
s"If you want to specify base margin, ensure all values are non-NaN.")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
def apply(
|
def buildWatches(
|
||||||
params: Map[String, Any],
|
params: Map[String, Any],
|
||||||
labeledPoints: Iterator[XGBLabeledPoint],
|
labeledPoints: Iterator[XGBLabeledPoint],
|
||||||
baseMarginsOpt: Option[Array[Float]],
|
|
||||||
cacheDirName: Option[String]): Watches = {
|
cacheDirName: Option[String]): Watches = {
|
||||||
val trainTestRatio = params.get("train_test_ratio").map(_.toString.toDouble).getOrElse(1.0)
|
val trainTestRatio = params.get("train_test_ratio").map(_.toString.toDouble).getOrElse(1.0)
|
||||||
val seed = params.get("seed").map(_.toString.toLong).getOrElse(System.nanoTime())
|
val seed = params.get("seed").map(_.toString.toLong).getOrElse(System.nanoTime())
|
||||||
val r = new Random(seed)
|
val r = new Random(seed)
|
||||||
val testPoints = mutable.ArrayBuffer.empty[XGBLabeledPoint]
|
val testPoints = mutable.ArrayBuffer.empty[XGBLabeledPoint]
|
||||||
|
val trainBaseMargins = new mutable.ArrayBuilder.ofFloat
|
||||||
|
val testBaseMargins = new mutable.ArrayBuilder.ofFloat
|
||||||
val trainPoints = labeledPoints.filter { labeledPoint =>
|
val trainPoints = labeledPoints.filter { labeledPoint =>
|
||||||
val accepted = r.nextDouble() <= trainTestRatio
|
val accepted = r.nextDouble() <= trainTestRatio
|
||||||
if (!accepted) {
|
if (!accepted) {
|
||||||
testPoints += labeledPoint
|
testPoints += labeledPoint
|
||||||
|
testBaseMargins += labeledPoint.baseMargin
|
||||||
|
} else {
|
||||||
|
trainBaseMargins += labeledPoint.baseMargin
|
||||||
}
|
}
|
||||||
|
accepted
|
||||||
|
}
|
||||||
|
val trainMatrix = new DMatrix(trainPoints, cacheDirName.map(_ + "/train").orNull)
|
||||||
|
val testMatrix = new DMatrix(testPoints.iterator, cacheDirName.map(_ + "/test").orNull)
|
||||||
|
|
||||||
|
val trainMargin = fromBaseMarginsToArray(trainBaseMargins.result().iterator)
|
||||||
|
val testMargin = fromBaseMarginsToArray(testBaseMargins.result().iterator)
|
||||||
|
if (trainMargin.isDefined) trainMatrix.setBaseMargin(trainMargin.get)
|
||||||
|
if (testMargin.isDefined) testMatrix.setBaseMargin(testMargin.get)
|
||||||
|
|
||||||
|
new Watches(trainMatrix, testMatrix, cacheDirName)
|
||||||
|
}
|
||||||
|
|
||||||
|
def buildWatchesWithGroup(
|
||||||
|
params: Map[String, Any],
|
||||||
|
labeledPointGroups: Iterator[Array[XGBLabeledPoint]],
|
||||||
|
cacheDirName: Option[String]): Watches = {
|
||||||
|
val trainTestRatio = params.get("train_test_ratio").map(_.toString.toDouble).getOrElse(1.0)
|
||||||
|
val seed = params.get("seed").map(_.toString.toLong).getOrElse(System.nanoTime())
|
||||||
|
val r = new Random(seed)
|
||||||
|
val testPoints = mutable.ArrayBuilder.make[XGBLabeledPoint]
|
||||||
|
val trainBaseMargins = new mutable.ArrayBuilder.ofFloat
|
||||||
|
val testBaseMargins = new mutable.ArrayBuilder.ofFloat
|
||||||
|
val trainGroups = new mutable.ArrayBuilder.ofInt
|
||||||
|
val testGroups = new mutable.ArrayBuilder.ofInt
|
||||||
|
|
||||||
|
val trainLabelPointGroups = labeledPointGroups.filter { labeledPointGroup =>
|
||||||
|
val accepted = r.nextDouble() <= trainTestRatio
|
||||||
|
if (!accepted) {
|
||||||
|
labeledPointGroup.foreach(labeledPoint => {
|
||||||
|
testPoints += labeledPoint
|
||||||
|
testBaseMargins += labeledPoint.baseMargin
|
||||||
|
})
|
||||||
|
testGroups += labeledPointGroup.length
|
||||||
|
} else {
|
||||||
|
labeledPointGroup.foreach(trainBaseMargins += _.baseMargin)
|
||||||
|
trainGroups += labeledPointGroup.length
|
||||||
|
}
|
||||||
accepted
|
accepted
|
||||||
}
|
}
|
||||||
|
|
||||||
val (trainIter1, trainIter2) = trainPoints.duplicate
|
val trainPoints = trainLabelPointGroups.flatMap(_.iterator)
|
||||||
val trainMatrix = new DMatrix(trainIter1, cacheDirName.map(_ + "/train").orNull)
|
val trainMatrix = new DMatrix(trainPoints, cacheDirName.map(_ + "/train").orNull)
|
||||||
val trainGroups = buildGroups(trainIter2.map(_.group).toSeq).toArray
|
trainMatrix.setGroup(trainGroups.result())
|
||||||
trainMatrix.setGroup(trainGroups)
|
|
||||||
|
|
||||||
val testMatrix = new DMatrix(testPoints.iterator, cacheDirName.map(_ + "/test").orNull)
|
val testMatrix = new DMatrix(testPoints.result().iterator, cacheDirName.map(_ + "/test").orNull)
|
||||||
if (trainTestRatio < 1.0) {
|
if (trainTestRatio < 1.0) {
|
||||||
val testGroups = buildGroups(testPoints.map(_.group)).toArray
|
testMatrix.setGroup(testGroups.result())
|
||||||
testMatrix.setGroup(testGroups)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
r.setSeed(seed)
|
val trainMargin = fromBaseMarginsToArray(trainBaseMargins.result().iterator)
|
||||||
for (baseMargins <- baseMarginsOpt) {
|
val testMargin = fromBaseMarginsToArray(testBaseMargins.result().iterator)
|
||||||
val (trainMargin, testMargin) = baseMargins.partition(_ => r.nextDouble() <= trainTestRatio)
|
if (trainMargin.isDefined) trainMatrix.setBaseMargin(trainMargin.get)
|
||||||
trainMatrix.setBaseMargin(trainMargin)
|
if (testMargin.isDefined) testMatrix.setBaseMargin(testMargin.get)
|
||||||
testMatrix.setBaseMargin(testMargin)
|
|
||||||
}
|
|
||||||
|
|
||||||
new Watches(trainMatrix, testMatrix, cacheDirName)
|
new Watches(trainMatrix, testMatrix, cacheDirName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Within each RDD partition, group the <code>XGBLabeledPoint</code> by group id.</p>
|
||||||
|
* And the first and the last groups may not have all the items due to the data partition.
|
||||||
|
* <code>LabeledPointGroupIterator</code> orginaizes data in a tuple format:
|
||||||
|
* (isFistGroup || isLastGroup, Array[XGBLabeledPoint]).</p>
|
||||||
|
* The edge groups across partitions can be stitched together later.
|
||||||
|
* @param base collection of <code>XGBLabeledPoint</code>
|
||||||
|
*/
|
||||||
|
private[spark] class LabeledPointGroupIterator(base: Iterator[XGBLabeledPoint])
|
||||||
|
extends AbstractIterator[XGBLabeledPointGroup] {
|
||||||
|
|
||||||
|
private var firstPointOfNextGroup: XGBLabeledPoint = null
|
||||||
|
private var isNewGroup = true
|
||||||
|
|
||||||
|
override def hasNext: Boolean = {
|
||||||
|
return base.hasNext || isNewGroup
|
||||||
|
}
|
||||||
|
|
||||||
|
override def next(): XGBLabeledPointGroup = {
|
||||||
|
val builder = mutable.ArrayBuilder.make[XGBLabeledPoint]
|
||||||
|
var isFirstGroup = true
|
||||||
|
if (firstPointOfNextGroup != null) {
|
||||||
|
builder += firstPointOfNextGroup
|
||||||
|
isFirstGroup = false
|
||||||
|
}
|
||||||
|
|
||||||
|
isNewGroup = false
|
||||||
|
while (!isNewGroup && base.hasNext) {
|
||||||
|
val point = base.next()
|
||||||
|
val groupId = if (firstPointOfNextGroup != null) firstPointOfNextGroup.group else point.group
|
||||||
|
firstPointOfNextGroup = point
|
||||||
|
if (point.group == groupId) {
|
||||||
|
// add to current group
|
||||||
|
builder += point
|
||||||
|
} else {
|
||||||
|
// start a new group
|
||||||
|
isNewGroup = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val isLastGroup = !isNewGroup
|
||||||
|
val result = builder.result()
|
||||||
|
val group = XGBLabeledPointGroup(result(0).group, result, isFirstGroup || isLastGroup)
|
||||||
|
|
||||||
|
group
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|||||||
@ -196,7 +196,7 @@ class XGBoostClassifier (
|
|||||||
// All non-null param maps in XGBoostClassifier are in derivedXGBParamMap.
|
// All non-null param maps in XGBoostClassifier are in derivedXGBParamMap.
|
||||||
val (_booster, _metrics) = XGBoost.trainDistributed(instances, derivedXGBParamMap,
|
val (_booster, _metrics) = XGBoost.trainDistributed(instances, derivedXGBParamMap,
|
||||||
$(numRound), $(numWorkers), $(customObj), $(customEval), $(useExternalMemory),
|
$(numRound), $(numWorkers), $(customObj), $(customEval), $(useExternalMemory),
|
||||||
$(missing))
|
$(missing), hasGroup = false)
|
||||||
val model = new XGBoostClassificationModel(uid, _numClasses, _booster)
|
val model = new XGBoostClassificationModel(uid, _numClasses, _booster)
|
||||||
val summary = XGBoostTrainingSummary(_metrics)
|
val summary = XGBoostTrainingSummary(_metrics)
|
||||||
model.setSummary(summary)
|
model.setSummary(summary)
|
||||||
@ -517,3 +517,4 @@ object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel]
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -191,7 +191,7 @@ class XGBoostRegressor (
|
|||||||
// All non-null param maps in XGBoostRegressor are in derivedXGBParamMap.
|
// All non-null param maps in XGBoostRegressor are in derivedXGBParamMap.
|
||||||
val (_booster, _metrics) = XGBoost.trainDistributed(instances, derivedXGBParamMap,
|
val (_booster, _metrics) = XGBoost.trainDistributed(instances, derivedXGBParamMap,
|
||||||
$(numRound), $(numWorkers), $(customObj), $(customEval), $(useExternalMemory),
|
$(numRound), $(numWorkers), $(customObj), $(customEval), $(useExternalMemory),
|
||||||
$(missing))
|
$(missing), hasGroup = group != lit(-1))
|
||||||
val model = new XGBoostRegressionModel(uid, _booster)
|
val model = new XGBoostRegressionModel(uid, _booster)
|
||||||
val summary = XGBoostTrainingSummary(_metrics)
|
val summary = XGBoostTrainingSummary(_metrics)
|
||||||
model.setSummary(summary)
|
model.setSummary(summary)
|
||||||
|
|||||||
@ -173,7 +173,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
|
|||||||
val training2 = training1.withColumn("margin", functions.rand())
|
val training2 = training1.withColumn("margin", functions.rand())
|
||||||
val test = buildDataFrame(Classification.test)
|
val test = buildDataFrame(Classification.test)
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "binary:logistic", "test_train_split" -> "0.5",
|
"objective" -> "binary:logistic", "train_test_ratio" -> "1.0",
|
||||||
"num_round" -> 5, "num_workers" -> numWorkers)
|
"num_round" -> 5, "num_workers" -> numWorkers)
|
||||||
|
|
||||||
val xgb = new XGBoostClassifier(paramMap)
|
val xgb = new XGBoostClassifier(paramMap)
|
||||||
|
|||||||
@ -19,6 +19,7 @@ package ml.dmlc.xgboost4j.scala.spark
|
|||||||
import java.nio.file.Files
|
import java.nio.file.Files
|
||||||
import java.util.concurrent.LinkedBlockingDeque
|
import java.util.concurrent.LinkedBlockingDeque
|
||||||
import ml.dmlc.xgboost4j.java.Rabit
|
import ml.dmlc.xgboost4j.java.Rabit
|
||||||
|
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||||
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||||
@ -71,18 +72,16 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
|||||||
assert(collectedAllReduceResults.poll().sameElements(maxVec))
|
assert(collectedAllReduceResults.poll().sameElements(maxVec))
|
||||||
}
|
}
|
||||||
|
|
||||||
test("build RDD containing boosters with the specified worker number") {
|
test("distributed training with the specified worker number") {
|
||||||
val trainingRDD = sc.parallelize(Classification.train)
|
val trainingRDD = sc.parallelize(Classification.train)
|
||||||
val partitionedRDD = XGBoost.repartitionForTraining(trainingRDD, 2)
|
val (booster, metrics) = XGBoost.trainDistributed(
|
||||||
val boosterRDD = XGBoost.buildDistributedBoosters(
|
trainingRDD,
|
||||||
partitionedRDD,
|
|
||||||
List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "binary:logistic").toMap,
|
"objective" -> "binary:logistic").toMap,
|
||||||
new java.util.HashMap[String, String](),
|
round = 5, nWorkers = numWorkers, eval = null, obj = null, useExternalMemory = false,
|
||||||
round = 5, eval = null, obj = null, useExternalMemory = true,
|
hasGroup = false, missing = Float.NaN)
|
||||||
missing = Float.NaN, prevBooster = null)
|
|
||||||
val boosterCount = boosterRDD.count()
|
assert(booster != null)
|
||||||
assert(boosterCount === 2)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
test("training with external memory cache") {
|
test("training with external memory cache") {
|
||||||
@ -235,4 +234,37 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
|||||||
assert(error(prevModel._booster) > error(nextModel._booster))
|
assert(error(prevModel._booster) > error(nextModel._booster))
|
||||||
assert(error(nextModel._booster) < 0.1)
|
assert(error(nextModel._booster) < 0.1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("repartitionForTrainingGroup with group data") {
|
||||||
|
// test different splits to cover the corner cases.
|
||||||
|
for (split <- 1 to 20) {
|
||||||
|
val trainingRDD = sc.parallelize(Ranking.train, split)
|
||||||
|
val traingGroupsRDD = XGBoost.repartitionForTrainingGroup(trainingRDD, 4)
|
||||||
|
val trainingGroups: Array[Array[XGBLabeledPoint]] = traingGroupsRDD.collect()
|
||||||
|
// check the the order of the groups with group id.
|
||||||
|
// Ranking.train has 20 groups
|
||||||
|
assert(trainingGroups.length == 20)
|
||||||
|
|
||||||
|
// compare all points
|
||||||
|
val allPoints = trainingGroups.sortBy(_(0).group).flatten
|
||||||
|
assert(allPoints.length == Ranking.train.size)
|
||||||
|
for (i <- 0 to Ranking.train.size - 1) {
|
||||||
|
assert(allPoints(i).group == Ranking.train(i).group)
|
||||||
|
assert(allPoints(i).label == Ranking.train(i).label)
|
||||||
|
assert(allPoints(i).values.sameElements(Ranking.train(i).values))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("distributed training with group data") {
|
||||||
|
val trainingRDD = sc.parallelize(Ranking.train, 2)
|
||||||
|
val (booster, metrics) = XGBoost.trainDistributed(
|
||||||
|
trainingRDD,
|
||||||
|
List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
|
"objective" -> "binary:logistic").toMap,
|
||||||
|
round = 5, nWorkers = numWorkers, eval = null, obj = null, useExternalMemory = false,
|
||||||
|
hasGroup = true, missing = Float.NaN)
|
||||||
|
|
||||||
|
assert(booster != null)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user