[jvm-packages] Rework the train pipeline (#7401)
1. Add PreXGBoost to build RDD[Watches] from Dataset 2. Feed RDD[Watches] built from PreXGBoost to XGBoost to train
This commit is contained in:
parent
8df0a252b7
commit
cb685607b2
@ -16,6 +16,8 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
|
||||
import org.apache.spark.HashPartitioner
|
||||
@ -99,43 +101,129 @@ object DataUtils extends Serializable {
|
||||
}
|
||||
}
|
||||
|
||||
/** Packed parameters used by [[convertDataFrameToXGBLabeledPointRDDs]] */
|
||||
private[spark] case class PackedParams(labelCol: Column,
|
||||
featuresCol: Column,
|
||||
weight: Column,
|
||||
baseMargin: Column,
|
||||
group: Option[Column],
|
||||
numWorkers: Int,
|
||||
deterministicPartition: Boolean)
|
||||
|
||||
/**
|
||||
* convertDataFrameToXGBLabeledPointRDDs converts DataFrames to an array of RDD[XGBLabeledPoint]
|
||||
*
|
||||
* First, it serves converting each instance of input into XGBLabeledPoint
|
||||
* Second, it repartition the RDD to the number workers.
|
||||
*
|
||||
*/
|
||||
private[spark] def convertDataFrameToXGBLabeledPointRDDs(
|
||||
labelCol: Column,
|
||||
featuresCol: Column,
|
||||
weight: Column,
|
||||
baseMargin: Column,
|
||||
group: Option[Column],
|
||||
numWorkers: Int,
|
||||
deterministicPartition: Boolean,
|
||||
dataFrames: DataFrame*): Array[RDD[XGBLabeledPoint]] = {
|
||||
val selectedColumns = group.map(groupCol => Seq(labelCol.cast(FloatType),
|
||||
featuresCol,
|
||||
weight.cast(FloatType),
|
||||
groupCol.cast(IntegerType),
|
||||
baseMargin.cast(FloatType))).getOrElse(Seq(labelCol.cast(FloatType),
|
||||
featuresCol,
|
||||
weight.cast(FloatType),
|
||||
baseMargin.cast(FloatType)))
|
||||
val arrayOfRDDs = dataFrames.toArray.map {
|
||||
df => df.select(selectedColumns: _*).rdd.map {
|
||||
case row @ Row(label: Float, features: Vector, weight: Float, group: Int,
|
||||
baseMargin: Float) =>
|
||||
val (size, indices, values) = features match {
|
||||
case v: SparseVector => (v.size, v.indices, v.values.map(_.toFloat))
|
||||
case v: DenseVector => (v.size, null, v.values.map(_.toFloat))
|
||||
packedParams: PackedParams,
|
||||
dataFrames: DataFrame*): Array[RDD[XGBLabeledPoint]] = {
|
||||
|
||||
packedParams match {
|
||||
case j @ PackedParams(labelCol, featuresCol, weight, baseMargin, group, numWorkers,
|
||||
deterministicPartition) =>
|
||||
val selectedColumns = group.map(groupCol => Seq(labelCol.cast(FloatType),
|
||||
featuresCol,
|
||||
weight.cast(FloatType),
|
||||
groupCol.cast(IntegerType),
|
||||
baseMargin.cast(FloatType))).getOrElse(Seq(labelCol.cast(FloatType),
|
||||
featuresCol,
|
||||
weight.cast(FloatType),
|
||||
baseMargin.cast(FloatType)))
|
||||
val arrayOfRDDs = dataFrames.toArray.map {
|
||||
df => df.select(selectedColumns: _*).rdd.map {
|
||||
case row @ Row(label: Float, features: Vector, weight: Float, group: Int,
|
||||
baseMargin: Float) =>
|
||||
val (size, indices, values) = features match {
|
||||
case v: SparseVector => (v.size, v.indices, v.values.map(_.toFloat))
|
||||
case v: DenseVector => (v.size, null, v.values.map(_.toFloat))
|
||||
}
|
||||
val xgbLp = XGBLabeledPoint(label, size, indices, values, weight, group, baseMargin)
|
||||
attachPartitionKey(row, deterministicPartition, numWorkers, xgbLp)
|
||||
case row @ Row(label: Float, features: Vector, weight: Float, baseMargin: Float) =>
|
||||
val (size, indices, values) = features match {
|
||||
case v: SparseVector => (v.size, v.indices, v.values.map(_.toFloat))
|
||||
case v: DenseVector => (v.size, null, v.values.map(_.toFloat))
|
||||
}
|
||||
val xgbLp = XGBLabeledPoint(label, size, indices, values, weight,
|
||||
baseMargin = baseMargin)
|
||||
attachPartitionKey(row, deterministicPartition, numWorkers, xgbLp)
|
||||
}
|
||||
val xgbLp = XGBLabeledPoint(label, size, indices, values, weight, group, baseMargin)
|
||||
attachPartitionKey(row, deterministicPartition, numWorkers, xgbLp)
|
||||
case row @ Row(label: Float, features: Vector, weight: Float, baseMargin: Float) =>
|
||||
val (size, indices, values) = features match {
|
||||
case v: SparseVector => (v.size, v.indices, v.values.map(_.toFloat))
|
||||
case v: DenseVector => (v.size, null, v.values.map(_.toFloat))
|
||||
}
|
||||
val xgbLp = XGBLabeledPoint(label, size, indices, values, weight, baseMargin = baseMargin)
|
||||
attachPartitionKey(row, deterministicPartition, numWorkers, xgbLp)
|
||||
}
|
||||
}
|
||||
repartitionRDDs(deterministicPartition, numWorkers, arrayOfRDDs)
|
||||
|
||||
case _ => throw new IllegalArgumentException("Wrong PackedParams") // never reach here
|
||||
}
|
||||
repartitionRDDs(deterministicPartition, numWorkers, arrayOfRDDs)
|
||||
|
||||
}
|
||||
|
||||
private[spark] def processMissingValues(
|
||||
xgbLabelPoints: Iterator[XGBLabeledPoint],
|
||||
missing: Float,
|
||||
allowNonZeroMissing: Boolean): Iterator[XGBLabeledPoint] = {
|
||||
if (!missing.isNaN) {
|
||||
removeMissingValues(verifyMissingSetting(xgbLabelPoints, missing, allowNonZeroMissing),
|
||||
missing, (v: Float) => v != missing)
|
||||
} else {
|
||||
removeMissingValues(verifyMissingSetting(xgbLabelPoints, missing, allowNonZeroMissing),
|
||||
missing, (v: Float) => !v.isNaN)
|
||||
}
|
||||
}
|
||||
|
||||
private[spark] def processMissingValuesWithGroup(
|
||||
xgbLabelPointGroups: Iterator[Array[XGBLabeledPoint]],
|
||||
missing: Float,
|
||||
allowNonZeroMissing: Boolean): Iterator[Array[XGBLabeledPoint]] = {
|
||||
if (!missing.isNaN) {
|
||||
xgbLabelPointGroups.map {
|
||||
labeledPoints => processMissingValues(
|
||||
labeledPoints.iterator,
|
||||
missing,
|
||||
allowNonZeroMissing
|
||||
).toArray
|
||||
}
|
||||
} else {
|
||||
xgbLabelPointGroups
|
||||
}
|
||||
}
|
||||
|
||||
private def removeMissingValues(
|
||||
xgbLabelPoints: Iterator[XGBLabeledPoint],
|
||||
missing: Float,
|
||||
keepCondition: Float => Boolean): Iterator[XGBLabeledPoint] = {
|
||||
xgbLabelPoints.map { labeledPoint =>
|
||||
val indicesBuilder = new mutable.ArrayBuilder.ofInt()
|
||||
val valuesBuilder = new mutable.ArrayBuilder.ofFloat()
|
||||
for ((value, i) <- labeledPoint.values.zipWithIndex if keepCondition(value)) {
|
||||
indicesBuilder += (if (labeledPoint.indices == null) i else labeledPoint.indices(i))
|
||||
valuesBuilder += value
|
||||
}
|
||||
labeledPoint.copy(indices = indicesBuilder.result(), values = valuesBuilder.result())
|
||||
}
|
||||
}
|
||||
|
||||
private def verifyMissingSetting(
|
||||
xgbLabelPoints: Iterator[XGBLabeledPoint],
|
||||
missing: Float,
|
||||
allowNonZeroMissing: Boolean): Iterator[XGBLabeledPoint] = {
|
||||
if (missing != 0.0f && !allowNonZeroMissing) {
|
||||
xgbLabelPoints.map(labeledPoint => {
|
||||
if (labeledPoint.indices != null) {
|
||||
throw new RuntimeException(s"you can only specify missing value as 0.0 (the currently" +
|
||||
s" set value $missing) when you have SparseVector or Empty vector as your feature" +
|
||||
s" format. If you didn't use Spark's VectorAssembler class to build your feature " +
|
||||
s"vector but instead did so in a way that preserves zeros in your feature vector " +
|
||||
s"you can avoid this check by using the 'allow_non_zero_for_missing parameter'" +
|
||||
s" (only use if you know what you are doing)")
|
||||
}
|
||||
labeledPoint
|
||||
})
|
||||
} else {
|
||||
xgbLabelPoints
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
@ -0,0 +1,421 @@
|
||||
/*
|
||||
Copyright (c) 2021 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import java.nio.file.Files
|
||||
|
||||
import scala.collection.{AbstractIterator, mutable}
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.spark.DataUtils.PackedParams
|
||||
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon
|
||||
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||
import org.apache.spark.sql.functions.{col, lit}
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
import org.apache.commons.logging.LogFactory
|
||||
|
||||
import org.apache.spark.TaskContext
|
||||
import org.apache.spark.ml.Estimator
|
||||
import org.apache.spark.storage.StorageLevel
|
||||
|
||||
/**
|
||||
* PreXGBoost converts Dataset[_] to RDD[[Watches]]
|
||||
*/
|
||||
object PreXGBoost {
|
||||
|
||||
private val logger = LogFactory.getLog("XGBoostSpark")
|
||||
|
||||
private lazy val defaultBaseMarginColumn = lit(Float.NaN)
|
||||
private lazy val defaultWeightColumn = lit(1.0)
|
||||
private lazy val defaultGroupColumn = lit(-1)
|
||||
|
||||
/**
|
||||
* Convert the Dataset[_] to RDD[Watches] which will be fed to XGBoost
|
||||
*
|
||||
* @param estimator supports XGBoostClassifier and XGBoostRegressor
|
||||
* @param dataset the training data
|
||||
* @param params all user defined and defaulted params
|
||||
* @return [[XGBoostExecutionParams]] => (RDD[[Watches]], Option[ RDD[_] ])
|
||||
* RDD[Watches] will be used as the training input
|
||||
* Option[RDD[_]\] is the optional cached RDD
|
||||
*/
|
||||
def buildDatasetToRDD(
|
||||
estimator: Estimator[_],
|
||||
dataset: Dataset[_],
|
||||
params: Map[String, Any]): XGBoostExecutionParams => (RDD[Watches], Option[RDD[_]]) = {
|
||||
|
||||
val (packedParams, evalSet) = estimator match {
|
||||
case est: XGBoostEstimatorCommon =>
|
||||
// get weight column, if weight is not defined, default to lit(1.0)
|
||||
val weight = if (!est.isDefined(est.weightCol) || est.getWeightCol.isEmpty) {
|
||||
defaultWeightColumn
|
||||
} else col(est.getWeightCol)
|
||||
|
||||
// get base-margin column, if base-margin is not defined, default to lit(Float.NaN)
|
||||
val baseMargin = if (!est.isDefined(est.baseMarginCol) || est.getBaseMarginCol.isEmpty) {
|
||||
defaultBaseMarginColumn
|
||||
} else col(est.getBaseMarginCol)
|
||||
|
||||
val group = est match {
|
||||
case regressor: XGBoostRegressor =>
|
||||
// get group column, if group is not defined, default to lit(-1)
|
||||
Some(
|
||||
if (!regressor.isDefined(regressor.groupCol) || regressor.getGroupCol.isEmpty) {
|
||||
defaultGroupColumn
|
||||
} else col(regressor.getGroupCol)
|
||||
)
|
||||
case _ => None
|
||||
|
||||
}
|
||||
|
||||
(PackedParams(col(est.getLabelCol), col(est.getFeaturesCol), weight, baseMargin, group,
|
||||
est.getNumWorkers, est.needDeterministicRepartitioning), est.getEvalSets(params))
|
||||
|
||||
case _ => throw new RuntimeException("Unsupporting " + estimator)
|
||||
}
|
||||
|
||||
// transform the training Dataset[_] to RDD[XGBLabeledPoint]
|
||||
val trainingSet: RDD[XGBLabeledPoint] = DataUtils.convertDataFrameToXGBLabeledPointRDDs(
|
||||
packedParams, dataset.asInstanceOf[DataFrame]).head
|
||||
|
||||
// transform the eval Dataset[_] to RDD[XGBLabeledPoint]
|
||||
val evalRDDMap = evalSet.map {
|
||||
case (name, dataFrame) => (name,
|
||||
DataUtils.convertDataFrameToXGBLabeledPointRDDs(packedParams, dataFrame).head)
|
||||
}
|
||||
|
||||
val hasGroup = packedParams.group.map(_ != defaultGroupColumn).getOrElse(false)
|
||||
|
||||
xgbExecParams: XGBoostExecutionParams =>
|
||||
composeInputData(trainingSet, hasGroup, packedParams.numWorkers) match {
|
||||
case Left(trainingData) =>
|
||||
val cachedRDD = if (xgbExecParams.cacheTrainingSet) {
|
||||
Some(trainingData.persist(StorageLevel.MEMORY_AND_DISK))
|
||||
} else None
|
||||
(trainForRanking(trainingData, xgbExecParams, evalRDDMap), cachedRDD)
|
||||
case Right(trainingData) =>
|
||||
val cachedRDD = if (xgbExecParams.cacheTrainingSet) {
|
||||
Some(trainingData.persist(StorageLevel.MEMORY_AND_DISK))
|
||||
} else None
|
||||
(trainForNonRanking(trainingData, xgbExecParams, evalRDDMap), cachedRDD)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Converting the RDD[XGBLabeledPoint] to the function to build RDD[Watches]
|
||||
*
|
||||
* @param trainingSet the input training RDD[XGBLabeledPoint]
|
||||
* @param evalRDDMap the eval set
|
||||
* @param hasGroup if has group
|
||||
* @return function to build (RDD[Watches], the cached RDD)
|
||||
*/
|
||||
private[spark] def buildRDDLabeledPointToRDDWatches(
|
||||
trainingSet: RDD[XGBLabeledPoint],
|
||||
evalRDDMap: Map[String, RDD[XGBLabeledPoint]] = Map(),
|
||||
hasGroup: Boolean = false): XGBoostExecutionParams => (RDD[Watches], Option[RDD[_]]) = {
|
||||
|
||||
xgbExecParams: XGBoostExecutionParams =>
|
||||
composeInputData(trainingSet, hasGroup, xgbExecParams.numWorkers) match {
|
||||
case Left(trainingData) =>
|
||||
val cachedRDD = if (xgbExecParams.cacheTrainingSet) {
|
||||
Some(trainingData.persist(StorageLevel.MEMORY_AND_DISK))
|
||||
} else None
|
||||
(trainForRanking(trainingData, xgbExecParams, evalRDDMap), cachedRDD)
|
||||
case Right(trainingData) =>
|
||||
val cachedRDD = if (xgbExecParams.cacheTrainingSet) {
|
||||
Some(trainingData.persist(StorageLevel.MEMORY_AND_DISK))
|
||||
} else None
|
||||
(trainForNonRanking(trainingData, xgbExecParams, evalRDDMap), cachedRDD)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Transform RDD according to group column
|
||||
*
|
||||
* @param trainingData the input XGBLabeledPoint RDD
|
||||
* @param hasGroup if has group column
|
||||
* @param nWorkers total xgboost number workers to run xgboost tasks
|
||||
* @return Either: the left is RDD with group, and the right is RDD without group
|
||||
*/
|
||||
private def composeInputData(
|
||||
trainingData: RDD[XGBLabeledPoint],
|
||||
hasGroup: Boolean,
|
||||
nWorkers: Int): Either[RDD[Array[XGBLabeledPoint]], RDD[XGBLabeledPoint]] = {
|
||||
if (hasGroup) {
|
||||
Left(repartitionForTrainingGroup(trainingData, nWorkers))
|
||||
} else {
|
||||
Right(trainingData)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Repartition trainingData with group directly may cause data chaos, since the same group data
|
||||
* may be split into different partitions.
|
||||
*
|
||||
* The first step is to aggregate the same group into same partition
|
||||
* The second step is to repartition to nWorkers
|
||||
*
|
||||
* TODO, Could we repartition trainingData on group?
|
||||
*/
|
||||
private[spark] def repartitionForTrainingGroup(trainingData: RDD[XGBLabeledPoint],
|
||||
nWorkers: Int): RDD[Array[XGBLabeledPoint]] = {
|
||||
val allGroups = aggByGroupInfo(trainingData)
|
||||
logger.info(s"repartitioning training group set to $nWorkers partitions")
|
||||
allGroups.repartition(nWorkers)
|
||||
}
|
||||
|
||||
/**
|
||||
* Build RDD[Watches] for Ranking
|
||||
* @param trainingData the training data RDD
|
||||
* @param xgbExecutionParams xgboost execution params
|
||||
* @param evalSetsMap the eval RDD
|
||||
* @return RDD[Watches]
|
||||
*/
|
||||
private def trainForRanking(
|
||||
trainingData: RDD[Array[XGBLabeledPoint]],
|
||||
xgbExecutionParam: XGBoostExecutionParams,
|
||||
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[Watches] = {
|
||||
if (evalSetsMap.isEmpty) {
|
||||
trainingData.mapPartitions(labeledPointGroups => {
|
||||
val watches = Watches.buildWatchesWithGroup(xgbExecutionParam,
|
||||
DataUtils.processMissingValuesWithGroup(labeledPointGroups, xgbExecutionParam.missing,
|
||||
xgbExecutionParam.allowNonZeroForMissing),
|
||||
getCacheDirName(xgbExecutionParam.useExternalMemory))
|
||||
Iterator.single(watches)
|
||||
}).cache()
|
||||
} else {
|
||||
coPartitionGroupSets(trainingData, evalSetsMap, xgbExecutionParam.numWorkers).mapPartitions(
|
||||
labeledPointGroupSets => {
|
||||
val watches = Watches.buildWatchesWithGroup(
|
||||
labeledPointGroupSets.map {
|
||||
case (name, iter) => (name, DataUtils.processMissingValuesWithGroup(iter,
|
||||
xgbExecutionParam.missing, xgbExecutionParam.allowNonZeroForMissing))
|
||||
},
|
||||
getCacheDirName(xgbExecutionParam.useExternalMemory))
|
||||
Iterator.single(watches)
|
||||
}).cache()
|
||||
}
|
||||
}
|
||||
|
||||
private def coPartitionGroupSets(
|
||||
aggedTrainingSet: RDD[Array[XGBLabeledPoint]],
|
||||
evalSets: Map[String, RDD[XGBLabeledPoint]],
|
||||
nWorkers: Int): RDD[(String, Iterator[Array[XGBLabeledPoint]])] = {
|
||||
val repartitionedDatasets = Map("train" -> aggedTrainingSet) ++ evalSets.map {
|
||||
case (name, rdd) => {
|
||||
val aggedRdd = aggByGroupInfo(rdd)
|
||||
if (aggedRdd.getNumPartitions != nWorkers) {
|
||||
name -> aggedRdd.repartition(nWorkers)
|
||||
} else {
|
||||
name -> aggedRdd
|
||||
}
|
||||
}
|
||||
}
|
||||
repartitionedDatasets.foldLeft(aggedTrainingSet.sparkContext.parallelize(
|
||||
Array.fill[(String, Iterator[Array[XGBLabeledPoint]])](nWorkers)(null), nWorkers)) {
|
||||
case (rddOfIterWrapper, (name, rddOfIter)) =>
|
||||
rddOfIterWrapper.zipPartitions(rddOfIter) {
|
||||
(itrWrapper, itr) =>
|
||||
if (!itr.hasNext) {
|
||||
logger.error("when specifying eval sets as dataframes, you have to ensure that " +
|
||||
"the number of elements in each dataframe is larger than the number of workers")
|
||||
throw new Exception("too few elements in evaluation sets")
|
||||
}
|
||||
val itrArray = itrWrapper.toArray
|
||||
if (itrArray.head != null) {
|
||||
new IteratorWrapper(itrArray :+ (name -> itr))
|
||||
} else {
|
||||
new IteratorWrapper(Array(name -> itr))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def aggByGroupInfo(trainingData: RDD[XGBLabeledPoint]) = {
|
||||
val normalGroups: RDD[Array[XGBLabeledPoint]] = trainingData.mapPartitions(
|
||||
// LabeledPointGroupIterator returns (Boolean, Array[XGBLabeledPoint])
|
||||
new LabeledPointGroupIterator(_)).filter(!_.isEdgeGroup).map(_.points)
|
||||
|
||||
// edge groups with partition id.
|
||||
val edgeGroups: RDD[(Int, XGBLabeledPointGroup)] = trainingData.mapPartitions(
|
||||
new LabeledPointGroupIterator(_)).filter(_.isEdgeGroup).map(
|
||||
group => (TaskContext.getPartitionId(), group))
|
||||
|
||||
// group chunks from different partitions together by group id in XGBLabeledPoint.
|
||||
// use groupBy instead of aggregateBy since all groups within a partition have unique group ids.
|
||||
val stitchedGroups: RDD[Array[XGBLabeledPoint]] = edgeGroups.groupBy(_._2.groupId).map(
|
||||
groups => {
|
||||
val it: Iterable[(Int, XGBLabeledPointGroup)] = groups._2
|
||||
// sorted by partition id and merge list of Array[XGBLabeledPoint] into one array
|
||||
it.toArray.sortBy(_._1).flatMap(_._2.points)
|
||||
})
|
||||
normalGroups.union(stitchedGroups)
|
||||
}
|
||||
|
||||
/**
|
||||
* Build RDD[Watches] for Non-Ranking
|
||||
* @param trainingData the training data RDD
|
||||
* @param xgbExecutionParams xgboost execution params
|
||||
* @param evalSetsMap the eval RDD
|
||||
* @return RDD[Watches]
|
||||
*/
|
||||
private def trainForNonRanking(
|
||||
trainingData: RDD[XGBLabeledPoint],
|
||||
xgbExecutionParams: XGBoostExecutionParams,
|
||||
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[Watches] = {
|
||||
if (evalSetsMap.isEmpty) {
|
||||
trainingData.mapPartitions { labeledPoints => {
|
||||
val watches = Watches.buildWatches(xgbExecutionParams,
|
||||
DataUtils.processMissingValues(labeledPoints, xgbExecutionParams.missing,
|
||||
xgbExecutionParams.allowNonZeroForMissing),
|
||||
getCacheDirName(xgbExecutionParams.useExternalMemory))
|
||||
Iterator.single(watches)
|
||||
}}.cache()
|
||||
} else {
|
||||
coPartitionNoGroupSets(trainingData, evalSetsMap, xgbExecutionParams.numWorkers).
|
||||
mapPartitions {
|
||||
nameAndLabeledPointSets =>
|
||||
val watches = Watches.buildWatches(
|
||||
nameAndLabeledPointSets.map {
|
||||
case (name, iter) => (name, DataUtils.processMissingValues(iter,
|
||||
xgbExecutionParams.missing, xgbExecutionParams.allowNonZeroForMissing))
|
||||
},
|
||||
getCacheDirName(xgbExecutionParams.useExternalMemory))
|
||||
Iterator.single(watches)
|
||||
}.cache()
|
||||
}
|
||||
}
|
||||
|
||||
private def coPartitionNoGroupSets(
|
||||
trainingData: RDD[XGBLabeledPoint],
|
||||
evalSets: Map[String, RDD[XGBLabeledPoint]],
|
||||
nWorkers: Int) = {
|
||||
// eval_sets is supposed to be set by the caller of [[trainDistributed]]
|
||||
val allDatasets = Map("train" -> trainingData) ++ evalSets
|
||||
val repartitionedDatasets = allDatasets.map { case (name, rdd) =>
|
||||
if (rdd.getNumPartitions != nWorkers) {
|
||||
(name, rdd.repartition(nWorkers))
|
||||
} else {
|
||||
(name, rdd)
|
||||
}
|
||||
}
|
||||
repartitionedDatasets.foldLeft(trainingData.sparkContext.parallelize(
|
||||
Array.fill[(String, Iterator[XGBLabeledPoint])](nWorkers)(null), nWorkers)) {
|
||||
case (rddOfIterWrapper, (name, rddOfIter)) =>
|
||||
rddOfIterWrapper.zipPartitions(rddOfIter) {
|
||||
(itrWrapper, itr) =>
|
||||
if (!itr.hasNext) {
|
||||
logger.error("when specifying eval sets as dataframes, you have to ensure that " +
|
||||
"the number of elements in each dataframe is larger than the number of workers")
|
||||
throw new Exception("too few elements in evaluation sets")
|
||||
}
|
||||
val itrArray = itrWrapper.toArray
|
||||
if (itrArray.head != null) {
|
||||
new IteratorWrapper(itrArray :+ (name -> itr))
|
||||
} else {
|
||||
new IteratorWrapper(Array(name -> itr))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def getCacheDirName(useExternalMemory: Boolean): Option[String] = {
|
||||
val taskId = TaskContext.getPartitionId().toString
|
||||
if (useExternalMemory) {
|
||||
val dir = Files.createTempDirectory(s"${TaskContext.get().stageId()}-cache-$taskId")
|
||||
Some(dir.toAbsolutePath.toString)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
class IteratorWrapper[T](arrayOfXGBLabeledPoints: Array[(String, Iterator[T])])
|
||||
extends Iterator[(String, Iterator[T])] {
|
||||
|
||||
private var currentIndex = 0
|
||||
|
||||
override def hasNext: Boolean = currentIndex <= arrayOfXGBLabeledPoints.length - 1
|
||||
|
||||
override def next(): (String, Iterator[T]) = {
|
||||
currentIndex += 1
|
||||
arrayOfXGBLabeledPoints(currentIndex - 1)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Training data group in a RDD partition.
|
||||
*
|
||||
* @param groupId The group id
|
||||
* @param points Array of XGBLabeledPoint within the same group.
|
||||
* @param isEdgeGroup whether it is a first or last group in a RDD partition.
|
||||
*/
|
||||
private[spark] case class XGBLabeledPointGroup(
|
||||
groupId: Int,
|
||||
points: Array[XGBLabeledPoint],
|
||||
isEdgeGroup: Boolean)
|
||||
|
||||
/**
|
||||
* Within each RDD partition, group the <code>XGBLabeledPoint</code> by group id.</p>
|
||||
* And the first and the last groups may not have all the items due to the data partition.
|
||||
* <code>LabeledPointGroupIterator</code> organizes data in a tuple format:
|
||||
* (isFistGroup || isLastGroup, Array[XGBLabeledPoint]).</p>
|
||||
* The edge groups across partitions can be stitched together later.
|
||||
* @param base collection of <code>XGBLabeledPoint</code>
|
||||
*/
|
||||
private[spark] class LabeledPointGroupIterator(base: Iterator[XGBLabeledPoint])
|
||||
extends AbstractIterator[XGBLabeledPointGroup] {
|
||||
|
||||
private var firstPointOfNextGroup: XGBLabeledPoint = null
|
||||
private var isNewGroup = false
|
||||
|
||||
override def hasNext: Boolean = {
|
||||
base.hasNext || isNewGroup
|
||||
}
|
||||
|
||||
override def next(): XGBLabeledPointGroup = {
|
||||
val builder = mutable.ArrayBuilder.make[XGBLabeledPoint]
|
||||
var isFirstGroup = true
|
||||
if (firstPointOfNextGroup != null) {
|
||||
builder += firstPointOfNextGroup
|
||||
isFirstGroup = false
|
||||
}
|
||||
|
||||
isNewGroup = false
|
||||
while (!isNewGroup && base.hasNext) {
|
||||
val point = base.next()
|
||||
val groupId = if (firstPointOfNextGroup != null) firstPointOfNextGroup.group else point.group
|
||||
firstPointOfNextGroup = point
|
||||
if (point.group == groupId) {
|
||||
// add to current group
|
||||
builder += point
|
||||
} else {
|
||||
// start a new group
|
||||
isNewGroup = true
|
||||
}
|
||||
}
|
||||
|
||||
val isLastGroup = !isNewGroup
|
||||
val result = builder.result()
|
||||
val group = XGBLabeledPointGroup(result(0).group, result, isFirstGroup || isLastGroup)
|
||||
|
||||
group
|
||||
}
|
||||
}
|
||||
@ -17,9 +17,8 @@
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import java.io.File
|
||||
import java.nio.file.Files
|
||||
|
||||
import scala.collection.{AbstractIterator, mutable}
|
||||
import scala.collection.mutable
|
||||
import scala.util.Random
|
||||
import scala.collection.JavaConverters._
|
||||
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
|
||||
@ -34,8 +33,6 @@ import org.apache.hadoop.fs.FileSystem
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext}
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.apache.spark.storage.StorageLevel
|
||||
|
||||
|
||||
/**
|
||||
* Rabit tracker configurations.
|
||||
@ -50,7 +47,7 @@ import org.apache.spark.storage.StorageLevel
|
||||
* in Scala without Python components, and with full support of timeouts.
|
||||
* The Scala implementation is currently experimental, use at your own risk.
|
||||
*/
|
||||
case class TrackerConf(workerConnectionTimeout: Long, trackerImpl: String)
|
||||
case class TrackerConf(workerConnectionTimeout: Long, trackerImpl: String )
|
||||
|
||||
object TrackerConf {
|
||||
def apply(): TrackerConf = TrackerConf(0L, "python")
|
||||
@ -61,7 +58,7 @@ private[this] case class XGBoostExecutionEarlyStoppingParams(numEarlyStoppingRou
|
||||
|
||||
private[this] case class XGBoostExecutionInputParams(trainTestRatio: Double, seed: Long)
|
||||
|
||||
private[this] case class XGBoostExecutionParams(
|
||||
private[spark] case class XGBoostExecutionParams(
|
||||
numWorkers: Int,
|
||||
numRounds: Int,
|
||||
useExternalMemory: Boolean,
|
||||
@ -257,96 +254,9 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Traing data group in a RDD partition.
|
||||
* @param groupId The group id
|
||||
* @param points Array of XGBLabeledPoint within the same group.
|
||||
* @param isEdgeGroup whether it is a frist or last group in a RDD partition.
|
||||
*/
|
||||
private[spark] case class XGBLabeledPointGroup(
|
||||
groupId: Int,
|
||||
points: Array[XGBLabeledPoint],
|
||||
isEdgeGroup: Boolean)
|
||||
|
||||
object XGBoost extends Serializable {
|
||||
private val logger = LogFactory.getLog("XGBoostSpark")
|
||||
|
||||
private def verifyMissingSetting(
|
||||
xgbLabelPoints: Iterator[XGBLabeledPoint],
|
||||
missing: Float,
|
||||
allowNonZeroMissing: Boolean): Iterator[XGBLabeledPoint] = {
|
||||
if (missing != 0.0f && !allowNonZeroMissing) {
|
||||
xgbLabelPoints.map(labeledPoint => {
|
||||
if (labeledPoint.indices != null) {
|
||||
throw new RuntimeException(s"you can only specify missing value as 0.0 (the currently" +
|
||||
s" set value $missing) when you have SparseVector or Empty vector as your feature" +
|
||||
s" format. If you didn't use Spark's VectorAssembler class to build your feature " +
|
||||
s"vector but instead did so in a way that preserves zeros in your feature vector " +
|
||||
s"you can avoid this check by using the 'allow_non_zero_for_missing parameter'" +
|
||||
s" (only use if you know what you are doing)")
|
||||
}
|
||||
labeledPoint
|
||||
})
|
||||
} else {
|
||||
xgbLabelPoints
|
||||
}
|
||||
}
|
||||
|
||||
private def removeMissingValues(
|
||||
xgbLabelPoints: Iterator[XGBLabeledPoint],
|
||||
missing: Float,
|
||||
keepCondition: Float => Boolean): Iterator[XGBLabeledPoint] = {
|
||||
xgbLabelPoints.map { labeledPoint =>
|
||||
val indicesBuilder = new mutable.ArrayBuilder.ofInt()
|
||||
val valuesBuilder = new mutable.ArrayBuilder.ofFloat()
|
||||
for ((value, i) <- labeledPoint.values.zipWithIndex if keepCondition(value)) {
|
||||
indicesBuilder += (if (labeledPoint.indices == null) i else labeledPoint.indices(i))
|
||||
valuesBuilder += value
|
||||
}
|
||||
labeledPoint.copy(indices = indicesBuilder.result(), values = valuesBuilder.result())
|
||||
}
|
||||
}
|
||||
|
||||
private[spark] def processMissingValues(
|
||||
xgbLabelPoints: Iterator[XGBLabeledPoint],
|
||||
missing: Float,
|
||||
allowNonZeroMissing: Boolean): Iterator[XGBLabeledPoint] = {
|
||||
if (!missing.isNaN) {
|
||||
removeMissingValues(verifyMissingSetting(xgbLabelPoints, missing, allowNonZeroMissing),
|
||||
missing, (v: Float) => v != missing)
|
||||
} else {
|
||||
removeMissingValues(verifyMissingSetting(xgbLabelPoints, missing, allowNonZeroMissing),
|
||||
missing, (v: Float) => !v.isNaN)
|
||||
}
|
||||
}
|
||||
|
||||
private def processMissingValuesWithGroup(
|
||||
xgbLabelPointGroups: Iterator[Array[XGBLabeledPoint]],
|
||||
missing: Float,
|
||||
allowNonZeroMissing: Boolean): Iterator[Array[XGBLabeledPoint]] = {
|
||||
if (!missing.isNaN) {
|
||||
xgbLabelPointGroups.map {
|
||||
labeledPoints => XGBoost.processMissingValues(
|
||||
labeledPoints.iterator,
|
||||
missing,
|
||||
allowNonZeroMissing
|
||||
).toArray
|
||||
}
|
||||
} else {
|
||||
xgbLabelPointGroups
|
||||
}
|
||||
}
|
||||
|
||||
private def getCacheDirName(useExternalMemory: Boolean): Option[String] = {
|
||||
val taskId = TaskContext.getPartitionId().toString
|
||||
if (useExternalMemory) {
|
||||
val dir = Files.createTempDirectory(s"${TaskContext.get().stageId()}-cache-$taskId")
|
||||
Some(dir.toAbsolutePath.toString)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
private def getGPUAddrFromResources: Int = {
|
||||
val tc = TaskContext.get()
|
||||
if (tc == null) {
|
||||
@ -437,150 +347,22 @@ object XGBoost extends Serializable {
|
||||
tracker
|
||||
}
|
||||
|
||||
class IteratorWrapper[T](arrayOfXGBLabeledPoints: Array[(String, Iterator[T])])
|
||||
extends Iterator[(String, Iterator[T])] {
|
||||
|
||||
private var currentIndex = 0
|
||||
|
||||
override def hasNext: Boolean = currentIndex <= arrayOfXGBLabeledPoints.length - 1
|
||||
|
||||
override def next(): (String, Iterator[T]) = {
|
||||
currentIndex += 1
|
||||
arrayOfXGBLabeledPoints(currentIndex - 1)
|
||||
}
|
||||
}
|
||||
|
||||
private def coPartitionNoGroupSets(
|
||||
trainingData: RDD[XGBLabeledPoint],
|
||||
evalSets: Map[String, RDD[XGBLabeledPoint]],
|
||||
nWorkers: Int) = {
|
||||
// eval_sets is supposed to be set by the caller of [[trainDistributed]]
|
||||
val allDatasets = Map("train" -> trainingData) ++ evalSets
|
||||
val repartitionedDatasets = allDatasets.map{case (name, rdd) =>
|
||||
if (rdd.getNumPartitions != nWorkers) {
|
||||
(name, rdd.repartition(nWorkers))
|
||||
} else {
|
||||
(name, rdd)
|
||||
}
|
||||
}
|
||||
repartitionedDatasets.foldLeft(trainingData.sparkContext.parallelize(
|
||||
Array.fill[(String, Iterator[XGBLabeledPoint])](nWorkers)(null), nWorkers)){
|
||||
case (rddOfIterWrapper, (name, rddOfIter)) =>
|
||||
rddOfIterWrapper.zipPartitions(rddOfIter){
|
||||
(itrWrapper, itr) =>
|
||||
if (!itr.hasNext) {
|
||||
logger.error("when specifying eval sets as dataframes, you have to ensure that " +
|
||||
"the number of elements in each dataframe is larger than the number of workers")
|
||||
throw new Exception("too few elements in evaluation sets")
|
||||
}
|
||||
val itrArray = itrWrapper.toArray
|
||||
if (itrArray.head != null) {
|
||||
new IteratorWrapper(itrArray :+ (name -> itr))
|
||||
} else {
|
||||
new IteratorWrapper(Array(name -> itr))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def trainForNonRanking(
|
||||
trainingData: RDD[XGBLabeledPoint],
|
||||
xgbExecutionParams: XGBoostExecutionParams,
|
||||
rabitEnv: java.util.Map[String, String],
|
||||
prevBooster: Booster,
|
||||
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
|
||||
if (evalSetsMap.isEmpty) {
|
||||
trainingData.mapPartitions(labeledPoints => {
|
||||
val watches = Watches.buildWatches(xgbExecutionParams,
|
||||
processMissingValues(labeledPoints, xgbExecutionParams.missing,
|
||||
xgbExecutionParams.allowNonZeroForMissing),
|
||||
getCacheDirName(xgbExecutionParams.useExternalMemory))
|
||||
buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, xgbExecutionParams.obj,
|
||||
xgbExecutionParams.eval, prevBooster)
|
||||
}).cache()
|
||||
} else {
|
||||
coPartitionNoGroupSets(trainingData, evalSetsMap, xgbExecutionParams.numWorkers).
|
||||
mapPartitions {
|
||||
nameAndLabeledPointSets =>
|
||||
val watches = Watches.buildWatches(
|
||||
nameAndLabeledPointSets.map {
|
||||
case (name, iter) => (name, processMissingValues(iter,
|
||||
xgbExecutionParams.missing, xgbExecutionParams.allowNonZeroForMissing))
|
||||
},
|
||||
getCacheDirName(xgbExecutionParams.useExternalMemory))
|
||||
buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, xgbExecutionParams.obj,
|
||||
xgbExecutionParams.eval, prevBooster)
|
||||
}.cache()
|
||||
}
|
||||
}
|
||||
|
||||
private def trainForRanking(
|
||||
trainingData: RDD[Array[XGBLabeledPoint]],
|
||||
xgbExecutionParam: XGBoostExecutionParams,
|
||||
rabitEnv: java.util.Map[String, String],
|
||||
prevBooster: Booster,
|
||||
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
|
||||
if (evalSetsMap.isEmpty) {
|
||||
trainingData.mapPartitions(labeledPointGroups => {
|
||||
val watches = Watches.buildWatchesWithGroup(xgbExecutionParam,
|
||||
processMissingValuesWithGroup(labeledPointGroups, xgbExecutionParam.missing,
|
||||
xgbExecutionParam.allowNonZeroForMissing),
|
||||
getCacheDirName(xgbExecutionParam.useExternalMemory))
|
||||
buildDistributedBooster(watches, xgbExecutionParam, rabitEnv,
|
||||
xgbExecutionParam.obj, xgbExecutionParam.eval, prevBooster)
|
||||
}).cache()
|
||||
} else {
|
||||
coPartitionGroupSets(trainingData, evalSetsMap, xgbExecutionParam.numWorkers).mapPartitions(
|
||||
labeledPointGroupSets => {
|
||||
val watches = Watches.buildWatchesWithGroup(
|
||||
labeledPointGroupSets.map {
|
||||
case (name, iter) => (name, processMissingValuesWithGroup(iter,
|
||||
xgbExecutionParam.missing, xgbExecutionParam.allowNonZeroForMissing))
|
||||
},
|
||||
getCacheDirName(xgbExecutionParam.useExternalMemory))
|
||||
buildDistributedBooster(watches, xgbExecutionParam, rabitEnv,
|
||||
xgbExecutionParam.obj,
|
||||
xgbExecutionParam.eval,
|
||||
prevBooster)
|
||||
}).cache()
|
||||
}
|
||||
}
|
||||
|
||||
private def cacheData(ifCacheDataBoolean: Boolean, input: RDD[_]): RDD[_] = {
|
||||
if (ifCacheDataBoolean) input.persist(StorageLevel.MEMORY_AND_DISK) else input
|
||||
}
|
||||
|
||||
private def composeInputData(
|
||||
trainingData: RDD[XGBLabeledPoint],
|
||||
ifCacheDataBoolean: Boolean,
|
||||
hasGroup: Boolean,
|
||||
nWorkers: Int): Either[RDD[Array[XGBLabeledPoint]], RDD[XGBLabeledPoint]] = {
|
||||
if (hasGroup) {
|
||||
val repartitionedData = repartitionForTrainingGroup(trainingData, nWorkers)
|
||||
Left(cacheData(ifCacheDataBoolean, repartitionedData).
|
||||
asInstanceOf[RDD[Array[XGBLabeledPoint]]])
|
||||
} else {
|
||||
Right(cacheData(ifCacheDataBoolean, trainingData).asInstanceOf[RDD[XGBLabeledPoint]])
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @return A tuple of the booster and the metrics used to build training summary
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
private[spark] def trainDistributed(
|
||||
trainingData: RDD[XGBLabeledPoint],
|
||||
params: Map[String, Any],
|
||||
hasGroup: Boolean = false,
|
||||
evalSetsMap: Map[String, RDD[XGBLabeledPoint]] = Map()):
|
||||
sc: SparkContext,
|
||||
buildTrainingData: XGBoostExecutionParams => (RDD[Watches], Option[RDD[_]]),
|
||||
params: Map[String, Any]):
|
||||
(Booster, Map[String, Array[Float]]) = {
|
||||
|
||||
logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}")
|
||||
val xgbParamsFactory = new XGBoostExecutionParamsFactory(params, trainingData.sparkContext)
|
||||
|
||||
val xgbParamsFactory = new XGBoostExecutionParamsFactory(params, sc)
|
||||
val xgbExecParams = xgbParamsFactory.buildXGBRuntimeParams
|
||||
val xgbRabitParams = xgbParamsFactory.buildRabitParams.asJava
|
||||
val sc = trainingData.sparkContext
|
||||
val transformedTrainingData = composeInputData(trainingData, xgbExecParams.cacheTrainingSet,
|
||||
hasGroup, xgbExecParams.numWorkers)
|
||||
|
||||
val prevBooster = xgbExecParams.checkpointParam.map { checkpointParam =>
|
||||
val checkpointManager = new ExternalCheckpointManager(
|
||||
checkpointParam.checkpointPath,
|
||||
@ -588,6 +370,10 @@ object XGBoost extends Serializable {
|
||||
checkpointManager.cleanUpHigherVersions(xgbExecParams.numRounds)
|
||||
checkpointManager.loadCheckpointAsScalaBooster()
|
||||
}.orNull
|
||||
|
||||
// Get the training data RDD and the cachedRDD
|
||||
val (trainingRDD, optionalCachedRDD) = buildTrainingData(xgbExecParams)
|
||||
|
||||
try {
|
||||
// Train for every ${savingRound} rounds and save the partially completed booster
|
||||
val tracker = startTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
|
||||
@ -599,13 +385,21 @@ object XGBoost extends Serializable {
|
||||
|
||||
tracker.getWorkerEnvs().putAll(xgbRabitParams)
|
||||
val rabitEnv = tracker.getWorkerEnvs
|
||||
val boostersAndMetrics = if (hasGroup) {
|
||||
trainForRanking(transformedTrainingData.left.get, xgbExecParams, rabitEnv, prevBooster,
|
||||
evalSetsMap)
|
||||
} else {
|
||||
trainForNonRanking(transformedTrainingData.right.get, xgbExecParams, rabitEnv,
|
||||
prevBooster, evalSetsMap)
|
||||
}
|
||||
|
||||
val boostersAndMetrics = trainingRDD.mapPartitions { iter => {
|
||||
var optionWatches: Option[Watches] = None
|
||||
|
||||
// take the first Watches to train
|
||||
if (iter.hasNext) {
|
||||
optionWatches = Some(iter.next())
|
||||
}
|
||||
|
||||
optionWatches.map { watches => buildDistributedBooster(watches, xgbExecParams, rabitEnv,
|
||||
xgbExecParams.obj, xgbExecParams.eval, prevBooster)}
|
||||
.getOrElse(throw new RuntimeException("No Watches to train"))
|
||||
|
||||
}}.cache()
|
||||
|
||||
val sparkJobThread = new Thread() {
|
||||
override def run() {
|
||||
// force the job
|
||||
@ -642,85 +436,11 @@ object XGBoost extends Serializable {
|
||||
// if the job was aborted due to an exception
|
||||
logger.error("the job was aborted due to ", t)
|
||||
if (xgbExecParams.killSparkContextOnWorkerFailure) {
|
||||
trainingData.sparkContext.stop()
|
||||
sc.stop()
|
||||
}
|
||||
throw t
|
||||
} finally {
|
||||
uncacheTrainingData(xgbExecParams.cacheTrainingSet, transformedTrainingData)
|
||||
}
|
||||
}
|
||||
|
||||
private def uncacheTrainingData(
|
||||
cacheTrainingSet: Boolean,
|
||||
transformedTrainingData: Either[RDD[Array[XGBLabeledPoint]], RDD[XGBLabeledPoint]]): Unit = {
|
||||
if (cacheTrainingSet) {
|
||||
if (transformedTrainingData.isLeft) {
|
||||
transformedTrainingData.left.get.unpersist()
|
||||
} else {
|
||||
transformedTrainingData.right.get.unpersist()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def aggByGroupInfo(trainingData: RDD[XGBLabeledPoint]) = {
|
||||
val normalGroups: RDD[Array[XGBLabeledPoint]] = trainingData.mapPartitions(
|
||||
// LabeledPointGroupIterator returns (Boolean, Array[XGBLabeledPoint])
|
||||
new LabeledPointGroupIterator(_)).filter(!_.isEdgeGroup).map(_.points)
|
||||
|
||||
// edge groups with partition id.
|
||||
val edgeGroups: RDD[(Int, XGBLabeledPointGroup)] = trainingData.mapPartitions(
|
||||
new LabeledPointGroupIterator(_)).filter(_.isEdgeGroup).map(
|
||||
group => (TaskContext.getPartitionId(), group))
|
||||
|
||||
// group chunks from different partitions together by group id in XGBLabeledPoint.
|
||||
// use groupBy instead of aggregateBy since all groups within a partition have unique group ids.
|
||||
val stitchedGroups: RDD[Array[XGBLabeledPoint]] = edgeGroups.groupBy(_._2.groupId).map(
|
||||
groups => {
|
||||
val it: Iterable[(Int, XGBLabeledPointGroup)] = groups._2
|
||||
// sorted by partition id and merge list of Array[XGBLabeledPoint] into one array
|
||||
it.toArray.sortBy(_._1).flatMap(_._2.points)
|
||||
})
|
||||
normalGroups.union(stitchedGroups)
|
||||
}
|
||||
|
||||
private[spark] def repartitionForTrainingGroup(
|
||||
trainingData: RDD[XGBLabeledPoint], nWorkers: Int): RDD[Array[XGBLabeledPoint]] = {
|
||||
val allGroups = aggByGroupInfo(trainingData)
|
||||
logger.info(s"repartitioning training group set to $nWorkers partitions")
|
||||
allGroups.repartition(nWorkers)
|
||||
}
|
||||
|
||||
private def coPartitionGroupSets(
|
||||
aggedTrainingSet: RDD[Array[XGBLabeledPoint]],
|
||||
evalSets: Map[String, RDD[XGBLabeledPoint]],
|
||||
nWorkers: Int): RDD[(String, Iterator[Array[XGBLabeledPoint]])] = {
|
||||
val repartitionedDatasets = Map("train" -> aggedTrainingSet) ++ evalSets.map {
|
||||
case (name, rdd) => {
|
||||
val aggedRdd = aggByGroupInfo(rdd)
|
||||
if (aggedRdd.getNumPartitions != nWorkers) {
|
||||
name -> aggedRdd.repartition(nWorkers)
|
||||
} else {
|
||||
name -> aggedRdd
|
||||
}
|
||||
}
|
||||
}
|
||||
repartitionedDatasets.foldLeft(aggedTrainingSet.sparkContext.parallelize(
|
||||
Array.fill[(String, Iterator[Array[XGBLabeledPoint]])](nWorkers)(null), nWorkers)){
|
||||
case (rddOfIterWrapper, (name, rddOfIter)) =>
|
||||
rddOfIterWrapper.zipPartitions(rddOfIter){
|
||||
(itrWrapper, itr) =>
|
||||
if (!itr.hasNext) {
|
||||
logger.error("when specifying eval sets as dataframes, you have to ensure that " +
|
||||
"the number of elements in each dataframe is larger than the number of workers")
|
||||
throw new Exception("too few elements in evaluation sets")
|
||||
}
|
||||
val itrArray = itrWrapper.toArray
|
||||
if (itrArray.head != null) {
|
||||
new IteratorWrapper(itrArray :+ (name -> itr))
|
||||
} else {
|
||||
new IteratorWrapper(Array(name -> itr))
|
||||
}
|
||||
}
|
||||
optionalCachedRDD.foreach(_.unpersist())
|
||||
}
|
||||
}
|
||||
|
||||
@ -753,7 +473,7 @@ object XGBoost extends Serializable {
|
||||
|
||||
}
|
||||
|
||||
private class Watches private(
|
||||
class Watches private(
|
||||
val datasets: Array[DMatrix],
|
||||
val names: Array[String],
|
||||
val cacheDirName: Option[String]) {
|
||||
@ -964,50 +684,4 @@ private object Watches {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Within each RDD partition, group the <code>XGBLabeledPoint</code> by group id.</p>
|
||||
* And the first and the last groups may not have all the items due to the data partition.
|
||||
* <code>LabeledPointGroupIterator</code> orginaizes data in a tuple format:
|
||||
* (isFistGroup || isLastGroup, Array[XGBLabeledPoint]).</p>
|
||||
* The edge groups across partitions can be stitched together later.
|
||||
* @param base collection of <code>XGBLabeledPoint</code>
|
||||
*/
|
||||
private[spark] class LabeledPointGroupIterator(base: Iterator[XGBLabeledPoint])
|
||||
extends AbstractIterator[XGBLabeledPointGroup] {
|
||||
|
||||
private var firstPointOfNextGroup: XGBLabeledPoint = null
|
||||
private var isNewGroup = false
|
||||
|
||||
override def hasNext: Boolean = {
|
||||
base.hasNext || isNewGroup
|
||||
}
|
||||
|
||||
override def next(): XGBLabeledPointGroup = {
|
||||
val builder = mutable.ArrayBuilder.make[XGBLabeledPoint]
|
||||
var isFirstGroup = true
|
||||
if (firstPointOfNextGroup != null) {
|
||||
builder += firstPointOfNextGroup
|
||||
isFirstGroup = false
|
||||
}
|
||||
|
||||
isNewGroup = false
|
||||
while (!isNewGroup && base.hasNext) {
|
||||
val point = base.next()
|
||||
val groupId = if (firstPointOfNextGroup != null) firstPointOfNextGroup.group else point.group
|
||||
firstPointOfNextGroup = point
|
||||
if (point.group == groupId) {
|
||||
// add to current group
|
||||
builder += point
|
||||
} else {
|
||||
// start a new group
|
||||
isNewGroup = true
|
||||
}
|
||||
}
|
||||
|
||||
val isLastGroup = !isNewGroup
|
||||
val result = builder.result()
|
||||
val group = XGBLabeledPointGroup(result(0).group, result, isFirstGroup || isLastGroup)
|
||||
|
||||
group
|
||||
}
|
||||
}
|
||||
|
||||
@ -38,7 +38,7 @@ import scala.collection.{AbstractIterator, Iterator, mutable}
|
||||
|
||||
class XGBoostClassifier (
|
||||
override val uid: String,
|
||||
private val xgboostParams: Map[String, Any])
|
||||
private[spark] val xgboostParams: Map[String, Any])
|
||||
extends ProbabilisticClassifier[Vector, XGBoostClassifier, XGBoostClassificationModel]
|
||||
with XGBoostClassifierParams with DefaultParamsWritable {
|
||||
|
||||
@ -176,26 +176,15 @@ class XGBoostClassifier (
|
||||
"\'num_class\' in xgboost params.")
|
||||
}
|
||||
|
||||
val weight = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
|
||||
val baseMargin = if (!isDefined(baseMarginCol) || $(baseMarginCol).isEmpty) {
|
||||
lit(Float.NaN)
|
||||
} else {
|
||||
col($(baseMarginCol))
|
||||
}
|
||||
|
||||
val trainingSet: RDD[XGBLabeledPoint] = DataUtils.convertDataFrameToXGBLabeledPointRDDs(
|
||||
col($(labelCol)), col($(featuresCol)), weight, baseMargin,
|
||||
None, $(numWorkers), needDeterministicRepartitioning, dataset.asInstanceOf[DataFrame]).head
|
||||
val evalRDDMap = getEvalSets(xgboostParams).map {
|
||||
case (name, dataFrame) => (name,
|
||||
DataUtils.convertDataFrameToXGBLabeledPointRDDs(col($(labelCol)), col($(featuresCol)),
|
||||
weight, baseMargin, None, $(numWorkers), needDeterministicRepartitioning, dataFrame).head)
|
||||
}
|
||||
// Packing with all params plus params user defined
|
||||
val derivedXGBParamMap = xgboostParams ++ MLlib2XGBoostParams
|
||||
val buildTrainingData = PreXGBoost.buildDatasetToRDD(this, dataset, derivedXGBParamMap)
|
||||
transformSchema(dataset.schema, logging = true)
|
||||
val derivedXGBParamMap = MLlib2XGBoostParams
|
||||
|
||||
// All non-null param maps in XGBoostClassifier are in derivedXGBParamMap.
|
||||
val (_booster, _metrics) = XGBoost.trainDistributed(trainingSet, derivedXGBParamMap,
|
||||
hasGroup = false, evalRDDMap)
|
||||
val (_booster, _metrics) = XGBoost.trainDistributed(dataset.sparkSession.sparkContext,
|
||||
buildTrainingData, derivedXGBParamMap)
|
||||
|
||||
val model = new XGBoostClassificationModel(uid, _numClasses, _booster)
|
||||
val summary = XGBoostTrainingSummary(_metrics)
|
||||
model.setSummary(summary)
|
||||
@ -265,7 +254,7 @@ class XGBoostClassificationModel private[ml](
|
||||
*/
|
||||
override def predict(features: Vector): Double = {
|
||||
import DataUtils._
|
||||
val dm = new DMatrix(XGBoost.processMissingValues(
|
||||
val dm = new DMatrix(processMissingValues(
|
||||
Iterator(features.asXGB),
|
||||
$(missing),
|
||||
$(allowNonZeroForMissing)
|
||||
@ -324,7 +313,7 @@ class XGBoostClassificationModel private[ml](
|
||||
}
|
||||
|
||||
val dm = new DMatrix(
|
||||
XGBoost.processMissingValues(
|
||||
processMissingValues(
|
||||
features.map(_.asXGB),
|
||||
$(missing),
|
||||
$(allowNonZeroForMissing)
|
||||
|
||||
@ -171,27 +171,16 @@ class XGBoostRegressor (
|
||||
set(objectiveType, "regression")
|
||||
}
|
||||
|
||||
val weight = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
|
||||
val baseMargin = if (!isDefined(baseMarginCol) || $(baseMarginCol).isEmpty) {
|
||||
lit(Float.NaN)
|
||||
} else {
|
||||
col($(baseMarginCol))
|
||||
}
|
||||
val group = if (!isDefined(groupCol) || $(groupCol).isEmpty) lit(-1) else col($(groupCol))
|
||||
val trainingSet: RDD[XGBLabeledPoint] = DataUtils.convertDataFrameToXGBLabeledPointRDDs(
|
||||
col($(labelCol)), col($(featuresCol)), weight, baseMargin, Some(group),
|
||||
$(numWorkers), needDeterministicRepartitioning, dataset.asInstanceOf[DataFrame]).head
|
||||
val evalRDDMap = getEvalSets(xgboostParams).map {
|
||||
case (name, dataFrame) => (name,
|
||||
DataUtils.convertDataFrameToXGBLabeledPointRDDs(col($(labelCol)), col($(featuresCol)),
|
||||
weight, baseMargin, Some(group), $(numWorkers), needDeterministicRepartitioning,
|
||||
dataFrame).head)
|
||||
}
|
||||
transformSchema(dataset.schema, logging = true)
|
||||
val derivedXGBParamMap = MLlib2XGBoostParams
|
||||
|
||||
// Packing with all params plus params user defined
|
||||
val derivedXGBParamMap = xgboostParams ++ MLlib2XGBoostParams
|
||||
val buildTrainingData = PreXGBoost.buildDatasetToRDD(this, dataset, derivedXGBParamMap)
|
||||
|
||||
// All non-null param maps in XGBoostRegressor are in derivedXGBParamMap.
|
||||
val (_booster, _metrics) = XGBoost.trainDistributed(trainingSet, derivedXGBParamMap,
|
||||
hasGroup = group != lit(-1), evalRDDMap)
|
||||
val (_booster, _metrics) = XGBoost.trainDistributed(dataset.sparkSession.sparkContext,
|
||||
buildTrainingData, derivedXGBParamMap)
|
||||
|
||||
val model = new XGBoostRegressionModel(uid, _booster)
|
||||
val summary = XGBoostTrainingSummary(_metrics)
|
||||
model.setSummary(summary)
|
||||
@ -260,7 +249,7 @@ class XGBoostRegressionModel private[ml] (
|
||||
*/
|
||||
override def predict(features: Vector): Double = {
|
||||
import DataUtils._
|
||||
val dm = new DMatrix(XGBoost.processMissingValues(
|
||||
val dm = new DMatrix(processMissingValues(
|
||||
Iterator(features.asXGB),
|
||||
$(missing),
|
||||
$(allowNonZeroForMissing)
|
||||
@ -301,7 +290,7 @@ class XGBoostRegressionModel private[ml] (
|
||||
}
|
||||
|
||||
val dm = new DMatrix(
|
||||
XGBoost.processMissingValues(
|
||||
processMissingValues(
|
||||
features.map(_.asXGB),
|
||||
$(missing),
|
||||
$(allowNonZeroForMissing)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014,2021 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -14,24 +14,20 @@
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
package ml.dmlc.xgboost4j.scala.spark.params
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.spark.params._
|
||||
|
||||
import org.apache.spark.ml.param.shared.HasWeightCol
|
||||
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasWeightCol}
|
||||
|
||||
private[spark] sealed trait XGBoostEstimatorCommon extends GeneralParams with LearningTaskParams
|
||||
with BoosterParams with RabitParams with ParamMapFuncs with NonParamVariables {
|
||||
with BoosterParams with RabitParams with ParamMapFuncs with NonParamVariables with HasWeightCol
|
||||
with HasBaseMarginCol with HasLeafPredictionCol with HasContribPredictionCol with HasFeaturesCol
|
||||
with HasLabelCol {
|
||||
|
||||
def needDeterministicRepartitioning: Boolean = {
|
||||
getCheckpointPath != null && getCheckpointPath.nonEmpty && getCheckpointInterval > 0
|
||||
}
|
||||
}
|
||||
|
||||
private[spark] trait XGBoostClassifierParams extends HasWeightCol with HasBaseMarginCol
|
||||
with HasNumClass with HasLeafPredictionCol with HasContribPredictionCol
|
||||
with XGBoostEstimatorCommon
|
||||
private[spark] trait XGBoostClassifierParams extends XGBoostEstimatorCommon with HasNumClass
|
||||
|
||||
private[spark] trait XGBoostRegressorParams extends HasBaseMarginCol with HasWeightCol
|
||||
with HasGroupCol with HasLeafPredictionCol with HasContribPredictionCol
|
||||
with XGBoostEstimatorCommon
|
||||
private[spark] trait XGBoostRegressorParams extends XGBoostEstimatorCommon with HasGroupCol
|
||||
@ -18,6 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import org.apache.spark.ml.linalg.Vectors
|
||||
import org.scalatest.FunSuite
|
||||
import ml.dmlc.xgboost4j.scala.spark.DataUtils.PackedParams
|
||||
|
||||
import org.apache.spark.sql.functions._
|
||||
|
||||
@ -55,13 +56,13 @@ class DeterministicPartitioningSuite extends FunSuite with TmpFolderPerSuite wit
|
||||
resultDF
|
||||
})
|
||||
val transformedRDDs = transformedDFs.map(df => DataUtils.convertDataFrameToXGBLabeledPointRDDs(
|
||||
col("label"),
|
||||
col("features"),
|
||||
lit(1.0),
|
||||
lit(Float.NaN),
|
||||
None,
|
||||
numWorkers,
|
||||
deterministicPartition = true,
|
||||
PackedParams(col("label"),
|
||||
col("features"),
|
||||
lit(1.0),
|
||||
lit(Float.NaN),
|
||||
None,
|
||||
numWorkers,
|
||||
deterministicPartition = true),
|
||||
df
|
||||
).head)
|
||||
val resultsMaps = transformedRDDs.map(rdd => rdd.mapPartitionsWithIndex {
|
||||
@ -90,14 +91,13 @@ class DeterministicPartitioningSuite extends FunSuite with TmpFolderPerSuite wit
|
||||
val df = ss.createDataFrame(sc.parallelize(dataset)).toDF("id", "label", "features")
|
||||
|
||||
val dfRepartitioned = DataUtils.convertDataFrameToXGBLabeledPointRDDs(
|
||||
col("label"),
|
||||
col("features"),
|
||||
lit(1.0),
|
||||
lit(Float.NaN),
|
||||
None,
|
||||
10,
|
||||
deterministicPartition = true,
|
||||
df
|
||||
PackedParams(col("label"),
|
||||
col("features"),
|
||||
lit(1.0),
|
||||
lit(Float.NaN),
|
||||
None,
|
||||
10,
|
||||
deterministicPartition = true), df
|
||||
).head
|
||||
|
||||
val partitionsSizes = dfRepartitioned
|
||||
|
||||
@ -17,12 +17,14 @@
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError
|
||||
|
||||
import scala.util.Random
|
||||
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||
|
||||
import org.apache.spark.TaskContext
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
import org.apache.spark.ml.feature.VectorAssembler
|
||||
import org.apache.spark.sql.functions.lit
|
||||
|
||||
@ -30,13 +32,14 @@ class XGBoostGeneralSuite extends FunSuite with TmpFolderPerSuite with PerTest {
|
||||
|
||||
test("distributed training with the specified worker number") {
|
||||
val trainingRDD = sc.parallelize(Classification.train)
|
||||
val buildTrainingRDD = PreXGBoost.buildRDDLabeledPointToRDDWatches(trainingRDD)
|
||||
val (booster, metrics) = XGBoost.trainDistributed(
|
||||
trainingRDD,
|
||||
sc,
|
||||
buildTrainingRDD,
|
||||
List("eta" -> "1", "max_depth" -> "6",
|
||||
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
|
||||
"custom_eval" -> null, "custom_obj" -> null, "use_external_memory" -> false,
|
||||
"missing" -> Float.NaN).toMap,
|
||||
hasGroup = false)
|
||||
"missing" -> Float.NaN).toMap)
|
||||
assert(booster != null)
|
||||
}
|
||||
|
||||
@ -179,7 +182,7 @@ class XGBoostGeneralSuite extends FunSuite with TmpFolderPerSuite with PerTest {
|
||||
// test different splits to cover the corner cases.
|
||||
for (split <- 1 to 20) {
|
||||
val trainingRDD = sc.parallelize(Ranking.train, split)
|
||||
val traingGroupsRDD = XGBoost.repartitionForTrainingGroup(trainingRDD, 4)
|
||||
val traingGroupsRDD = PreXGBoost.repartitionForTrainingGroup(trainingRDD, 4)
|
||||
val trainingGroups: Array[Array[XGBLabeledPoint]] = traingGroupsRDD.collect()
|
||||
// check the the order of the groups with group id.
|
||||
// Ranking.train has 20 groups
|
||||
@ -201,18 +204,19 @@ class XGBoostGeneralSuite extends FunSuite with TmpFolderPerSuite with PerTest {
|
||||
// make one partition empty for testing
|
||||
it.filter(_ => TaskContext.getPartitionId() != 3)
|
||||
})
|
||||
XGBoost.repartitionForTrainingGroup(trainingRDD, 4)
|
||||
PreXGBoost.repartitionForTrainingGroup(trainingRDD, 4)
|
||||
}
|
||||
|
||||
test("distributed training with group data") {
|
||||
val trainingRDD = sc.parallelize(Ranking.train, 5)
|
||||
val buildTrainingRDD = PreXGBoost.buildRDDLabeledPointToRDDWatches(trainingRDD, hasGroup = true)
|
||||
val (booster, _) = XGBoost.trainDistributed(
|
||||
trainingRDD,
|
||||
sc,
|
||||
buildTrainingRDD,
|
||||
List("eta" -> "1", "max_depth" -> "6",
|
||||
"objective" -> "rank:pairwise", "num_round" -> 5, "num_workers" -> numWorkers,
|
||||
"custom_eval" -> null, "custom_obj" -> null, "use_external_memory" -> false,
|
||||
"missing" -> Float.NaN).toMap,
|
||||
hasGroup = true)
|
||||
"missing" -> Float.NaN).toMap)
|
||||
|
||||
assert(booster != null)
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user