|
|
|
|
@@ -19,7 +19,7 @@ package ml.dmlc.xgboost4j.scala.spark
|
|
|
|
|
import java.io.File
|
|
|
|
|
import java.nio.file.Files
|
|
|
|
|
|
|
|
|
|
import scala.collection.mutable
|
|
|
|
|
import scala.collection.{AbstractIterator, mutable}
|
|
|
|
|
import scala.util.Random
|
|
|
|
|
|
|
|
|
|
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
|
|
|
|
|
@@ -53,6 +53,17 @@ object TrackerConf {
|
|
|
|
|
def apply(): TrackerConf = TrackerConf(0L, "python")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Traing data group in a RDD partition.
|
|
|
|
|
* @param groupId The group id
|
|
|
|
|
* @param points Array of XGBLabeledPoint within the same group.
|
|
|
|
|
* @param isEdgeGroup whether it is a frist or last group in a RDD partition.
|
|
|
|
|
*/
|
|
|
|
|
private[spark] case class XGBLabeledPointGroup(
|
|
|
|
|
groupId: Int,
|
|
|
|
|
points: Array[XGBLabeledPoint],
|
|
|
|
|
isEdgeGroup: Boolean)
|
|
|
|
|
|
|
|
|
|
object XGBoost extends Serializable {
|
|
|
|
|
private val logger = LogFactory.getLog("XGBoostSpark")
|
|
|
|
|
|
|
|
|
|
@@ -74,78 +85,62 @@ object XGBoost extends Serializable {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private def fromBaseMarginsToArray(baseMargins: Iterator[Float]): Option[Array[Float]] = {
|
|
|
|
|
val builder = new mutable.ArrayBuilder.ofFloat()
|
|
|
|
|
var nTotal = 0
|
|
|
|
|
var nUndefined = 0
|
|
|
|
|
while (baseMargins.hasNext) {
|
|
|
|
|
nTotal += 1
|
|
|
|
|
val baseMargin = baseMargins.next()
|
|
|
|
|
if (baseMargin.isNaN) {
|
|
|
|
|
nUndefined += 1 // don't waste space for all-NaNs.
|
|
|
|
|
} else {
|
|
|
|
|
builder += baseMargin
|
|
|
|
|
private def removeMissingValuesWithGroup(
|
|
|
|
|
xgbLabelPointGroups: Iterator[Array[XGBLabeledPoint]],
|
|
|
|
|
missing: Float): Iterator[Array[XGBLabeledPoint]] = {
|
|
|
|
|
if (!missing.isNaN) {
|
|
|
|
|
xgbLabelPointGroups.map {
|
|
|
|
|
labeledPoints => XGBoost.removeMissingValues(labeledPoints.iterator, missing).toArray
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (nUndefined == nTotal) {
|
|
|
|
|
None
|
|
|
|
|
} else if (nUndefined == 0) {
|
|
|
|
|
Some(builder.result())
|
|
|
|
|
} else {
|
|
|
|
|
throw new IllegalArgumentException(
|
|
|
|
|
s"Encountered a partition with $nUndefined NaN base margin values. " +
|
|
|
|
|
s"If you want to specify base margin, ensure all values are non-NaN.")
|
|
|
|
|
xgbLabelPointGroups
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private[spark] def buildDistributedBoosters(
|
|
|
|
|
data: RDD[XGBLabeledPoint],
|
|
|
|
|
private def getCacheDirName(useExternalMemory: Boolean): Option[String] = {
|
|
|
|
|
val taskId = TaskContext.getPartitionId().toString
|
|
|
|
|
if (useExternalMemory) {
|
|
|
|
|
val dir = Files.createTempDirectory(s"${TaskContext.get().stageId()}-cache-$taskId")
|
|
|
|
|
Some(dir.toAbsolutePath.toString)
|
|
|
|
|
} else {
|
|
|
|
|
None
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private def buildDistributedBooster(
|
|
|
|
|
watches: Watches,
|
|
|
|
|
params: Map[String, Any],
|
|
|
|
|
rabitEnv: java.util.Map[String, String],
|
|
|
|
|
round: Int,
|
|
|
|
|
obj: ObjectiveTrait,
|
|
|
|
|
eval: EvalTrait,
|
|
|
|
|
useExternalMemory: Boolean,
|
|
|
|
|
missing: Float,
|
|
|
|
|
prevBooster: Booster
|
|
|
|
|
): RDD[(Booster, Map[String, Array[Float]])] = {
|
|
|
|
|
prevBooster: Booster)
|
|
|
|
|
: Iterator[(Booster, Map[String, Array[Float]])] = {
|
|
|
|
|
|
|
|
|
|
val partitionedBaseMargin = data.map(_.baseMargin)
|
|
|
|
|
// to workaround the empty partitions in training dataset,
|
|
|
|
|
// this might not be the best efficient implementation, see
|
|
|
|
|
// (https://github.com/dmlc/xgboost/issues/1277)
|
|
|
|
|
data.zipPartitions(partitionedBaseMargin) { (labeledPoints, baseMargins) =>
|
|
|
|
|
if (labeledPoints.isEmpty) {
|
|
|
|
|
throw new XGBoostError(
|
|
|
|
|
s"detected an empty partition in the training data, partition ID:" +
|
|
|
|
|
s" ${TaskContext.getPartitionId()}")
|
|
|
|
|
}
|
|
|
|
|
val taskId = TaskContext.getPartitionId().toString
|
|
|
|
|
val cacheDirName = if (useExternalMemory) {
|
|
|
|
|
val dir = Files.createTempDirectory(s"${TaskContext.get().stageId()}-cache-$taskId")
|
|
|
|
|
Some(dir.toAbsolutePath.toString)
|
|
|
|
|
} else {
|
|
|
|
|
None
|
|
|
|
|
}
|
|
|
|
|
rabitEnv.put("DMLC_TASK_ID", taskId)
|
|
|
|
|
Rabit.init(rabitEnv)
|
|
|
|
|
val watches = Watches(params,
|
|
|
|
|
removeMissingValues(labeledPoints, missing),
|
|
|
|
|
fromBaseMarginsToArray(baseMargins), cacheDirName)
|
|
|
|
|
if (watches.train.rowNum == 0) {
|
|
|
|
|
throw new XGBoostError(
|
|
|
|
|
s"detected an empty partition in the training data, partition ID:" +
|
|
|
|
|
s" ${TaskContext.getPartitionId()}")
|
|
|
|
|
}
|
|
|
|
|
val taskId = TaskContext.getPartitionId().toString
|
|
|
|
|
rabitEnv.put("DMLC_TASK_ID", taskId)
|
|
|
|
|
Rabit.init(rabitEnv)
|
|
|
|
|
|
|
|
|
|
try {
|
|
|
|
|
val numEarlyStoppingRounds = params.get("num_early_stopping_rounds")
|
|
|
|
|
.map(_.toString.toInt).getOrElse(0)
|
|
|
|
|
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round))
|
|
|
|
|
val booster = SXGBoost.train(watches.train, params, round,
|
|
|
|
|
watches.toMap, metrics, obj, eval,
|
|
|
|
|
earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
|
|
|
|
|
Iterator(booster -> watches.toMap.keys.zip(metrics).toMap)
|
|
|
|
|
} finally {
|
|
|
|
|
Rabit.shutdown()
|
|
|
|
|
watches.delete()
|
|
|
|
|
}
|
|
|
|
|
}.cache()
|
|
|
|
|
try {
|
|
|
|
|
val numEarlyStoppingRounds = params.get("num_early_stopping_rounds")
|
|
|
|
|
.map(_.toString.toInt).getOrElse(0)
|
|
|
|
|
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round))
|
|
|
|
|
val booster = SXGBoost.train(watches.train, params, round,
|
|
|
|
|
watches.toMap, metrics, obj, eval,
|
|
|
|
|
earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
|
|
|
|
|
Iterator(booster -> watches.toMap.keys.zip(metrics).toMap)
|
|
|
|
|
} finally {
|
|
|
|
|
Rabit.shutdown()
|
|
|
|
|
watches.delete()
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private def overrideParamsAccordingToTaskCPUs(
|
|
|
|
|
@@ -219,7 +214,8 @@ object XGBoost extends Serializable {
|
|
|
|
|
obj: ObjectiveTrait = null,
|
|
|
|
|
eval: EvalTrait = null,
|
|
|
|
|
useExternalMemory: Boolean = false,
|
|
|
|
|
missing: Float = Float.NaN): (Booster, Map[String, Array[Float]]) = {
|
|
|
|
|
missing: Float = Float.NaN,
|
|
|
|
|
hasGroup: Boolean = false): (Booster, Map[String, Array[Float]]) = {
|
|
|
|
|
validateSparkSslConf(trainingData.context)
|
|
|
|
|
if (params.contains("tree_method")) {
|
|
|
|
|
require(params("tree_method") != "hist", "xgboost4j-spark does not support fast histogram" +
|
|
|
|
|
@@ -244,7 +240,6 @@ object XGBoost extends Serializable {
|
|
|
|
|
" an instance of Long.")
|
|
|
|
|
}
|
|
|
|
|
val (checkpointPath, checkpointInterval) = CheckpointManager.extractParams(params)
|
|
|
|
|
val partitionedData = repartitionForTraining(trainingData, nWorkers)
|
|
|
|
|
|
|
|
|
|
val sc = trainingData.sparkContext
|
|
|
|
|
val checkpointManager = new CheckpointManager(sc, checkpointPath)
|
|
|
|
|
@@ -258,9 +253,29 @@ object XGBoost extends Serializable {
|
|
|
|
|
try {
|
|
|
|
|
val overriddenParams = overrideParamsAccordingToTaskCPUs(params, sc)
|
|
|
|
|
val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, nWorkers)
|
|
|
|
|
val boostersAndMetrics = buildDistributedBoosters(partitionedData, overriddenParams,
|
|
|
|
|
tracker.getWorkerEnvs, checkpointRound, obj, eval, useExternalMemory, missing,
|
|
|
|
|
prevBooster)
|
|
|
|
|
val rabitEnv = tracker.getWorkerEnvs
|
|
|
|
|
val boostersAndMetrics = hasGroup match {
|
|
|
|
|
case true => {
|
|
|
|
|
val partitionedData = repartitionForTrainingGroup(trainingData, nWorkers)
|
|
|
|
|
partitionedData.mapPartitions(labeledPointGroups => {
|
|
|
|
|
val watches = Watches.buildWatchesWithGroup(params,
|
|
|
|
|
removeMissingValuesWithGroup(labeledPointGroups, missing),
|
|
|
|
|
getCacheDirName(useExternalMemory))
|
|
|
|
|
buildDistributedBooster(watches, params, rabitEnv, checkpointRound,
|
|
|
|
|
obj, eval, prevBooster)
|
|
|
|
|
}).cache()
|
|
|
|
|
}
|
|
|
|
|
case false => {
|
|
|
|
|
val partitionedData = repartitionForTraining(trainingData, nWorkers)
|
|
|
|
|
partitionedData.mapPartitions(labeledPoints => {
|
|
|
|
|
val watches = Watches.buildWatches(params,
|
|
|
|
|
removeMissingValues(labeledPoints, missing),
|
|
|
|
|
getCacheDirName(useExternalMemory))
|
|
|
|
|
buildDistributedBooster(watches, params, rabitEnv, checkpointRound,
|
|
|
|
|
obj, eval, prevBooster)
|
|
|
|
|
}).cache()
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
val sparkJobThread = new Thread() {
|
|
|
|
|
override def run() {
|
|
|
|
|
// force the job
|
|
|
|
|
@@ -278,13 +293,12 @@ object XGBoost extends Serializable {
|
|
|
|
|
checkpointManager.updateCheckpoint(prevBooster)
|
|
|
|
|
}
|
|
|
|
|
(booster, metrics)
|
|
|
|
|
} finally {
|
|
|
|
|
tracker.stop()
|
|
|
|
|
}
|
|
|
|
|
} finally {
|
|
|
|
|
tracker.stop()
|
|
|
|
|
}
|
|
|
|
|
}.last
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private[spark] def repartitionForTraining(trainingData: RDD[XGBLabeledPoint], nWorkers: Int) = {
|
|
|
|
|
if (trainingData.getNumPartitions != nWorkers) {
|
|
|
|
|
logger.info(s"repartitioning training set to $nWorkers partitions")
|
|
|
|
|
@@ -294,6 +308,31 @@ object XGBoost extends Serializable {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private[spark] def repartitionForTrainingGroup(
|
|
|
|
|
trainingData: RDD[XGBLabeledPoint], nWorkers: Int): RDD[Array[XGBLabeledPoint]] = {
|
|
|
|
|
val normalGroups: RDD[Array[XGBLabeledPoint]] = trainingData.mapPartitions(
|
|
|
|
|
// LabeledPointGroupIterator returns (Boolean, Array[XGBLabeledPoint])
|
|
|
|
|
new LabeledPointGroupIterator(_)).filter(!_.isEdgeGroup).map(_.points)
|
|
|
|
|
|
|
|
|
|
// edge groups with partition id.
|
|
|
|
|
val edgeGroups: RDD[(Int, XGBLabeledPointGroup)] = trainingData.mapPartitions(
|
|
|
|
|
new LabeledPointGroupIterator(_)).filter(_.isEdgeGroup).map(
|
|
|
|
|
group => (TaskContext.getPartitionId(), group))
|
|
|
|
|
|
|
|
|
|
// group chunks from different partitions together by group id in XGBLabeledPoint.
|
|
|
|
|
// use groupBy instead of aggregateBy since all groups within a partition have unique groud ids.
|
|
|
|
|
val stitchedGroups: RDD[Array[XGBLabeledPoint]] = edgeGroups.groupBy(_._2.groupId).map(
|
|
|
|
|
groups => {
|
|
|
|
|
val it: Iterable[(Int, XGBLabeledPointGroup)] = groups._2
|
|
|
|
|
// sorted by partition id and merge list of Array[XGBLabeledPoint] into one array
|
|
|
|
|
it.toArray.sortBy(_._1).map(_._2.points).flatten
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
var allGroups = normalGroups.union(stitchedGroups)
|
|
|
|
|
logger.info(s"repartitioning training group set to $nWorkers partitions")
|
|
|
|
|
allGroups.repartition(nWorkers)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private def postTrackerReturnProcessing(
|
|
|
|
|
trackerReturnVal: Int,
|
|
|
|
|
distributedBoostersAndMetrics: RDD[(Booster, Map[String, Array[Float]])],
|
|
|
|
|
@@ -321,9 +360,9 @@ object XGBoost extends Serializable {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private class Watches private(
|
|
|
|
|
val train: DMatrix,
|
|
|
|
|
val test: DMatrix,
|
|
|
|
|
private val cacheDirName: Option[String]) {
|
|
|
|
|
val train: DMatrix,
|
|
|
|
|
val test: DMatrix,
|
|
|
|
|
private val cacheDirName: Option[String]) {
|
|
|
|
|
|
|
|
|
|
def toMap: Map[String, DMatrix] = Map("train" -> train, "test" -> test)
|
|
|
|
|
.filter { case (_, matrix) => matrix.rowNum > 0 }
|
|
|
|
|
@@ -342,59 +381,152 @@ private class Watches private(
|
|
|
|
|
|
|
|
|
|
private object Watches {
|
|
|
|
|
|
|
|
|
|
def buildGroups(groups: Seq[Int]): Seq[Int] = {
|
|
|
|
|
val output = mutable.ArrayBuffer.empty[Int]
|
|
|
|
|
var count = 1
|
|
|
|
|
var lastGroup = groups.head
|
|
|
|
|
for (group <- groups.tail) {
|
|
|
|
|
if (group != lastGroup) {
|
|
|
|
|
lastGroup = group
|
|
|
|
|
output += count
|
|
|
|
|
count = 1
|
|
|
|
|
private def fromBaseMarginsToArray(baseMargins: Iterator[Float]): Option[Array[Float]] = {
|
|
|
|
|
val builder = new mutable.ArrayBuilder.ofFloat()
|
|
|
|
|
var nTotal = 0
|
|
|
|
|
var nUndefined = 0
|
|
|
|
|
while (baseMargins.hasNext) {
|
|
|
|
|
nTotal += 1
|
|
|
|
|
val baseMargin = baseMargins.next()
|
|
|
|
|
if (baseMargin.isNaN) {
|
|
|
|
|
nUndefined += 1 // don't waste space for all-NaNs.
|
|
|
|
|
} else {
|
|
|
|
|
count += 1
|
|
|
|
|
builder += baseMargin
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
output += count
|
|
|
|
|
output
|
|
|
|
|
if (nUndefined == nTotal) {
|
|
|
|
|
None
|
|
|
|
|
} else if (nUndefined == 0) {
|
|
|
|
|
Some(builder.result())
|
|
|
|
|
} else {
|
|
|
|
|
throw new IllegalArgumentException(
|
|
|
|
|
s"Encountered a partition with $nUndefined NaN base margin values. " +
|
|
|
|
|
s"If you want to specify base margin, ensure all values are non-NaN.")
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def apply(
|
|
|
|
|
def buildWatches(
|
|
|
|
|
params: Map[String, Any],
|
|
|
|
|
labeledPoints: Iterator[XGBLabeledPoint],
|
|
|
|
|
baseMarginsOpt: Option[Array[Float]],
|
|
|
|
|
cacheDirName: Option[String]): Watches = {
|
|
|
|
|
val trainTestRatio = params.get("train_test_ratio").map(_.toString.toDouble).getOrElse(1.0)
|
|
|
|
|
val seed = params.get("seed").map(_.toString.toLong).getOrElse(System.nanoTime())
|
|
|
|
|
val r = new Random(seed)
|
|
|
|
|
val testPoints = mutable.ArrayBuffer.empty[XGBLabeledPoint]
|
|
|
|
|
val trainBaseMargins = new mutable.ArrayBuilder.ofFloat
|
|
|
|
|
val testBaseMargins = new mutable.ArrayBuilder.ofFloat
|
|
|
|
|
val trainPoints = labeledPoints.filter { labeledPoint =>
|
|
|
|
|
val accepted = r.nextDouble() <= trainTestRatio
|
|
|
|
|
if (!accepted) {
|
|
|
|
|
testPoints += labeledPoint
|
|
|
|
|
testBaseMargins += labeledPoint.baseMargin
|
|
|
|
|
} else {
|
|
|
|
|
trainBaseMargins += labeledPoint.baseMargin
|
|
|
|
|
}
|
|
|
|
|
accepted
|
|
|
|
|
}
|
|
|
|
|
val trainMatrix = new DMatrix(trainPoints, cacheDirName.map(_ + "/train").orNull)
|
|
|
|
|
val testMatrix = new DMatrix(testPoints.iterator, cacheDirName.map(_ + "/test").orNull)
|
|
|
|
|
|
|
|
|
|
val trainMargin = fromBaseMarginsToArray(trainBaseMargins.result().iterator)
|
|
|
|
|
val testMargin = fromBaseMarginsToArray(testBaseMargins.result().iterator)
|
|
|
|
|
if (trainMargin.isDefined) trainMatrix.setBaseMargin(trainMargin.get)
|
|
|
|
|
if (testMargin.isDefined) testMatrix.setBaseMargin(testMargin.get)
|
|
|
|
|
|
|
|
|
|
new Watches(trainMatrix, testMatrix, cacheDirName)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def buildWatchesWithGroup(
|
|
|
|
|
params: Map[String, Any],
|
|
|
|
|
labeledPointGroups: Iterator[Array[XGBLabeledPoint]],
|
|
|
|
|
cacheDirName: Option[String]): Watches = {
|
|
|
|
|
val trainTestRatio = params.get("train_test_ratio").map(_.toString.toDouble).getOrElse(1.0)
|
|
|
|
|
val seed = params.get("seed").map(_.toString.toLong).getOrElse(System.nanoTime())
|
|
|
|
|
val r = new Random(seed)
|
|
|
|
|
val testPoints = mutable.ArrayBuilder.make[XGBLabeledPoint]
|
|
|
|
|
val trainBaseMargins = new mutable.ArrayBuilder.ofFloat
|
|
|
|
|
val testBaseMargins = new mutable.ArrayBuilder.ofFloat
|
|
|
|
|
val trainGroups = new mutable.ArrayBuilder.ofInt
|
|
|
|
|
val testGroups = new mutable.ArrayBuilder.ofInt
|
|
|
|
|
|
|
|
|
|
val trainLabelPointGroups = labeledPointGroups.filter { labeledPointGroup =>
|
|
|
|
|
val accepted = r.nextDouble() <= trainTestRatio
|
|
|
|
|
if (!accepted) {
|
|
|
|
|
labeledPointGroup.foreach(labeledPoint => {
|
|
|
|
|
testPoints += labeledPoint
|
|
|
|
|
testBaseMargins += labeledPoint.baseMargin
|
|
|
|
|
})
|
|
|
|
|
testGroups += labeledPointGroup.length
|
|
|
|
|
} else {
|
|
|
|
|
labeledPointGroup.foreach(trainBaseMargins += _.baseMargin)
|
|
|
|
|
trainGroups += labeledPointGroup.length
|
|
|
|
|
}
|
|
|
|
|
accepted
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
val (trainIter1, trainIter2) = trainPoints.duplicate
|
|
|
|
|
val trainMatrix = new DMatrix(trainIter1, cacheDirName.map(_ + "/train").orNull)
|
|
|
|
|
val trainGroups = buildGroups(trainIter2.map(_.group).toSeq).toArray
|
|
|
|
|
trainMatrix.setGroup(trainGroups)
|
|
|
|
|
val trainPoints = trainLabelPointGroups.flatMap(_.iterator)
|
|
|
|
|
val trainMatrix = new DMatrix(trainPoints, cacheDirName.map(_ + "/train").orNull)
|
|
|
|
|
trainMatrix.setGroup(trainGroups.result())
|
|
|
|
|
|
|
|
|
|
val testMatrix = new DMatrix(testPoints.iterator, cacheDirName.map(_ + "/test").orNull)
|
|
|
|
|
val testMatrix = new DMatrix(testPoints.result().iterator, cacheDirName.map(_ + "/test").orNull)
|
|
|
|
|
if (trainTestRatio < 1.0) {
|
|
|
|
|
val testGroups = buildGroups(testPoints.map(_.group)).toArray
|
|
|
|
|
testMatrix.setGroup(testGroups)
|
|
|
|
|
testMatrix.setGroup(testGroups.result())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
r.setSeed(seed)
|
|
|
|
|
for (baseMargins <- baseMarginsOpt) {
|
|
|
|
|
val (trainMargin, testMargin) = baseMargins.partition(_ => r.nextDouble() <= trainTestRatio)
|
|
|
|
|
trainMatrix.setBaseMargin(trainMargin)
|
|
|
|
|
testMatrix.setBaseMargin(testMargin)
|
|
|
|
|
}
|
|
|
|
|
val trainMargin = fromBaseMarginsToArray(trainBaseMargins.result().iterator)
|
|
|
|
|
val testMargin = fromBaseMarginsToArray(testBaseMargins.result().iterator)
|
|
|
|
|
if (trainMargin.isDefined) trainMatrix.setBaseMargin(trainMargin.get)
|
|
|
|
|
if (testMargin.isDefined) testMatrix.setBaseMargin(testMargin.get)
|
|
|
|
|
|
|
|
|
|
new Watches(trainMatrix, testMatrix, cacheDirName)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Within each RDD partition, group the <code>XGBLabeledPoint</code> by group id.</p>
|
|
|
|
|
* And the first and the last groups may not have all the items due to the data partition.
|
|
|
|
|
* <code>LabeledPointGroupIterator</code> orginaizes data in a tuple format:
|
|
|
|
|
* (isFistGroup || isLastGroup, Array[XGBLabeledPoint]).</p>
|
|
|
|
|
* The edge groups across partitions can be stitched together later.
|
|
|
|
|
* @param base collection of <code>XGBLabeledPoint</code>
|
|
|
|
|
*/
|
|
|
|
|
private[spark] class LabeledPointGroupIterator(base: Iterator[XGBLabeledPoint])
|
|
|
|
|
extends AbstractIterator[XGBLabeledPointGroup] {
|
|
|
|
|
|
|
|
|
|
private var firstPointOfNextGroup: XGBLabeledPoint = null
|
|
|
|
|
private var isNewGroup = true
|
|
|
|
|
|
|
|
|
|
override def hasNext: Boolean = {
|
|
|
|
|
return base.hasNext || isNewGroup
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
override def next(): XGBLabeledPointGroup = {
|
|
|
|
|
val builder = mutable.ArrayBuilder.make[XGBLabeledPoint]
|
|
|
|
|
var isFirstGroup = true
|
|
|
|
|
if (firstPointOfNextGroup != null) {
|
|
|
|
|
builder += firstPointOfNextGroup
|
|
|
|
|
isFirstGroup = false
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
isNewGroup = false
|
|
|
|
|
while (!isNewGroup && base.hasNext) {
|
|
|
|
|
val point = base.next()
|
|
|
|
|
val groupId = if (firstPointOfNextGroup != null) firstPointOfNextGroup.group else point.group
|
|
|
|
|
firstPointOfNextGroup = point
|
|
|
|
|
if (point.group == groupId) {
|
|
|
|
|
// add to current group
|
|
|
|
|
builder += point
|
|
|
|
|
} else {
|
|
|
|
|
// start a new group
|
|
|
|
|
isNewGroup = true
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
val isLastGroup = !isNewGroup
|
|
|
|
|
val result = builder.result()
|
|
|
|
|
val group = XGBLabeledPointGroup(result(0).group, result, isFirstGroup || isLastGroup)
|
|
|
|
|
|
|
|
|
|
group
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|