[jvm-packages] Rework the train pipeline (#7401)

1. Add PreXGBoost to build RDD[Watches] from Dataset
2. Feed RDD[Watches] built from PreXGBoost to XGBoost to train
This commit is contained in:
Bobby Wang 2021-11-10 17:51:38 +08:00 committed by GitHub
parent 8df0a252b7
commit cb685607b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 631 additions and 470 deletions

View File

@ -16,6 +16,8 @@
package ml.dmlc.xgboost4j.scala.spark
import scala.collection.mutable
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.spark.HashPartitioner
@ -99,43 +101,129 @@ object DataUtils extends Serializable {
}
}
/** Packed parameters used by [[convertDataFrameToXGBLabeledPointRDDs]] */
private[spark] case class PackedParams(labelCol: Column,
featuresCol: Column,
weight: Column,
baseMargin: Column,
group: Option[Column],
numWorkers: Int,
deterministicPartition: Boolean)
/**
* convertDataFrameToXGBLabeledPointRDDs converts DataFrames to an array of RDD[XGBLabeledPoint]
*
* First, it serves converting each instance of input into XGBLabeledPoint
* Second, it repartition the RDD to the number workers.
*
*/
private[spark] def convertDataFrameToXGBLabeledPointRDDs(
labelCol: Column,
featuresCol: Column,
weight: Column,
baseMargin: Column,
group: Option[Column],
numWorkers: Int,
deterministicPartition: Boolean,
dataFrames: DataFrame*): Array[RDD[XGBLabeledPoint]] = {
val selectedColumns = group.map(groupCol => Seq(labelCol.cast(FloatType),
featuresCol,
weight.cast(FloatType),
groupCol.cast(IntegerType),
baseMargin.cast(FloatType))).getOrElse(Seq(labelCol.cast(FloatType),
featuresCol,
weight.cast(FloatType),
baseMargin.cast(FloatType)))
val arrayOfRDDs = dataFrames.toArray.map {
df => df.select(selectedColumns: _*).rdd.map {
case row @ Row(label: Float, features: Vector, weight: Float, group: Int,
baseMargin: Float) =>
val (size, indices, values) = features match {
case v: SparseVector => (v.size, v.indices, v.values.map(_.toFloat))
case v: DenseVector => (v.size, null, v.values.map(_.toFloat))
packedParams: PackedParams,
dataFrames: DataFrame*): Array[RDD[XGBLabeledPoint]] = {
packedParams match {
case j @ PackedParams(labelCol, featuresCol, weight, baseMargin, group, numWorkers,
deterministicPartition) =>
val selectedColumns = group.map(groupCol => Seq(labelCol.cast(FloatType),
featuresCol,
weight.cast(FloatType),
groupCol.cast(IntegerType),
baseMargin.cast(FloatType))).getOrElse(Seq(labelCol.cast(FloatType),
featuresCol,
weight.cast(FloatType),
baseMargin.cast(FloatType)))
val arrayOfRDDs = dataFrames.toArray.map {
df => df.select(selectedColumns: _*).rdd.map {
case row @ Row(label: Float, features: Vector, weight: Float, group: Int,
baseMargin: Float) =>
val (size, indices, values) = features match {
case v: SparseVector => (v.size, v.indices, v.values.map(_.toFloat))
case v: DenseVector => (v.size, null, v.values.map(_.toFloat))
}
val xgbLp = XGBLabeledPoint(label, size, indices, values, weight, group, baseMargin)
attachPartitionKey(row, deterministicPartition, numWorkers, xgbLp)
case row @ Row(label: Float, features: Vector, weight: Float, baseMargin: Float) =>
val (size, indices, values) = features match {
case v: SparseVector => (v.size, v.indices, v.values.map(_.toFloat))
case v: DenseVector => (v.size, null, v.values.map(_.toFloat))
}
val xgbLp = XGBLabeledPoint(label, size, indices, values, weight,
baseMargin = baseMargin)
attachPartitionKey(row, deterministicPartition, numWorkers, xgbLp)
}
val xgbLp = XGBLabeledPoint(label, size, indices, values, weight, group, baseMargin)
attachPartitionKey(row, deterministicPartition, numWorkers, xgbLp)
case row @ Row(label: Float, features: Vector, weight: Float, baseMargin: Float) =>
val (size, indices, values) = features match {
case v: SparseVector => (v.size, v.indices, v.values.map(_.toFloat))
case v: DenseVector => (v.size, null, v.values.map(_.toFloat))
}
val xgbLp = XGBLabeledPoint(label, size, indices, values, weight, baseMargin = baseMargin)
attachPartitionKey(row, deterministicPartition, numWorkers, xgbLp)
}
}
repartitionRDDs(deterministicPartition, numWorkers, arrayOfRDDs)
case _ => throw new IllegalArgumentException("Wrong PackedParams") // never reach here
}
repartitionRDDs(deterministicPartition, numWorkers, arrayOfRDDs)
}
private[spark] def processMissingValues(
xgbLabelPoints: Iterator[XGBLabeledPoint],
missing: Float,
allowNonZeroMissing: Boolean): Iterator[XGBLabeledPoint] = {
if (!missing.isNaN) {
removeMissingValues(verifyMissingSetting(xgbLabelPoints, missing, allowNonZeroMissing),
missing, (v: Float) => v != missing)
} else {
removeMissingValues(verifyMissingSetting(xgbLabelPoints, missing, allowNonZeroMissing),
missing, (v: Float) => !v.isNaN)
}
}
private[spark] def processMissingValuesWithGroup(
xgbLabelPointGroups: Iterator[Array[XGBLabeledPoint]],
missing: Float,
allowNonZeroMissing: Boolean): Iterator[Array[XGBLabeledPoint]] = {
if (!missing.isNaN) {
xgbLabelPointGroups.map {
labeledPoints => processMissingValues(
labeledPoints.iterator,
missing,
allowNonZeroMissing
).toArray
}
} else {
xgbLabelPointGroups
}
}
private def removeMissingValues(
xgbLabelPoints: Iterator[XGBLabeledPoint],
missing: Float,
keepCondition: Float => Boolean): Iterator[XGBLabeledPoint] = {
xgbLabelPoints.map { labeledPoint =>
val indicesBuilder = new mutable.ArrayBuilder.ofInt()
val valuesBuilder = new mutable.ArrayBuilder.ofFloat()
for ((value, i) <- labeledPoint.values.zipWithIndex if keepCondition(value)) {
indicesBuilder += (if (labeledPoint.indices == null) i else labeledPoint.indices(i))
valuesBuilder += value
}
labeledPoint.copy(indices = indicesBuilder.result(), values = valuesBuilder.result())
}
}
private def verifyMissingSetting(
xgbLabelPoints: Iterator[XGBLabeledPoint],
missing: Float,
allowNonZeroMissing: Boolean): Iterator[XGBLabeledPoint] = {
if (missing != 0.0f && !allowNonZeroMissing) {
xgbLabelPoints.map(labeledPoint => {
if (labeledPoint.indices != null) {
throw new RuntimeException(s"you can only specify missing value as 0.0 (the currently" +
s" set value $missing) when you have SparseVector or Empty vector as your feature" +
s" format. If you didn't use Spark's VectorAssembler class to build your feature " +
s"vector but instead did so in a way that preserves zeros in your feature vector " +
s"you can avoid this check by using the 'allow_non_zero_for_missing parameter'" +
s" (only use if you know what you are doing)")
}
labeledPoint
})
} else {
xgbLabelPoints
}
}
}

View File

@ -0,0 +1,421 @@
/*
Copyright (c) 2021 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import java.nio.file.Files
import scala.collection.{AbstractIterator, mutable}
import ml.dmlc.xgboost4j.scala.spark.DataUtils.PackedParams
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.{col, lit}
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.commons.logging.LogFactory
import org.apache.spark.TaskContext
import org.apache.spark.ml.Estimator
import org.apache.spark.storage.StorageLevel
/**
* PreXGBoost converts Dataset[_] to RDD[[Watches]]
*/
object PreXGBoost {
private val logger = LogFactory.getLog("XGBoostSpark")
private lazy val defaultBaseMarginColumn = lit(Float.NaN)
private lazy val defaultWeightColumn = lit(1.0)
private lazy val defaultGroupColumn = lit(-1)
/**
* Convert the Dataset[_] to RDD[Watches] which will be fed to XGBoost
*
* @param estimator supports XGBoostClassifier and XGBoostRegressor
* @param dataset the training data
* @param params all user defined and defaulted params
* @return [[XGBoostExecutionParams]] => (RDD[[Watches]], Option[ RDD[_] ])
* RDD[Watches] will be used as the training input
* Option[RDD[_]\] is the optional cached RDD
*/
def buildDatasetToRDD(
estimator: Estimator[_],
dataset: Dataset[_],
params: Map[String, Any]): XGBoostExecutionParams => (RDD[Watches], Option[RDD[_]]) = {
val (packedParams, evalSet) = estimator match {
case est: XGBoostEstimatorCommon =>
// get weight column, if weight is not defined, default to lit(1.0)
val weight = if (!est.isDefined(est.weightCol) || est.getWeightCol.isEmpty) {
defaultWeightColumn
} else col(est.getWeightCol)
// get base-margin column, if base-margin is not defined, default to lit(Float.NaN)
val baseMargin = if (!est.isDefined(est.baseMarginCol) || est.getBaseMarginCol.isEmpty) {
defaultBaseMarginColumn
} else col(est.getBaseMarginCol)
val group = est match {
case regressor: XGBoostRegressor =>
// get group column, if group is not defined, default to lit(-1)
Some(
if (!regressor.isDefined(regressor.groupCol) || regressor.getGroupCol.isEmpty) {
defaultGroupColumn
} else col(regressor.getGroupCol)
)
case _ => None
}
(PackedParams(col(est.getLabelCol), col(est.getFeaturesCol), weight, baseMargin, group,
est.getNumWorkers, est.needDeterministicRepartitioning), est.getEvalSets(params))
case _ => throw new RuntimeException("Unsupporting " + estimator)
}
// transform the training Dataset[_] to RDD[XGBLabeledPoint]
val trainingSet: RDD[XGBLabeledPoint] = DataUtils.convertDataFrameToXGBLabeledPointRDDs(
packedParams, dataset.asInstanceOf[DataFrame]).head
// transform the eval Dataset[_] to RDD[XGBLabeledPoint]
val evalRDDMap = evalSet.map {
case (name, dataFrame) => (name,
DataUtils.convertDataFrameToXGBLabeledPointRDDs(packedParams, dataFrame).head)
}
val hasGroup = packedParams.group.map(_ != defaultGroupColumn).getOrElse(false)
xgbExecParams: XGBoostExecutionParams =>
composeInputData(trainingSet, hasGroup, packedParams.numWorkers) match {
case Left(trainingData) =>
val cachedRDD = if (xgbExecParams.cacheTrainingSet) {
Some(trainingData.persist(StorageLevel.MEMORY_AND_DISK))
} else None
(trainForRanking(trainingData, xgbExecParams, evalRDDMap), cachedRDD)
case Right(trainingData) =>
val cachedRDD = if (xgbExecParams.cacheTrainingSet) {
Some(trainingData.persist(StorageLevel.MEMORY_AND_DISK))
} else None
(trainForNonRanking(trainingData, xgbExecParams, evalRDDMap), cachedRDD)
}
}
/**
* Converting the RDD[XGBLabeledPoint] to the function to build RDD[Watches]
*
* @param trainingSet the input training RDD[XGBLabeledPoint]
* @param evalRDDMap the eval set
* @param hasGroup if has group
* @return function to build (RDD[Watches], the cached RDD)
*/
private[spark] def buildRDDLabeledPointToRDDWatches(
trainingSet: RDD[XGBLabeledPoint],
evalRDDMap: Map[String, RDD[XGBLabeledPoint]] = Map(),
hasGroup: Boolean = false): XGBoostExecutionParams => (RDD[Watches], Option[RDD[_]]) = {
xgbExecParams: XGBoostExecutionParams =>
composeInputData(trainingSet, hasGroup, xgbExecParams.numWorkers) match {
case Left(trainingData) =>
val cachedRDD = if (xgbExecParams.cacheTrainingSet) {
Some(trainingData.persist(StorageLevel.MEMORY_AND_DISK))
} else None
(trainForRanking(trainingData, xgbExecParams, evalRDDMap), cachedRDD)
case Right(trainingData) =>
val cachedRDD = if (xgbExecParams.cacheTrainingSet) {
Some(trainingData.persist(StorageLevel.MEMORY_AND_DISK))
} else None
(trainForNonRanking(trainingData, xgbExecParams, evalRDDMap), cachedRDD)
}
}
/**
* Transform RDD according to group column
*
* @param trainingData the input XGBLabeledPoint RDD
* @param hasGroup if has group column
* @param nWorkers total xgboost number workers to run xgboost tasks
* @return Either: the left is RDD with group, and the right is RDD without group
*/
private def composeInputData(
trainingData: RDD[XGBLabeledPoint],
hasGroup: Boolean,
nWorkers: Int): Either[RDD[Array[XGBLabeledPoint]], RDD[XGBLabeledPoint]] = {
if (hasGroup) {
Left(repartitionForTrainingGroup(trainingData, nWorkers))
} else {
Right(trainingData)
}
}
/**
* Repartition trainingData with group directly may cause data chaos, since the same group data
* may be split into different partitions.
*
* The first step is to aggregate the same group into same partition
* The second step is to repartition to nWorkers
*
* TODO, Could we repartition trainingData on group?
*/
private[spark] def repartitionForTrainingGroup(trainingData: RDD[XGBLabeledPoint],
nWorkers: Int): RDD[Array[XGBLabeledPoint]] = {
val allGroups = aggByGroupInfo(trainingData)
logger.info(s"repartitioning training group set to $nWorkers partitions")
allGroups.repartition(nWorkers)
}
/**
* Build RDD[Watches] for Ranking
* @param trainingData the training data RDD
* @param xgbExecutionParams xgboost execution params
* @param evalSetsMap the eval RDD
* @return RDD[Watches]
*/
private def trainForRanking(
trainingData: RDD[Array[XGBLabeledPoint]],
xgbExecutionParam: XGBoostExecutionParams,
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[Watches] = {
if (evalSetsMap.isEmpty) {
trainingData.mapPartitions(labeledPointGroups => {
val watches = Watches.buildWatchesWithGroup(xgbExecutionParam,
DataUtils.processMissingValuesWithGroup(labeledPointGroups, xgbExecutionParam.missing,
xgbExecutionParam.allowNonZeroForMissing),
getCacheDirName(xgbExecutionParam.useExternalMemory))
Iterator.single(watches)
}).cache()
} else {
coPartitionGroupSets(trainingData, evalSetsMap, xgbExecutionParam.numWorkers).mapPartitions(
labeledPointGroupSets => {
val watches = Watches.buildWatchesWithGroup(
labeledPointGroupSets.map {
case (name, iter) => (name, DataUtils.processMissingValuesWithGroup(iter,
xgbExecutionParam.missing, xgbExecutionParam.allowNonZeroForMissing))
},
getCacheDirName(xgbExecutionParam.useExternalMemory))
Iterator.single(watches)
}).cache()
}
}
private def coPartitionGroupSets(
aggedTrainingSet: RDD[Array[XGBLabeledPoint]],
evalSets: Map[String, RDD[XGBLabeledPoint]],
nWorkers: Int): RDD[(String, Iterator[Array[XGBLabeledPoint]])] = {
val repartitionedDatasets = Map("train" -> aggedTrainingSet) ++ evalSets.map {
case (name, rdd) => {
val aggedRdd = aggByGroupInfo(rdd)
if (aggedRdd.getNumPartitions != nWorkers) {
name -> aggedRdd.repartition(nWorkers)
} else {
name -> aggedRdd
}
}
}
repartitionedDatasets.foldLeft(aggedTrainingSet.sparkContext.parallelize(
Array.fill[(String, Iterator[Array[XGBLabeledPoint]])](nWorkers)(null), nWorkers)) {
case (rddOfIterWrapper, (name, rddOfIter)) =>
rddOfIterWrapper.zipPartitions(rddOfIter) {
(itrWrapper, itr) =>
if (!itr.hasNext) {
logger.error("when specifying eval sets as dataframes, you have to ensure that " +
"the number of elements in each dataframe is larger than the number of workers")
throw new Exception("too few elements in evaluation sets")
}
val itrArray = itrWrapper.toArray
if (itrArray.head != null) {
new IteratorWrapper(itrArray :+ (name -> itr))
} else {
new IteratorWrapper(Array(name -> itr))
}
}
}
}
private def aggByGroupInfo(trainingData: RDD[XGBLabeledPoint]) = {
val normalGroups: RDD[Array[XGBLabeledPoint]] = trainingData.mapPartitions(
// LabeledPointGroupIterator returns (Boolean, Array[XGBLabeledPoint])
new LabeledPointGroupIterator(_)).filter(!_.isEdgeGroup).map(_.points)
// edge groups with partition id.
val edgeGroups: RDD[(Int, XGBLabeledPointGroup)] = trainingData.mapPartitions(
new LabeledPointGroupIterator(_)).filter(_.isEdgeGroup).map(
group => (TaskContext.getPartitionId(), group))
// group chunks from different partitions together by group id in XGBLabeledPoint.
// use groupBy instead of aggregateBy since all groups within a partition have unique group ids.
val stitchedGroups: RDD[Array[XGBLabeledPoint]] = edgeGroups.groupBy(_._2.groupId).map(
groups => {
val it: Iterable[(Int, XGBLabeledPointGroup)] = groups._2
// sorted by partition id and merge list of Array[XGBLabeledPoint] into one array
it.toArray.sortBy(_._1).flatMap(_._2.points)
})
normalGroups.union(stitchedGroups)
}
/**
* Build RDD[Watches] for Non-Ranking
* @param trainingData the training data RDD
* @param xgbExecutionParams xgboost execution params
* @param evalSetsMap the eval RDD
* @return RDD[Watches]
*/
private def trainForNonRanking(
trainingData: RDD[XGBLabeledPoint],
xgbExecutionParams: XGBoostExecutionParams,
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[Watches] = {
if (evalSetsMap.isEmpty) {
trainingData.mapPartitions { labeledPoints => {
val watches = Watches.buildWatches(xgbExecutionParams,
DataUtils.processMissingValues(labeledPoints, xgbExecutionParams.missing,
xgbExecutionParams.allowNonZeroForMissing),
getCacheDirName(xgbExecutionParams.useExternalMemory))
Iterator.single(watches)
}}.cache()
} else {
coPartitionNoGroupSets(trainingData, evalSetsMap, xgbExecutionParams.numWorkers).
mapPartitions {
nameAndLabeledPointSets =>
val watches = Watches.buildWatches(
nameAndLabeledPointSets.map {
case (name, iter) => (name, DataUtils.processMissingValues(iter,
xgbExecutionParams.missing, xgbExecutionParams.allowNonZeroForMissing))
},
getCacheDirName(xgbExecutionParams.useExternalMemory))
Iterator.single(watches)
}.cache()
}
}
private def coPartitionNoGroupSets(
trainingData: RDD[XGBLabeledPoint],
evalSets: Map[String, RDD[XGBLabeledPoint]],
nWorkers: Int) = {
// eval_sets is supposed to be set by the caller of [[trainDistributed]]
val allDatasets = Map("train" -> trainingData) ++ evalSets
val repartitionedDatasets = allDatasets.map { case (name, rdd) =>
if (rdd.getNumPartitions != nWorkers) {
(name, rdd.repartition(nWorkers))
} else {
(name, rdd)
}
}
repartitionedDatasets.foldLeft(trainingData.sparkContext.parallelize(
Array.fill[(String, Iterator[XGBLabeledPoint])](nWorkers)(null), nWorkers)) {
case (rddOfIterWrapper, (name, rddOfIter)) =>
rddOfIterWrapper.zipPartitions(rddOfIter) {
(itrWrapper, itr) =>
if (!itr.hasNext) {
logger.error("when specifying eval sets as dataframes, you have to ensure that " +
"the number of elements in each dataframe is larger than the number of workers")
throw new Exception("too few elements in evaluation sets")
}
val itrArray = itrWrapper.toArray
if (itrArray.head != null) {
new IteratorWrapper(itrArray :+ (name -> itr))
} else {
new IteratorWrapper(Array(name -> itr))
}
}
}
}
private def getCacheDirName(useExternalMemory: Boolean): Option[String] = {
val taskId = TaskContext.getPartitionId().toString
if (useExternalMemory) {
val dir = Files.createTempDirectory(s"${TaskContext.get().stageId()}-cache-$taskId")
Some(dir.toAbsolutePath.toString)
} else {
None
}
}
}
class IteratorWrapper[T](arrayOfXGBLabeledPoints: Array[(String, Iterator[T])])
extends Iterator[(String, Iterator[T])] {
private var currentIndex = 0
override def hasNext: Boolean = currentIndex <= arrayOfXGBLabeledPoints.length - 1
override def next(): (String, Iterator[T]) = {
currentIndex += 1
arrayOfXGBLabeledPoints(currentIndex - 1)
}
}
/**
* Training data group in a RDD partition.
*
* @param groupId The group id
* @param points Array of XGBLabeledPoint within the same group.
* @param isEdgeGroup whether it is a first or last group in a RDD partition.
*/
private[spark] case class XGBLabeledPointGroup(
groupId: Int,
points: Array[XGBLabeledPoint],
isEdgeGroup: Boolean)
/**
* Within each RDD partition, group the <code>XGBLabeledPoint</code> by group id.</p>
* And the first and the last groups may not have all the items due to the data partition.
* <code>LabeledPointGroupIterator</code> organizes data in a tuple format:
* (isFistGroup || isLastGroup, Array[XGBLabeledPoint]).</p>
* The edge groups across partitions can be stitched together later.
* @param base collection of <code>XGBLabeledPoint</code>
*/
private[spark] class LabeledPointGroupIterator(base: Iterator[XGBLabeledPoint])
extends AbstractIterator[XGBLabeledPointGroup] {
private var firstPointOfNextGroup: XGBLabeledPoint = null
private var isNewGroup = false
override def hasNext: Boolean = {
base.hasNext || isNewGroup
}
override def next(): XGBLabeledPointGroup = {
val builder = mutable.ArrayBuilder.make[XGBLabeledPoint]
var isFirstGroup = true
if (firstPointOfNextGroup != null) {
builder += firstPointOfNextGroup
isFirstGroup = false
}
isNewGroup = false
while (!isNewGroup && base.hasNext) {
val point = base.next()
val groupId = if (firstPointOfNextGroup != null) firstPointOfNextGroup.group else point.group
firstPointOfNextGroup = point
if (point.group == groupId) {
// add to current group
builder += point
} else {
// start a new group
isNewGroup = true
}
}
val isLastGroup = !isNewGroup
val result = builder.result()
val group = XGBLabeledPointGroup(result(0).group, result, isFirstGroup || isLastGroup)
group
}
}

View File

@ -17,9 +17,8 @@
package ml.dmlc.xgboost4j.scala.spark
import java.io.File
import java.nio.file.Files
import scala.collection.{AbstractIterator, mutable}
import scala.collection.mutable
import scala.util.Random
import scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
@ -34,8 +33,6 @@ import org.apache.hadoop.fs.FileSystem
import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext}
import org.apache.spark.sql.SparkSession
import org.apache.spark.storage.StorageLevel
/**
* Rabit tracker configurations.
@ -50,7 +47,7 @@ import org.apache.spark.storage.StorageLevel
* in Scala without Python components, and with full support of timeouts.
* The Scala implementation is currently experimental, use at your own risk.
*/
case class TrackerConf(workerConnectionTimeout: Long, trackerImpl: String)
case class TrackerConf(workerConnectionTimeout: Long, trackerImpl: String )
object TrackerConf {
def apply(): TrackerConf = TrackerConf(0L, "python")
@ -61,7 +58,7 @@ private[this] case class XGBoostExecutionEarlyStoppingParams(numEarlyStoppingRou
private[this] case class XGBoostExecutionInputParams(trainTestRatio: Double, seed: Long)
private[this] case class XGBoostExecutionParams(
private[spark] case class XGBoostExecutionParams(
numWorkers: Int,
numRounds: Int,
useExternalMemory: Boolean,
@ -257,96 +254,9 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
)
}
/**
* Traing data group in a RDD partition.
* @param groupId The group id
* @param points Array of XGBLabeledPoint within the same group.
* @param isEdgeGroup whether it is a frist or last group in a RDD partition.
*/
private[spark] case class XGBLabeledPointGroup(
groupId: Int,
points: Array[XGBLabeledPoint],
isEdgeGroup: Boolean)
object XGBoost extends Serializable {
private val logger = LogFactory.getLog("XGBoostSpark")
private def verifyMissingSetting(
xgbLabelPoints: Iterator[XGBLabeledPoint],
missing: Float,
allowNonZeroMissing: Boolean): Iterator[XGBLabeledPoint] = {
if (missing != 0.0f && !allowNonZeroMissing) {
xgbLabelPoints.map(labeledPoint => {
if (labeledPoint.indices != null) {
throw new RuntimeException(s"you can only specify missing value as 0.0 (the currently" +
s" set value $missing) when you have SparseVector or Empty vector as your feature" +
s" format. If you didn't use Spark's VectorAssembler class to build your feature " +
s"vector but instead did so in a way that preserves zeros in your feature vector " +
s"you can avoid this check by using the 'allow_non_zero_for_missing parameter'" +
s" (only use if you know what you are doing)")
}
labeledPoint
})
} else {
xgbLabelPoints
}
}
private def removeMissingValues(
xgbLabelPoints: Iterator[XGBLabeledPoint],
missing: Float,
keepCondition: Float => Boolean): Iterator[XGBLabeledPoint] = {
xgbLabelPoints.map { labeledPoint =>
val indicesBuilder = new mutable.ArrayBuilder.ofInt()
val valuesBuilder = new mutable.ArrayBuilder.ofFloat()
for ((value, i) <- labeledPoint.values.zipWithIndex if keepCondition(value)) {
indicesBuilder += (if (labeledPoint.indices == null) i else labeledPoint.indices(i))
valuesBuilder += value
}
labeledPoint.copy(indices = indicesBuilder.result(), values = valuesBuilder.result())
}
}
private[spark] def processMissingValues(
xgbLabelPoints: Iterator[XGBLabeledPoint],
missing: Float,
allowNonZeroMissing: Boolean): Iterator[XGBLabeledPoint] = {
if (!missing.isNaN) {
removeMissingValues(verifyMissingSetting(xgbLabelPoints, missing, allowNonZeroMissing),
missing, (v: Float) => v != missing)
} else {
removeMissingValues(verifyMissingSetting(xgbLabelPoints, missing, allowNonZeroMissing),
missing, (v: Float) => !v.isNaN)
}
}
private def processMissingValuesWithGroup(
xgbLabelPointGroups: Iterator[Array[XGBLabeledPoint]],
missing: Float,
allowNonZeroMissing: Boolean): Iterator[Array[XGBLabeledPoint]] = {
if (!missing.isNaN) {
xgbLabelPointGroups.map {
labeledPoints => XGBoost.processMissingValues(
labeledPoints.iterator,
missing,
allowNonZeroMissing
).toArray
}
} else {
xgbLabelPointGroups
}
}
private def getCacheDirName(useExternalMemory: Boolean): Option[String] = {
val taskId = TaskContext.getPartitionId().toString
if (useExternalMemory) {
val dir = Files.createTempDirectory(s"${TaskContext.get().stageId()}-cache-$taskId")
Some(dir.toAbsolutePath.toString)
} else {
None
}
}
private def getGPUAddrFromResources: Int = {
val tc = TaskContext.get()
if (tc == null) {
@ -437,150 +347,22 @@ object XGBoost extends Serializable {
tracker
}
class IteratorWrapper[T](arrayOfXGBLabeledPoints: Array[(String, Iterator[T])])
extends Iterator[(String, Iterator[T])] {
private var currentIndex = 0
override def hasNext: Boolean = currentIndex <= arrayOfXGBLabeledPoints.length - 1
override def next(): (String, Iterator[T]) = {
currentIndex += 1
arrayOfXGBLabeledPoints(currentIndex - 1)
}
}
private def coPartitionNoGroupSets(
trainingData: RDD[XGBLabeledPoint],
evalSets: Map[String, RDD[XGBLabeledPoint]],
nWorkers: Int) = {
// eval_sets is supposed to be set by the caller of [[trainDistributed]]
val allDatasets = Map("train" -> trainingData) ++ evalSets
val repartitionedDatasets = allDatasets.map{case (name, rdd) =>
if (rdd.getNumPartitions != nWorkers) {
(name, rdd.repartition(nWorkers))
} else {
(name, rdd)
}
}
repartitionedDatasets.foldLeft(trainingData.sparkContext.parallelize(
Array.fill[(String, Iterator[XGBLabeledPoint])](nWorkers)(null), nWorkers)){
case (rddOfIterWrapper, (name, rddOfIter)) =>
rddOfIterWrapper.zipPartitions(rddOfIter){
(itrWrapper, itr) =>
if (!itr.hasNext) {
logger.error("when specifying eval sets as dataframes, you have to ensure that " +
"the number of elements in each dataframe is larger than the number of workers")
throw new Exception("too few elements in evaluation sets")
}
val itrArray = itrWrapper.toArray
if (itrArray.head != null) {
new IteratorWrapper(itrArray :+ (name -> itr))
} else {
new IteratorWrapper(Array(name -> itr))
}
}
}
}
private def trainForNonRanking(
trainingData: RDD[XGBLabeledPoint],
xgbExecutionParams: XGBoostExecutionParams,
rabitEnv: java.util.Map[String, String],
prevBooster: Booster,
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
if (evalSetsMap.isEmpty) {
trainingData.mapPartitions(labeledPoints => {
val watches = Watches.buildWatches(xgbExecutionParams,
processMissingValues(labeledPoints, xgbExecutionParams.missing,
xgbExecutionParams.allowNonZeroForMissing),
getCacheDirName(xgbExecutionParams.useExternalMemory))
buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, xgbExecutionParams.obj,
xgbExecutionParams.eval, prevBooster)
}).cache()
} else {
coPartitionNoGroupSets(trainingData, evalSetsMap, xgbExecutionParams.numWorkers).
mapPartitions {
nameAndLabeledPointSets =>
val watches = Watches.buildWatches(
nameAndLabeledPointSets.map {
case (name, iter) => (name, processMissingValues(iter,
xgbExecutionParams.missing, xgbExecutionParams.allowNonZeroForMissing))
},
getCacheDirName(xgbExecutionParams.useExternalMemory))
buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, xgbExecutionParams.obj,
xgbExecutionParams.eval, prevBooster)
}.cache()
}
}
private def trainForRanking(
trainingData: RDD[Array[XGBLabeledPoint]],
xgbExecutionParam: XGBoostExecutionParams,
rabitEnv: java.util.Map[String, String],
prevBooster: Booster,
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
if (evalSetsMap.isEmpty) {
trainingData.mapPartitions(labeledPointGroups => {
val watches = Watches.buildWatchesWithGroup(xgbExecutionParam,
processMissingValuesWithGroup(labeledPointGroups, xgbExecutionParam.missing,
xgbExecutionParam.allowNonZeroForMissing),
getCacheDirName(xgbExecutionParam.useExternalMemory))
buildDistributedBooster(watches, xgbExecutionParam, rabitEnv,
xgbExecutionParam.obj, xgbExecutionParam.eval, prevBooster)
}).cache()
} else {
coPartitionGroupSets(trainingData, evalSetsMap, xgbExecutionParam.numWorkers).mapPartitions(
labeledPointGroupSets => {
val watches = Watches.buildWatchesWithGroup(
labeledPointGroupSets.map {
case (name, iter) => (name, processMissingValuesWithGroup(iter,
xgbExecutionParam.missing, xgbExecutionParam.allowNonZeroForMissing))
},
getCacheDirName(xgbExecutionParam.useExternalMemory))
buildDistributedBooster(watches, xgbExecutionParam, rabitEnv,
xgbExecutionParam.obj,
xgbExecutionParam.eval,
prevBooster)
}).cache()
}
}
private def cacheData(ifCacheDataBoolean: Boolean, input: RDD[_]): RDD[_] = {
if (ifCacheDataBoolean) input.persist(StorageLevel.MEMORY_AND_DISK) else input
}
private def composeInputData(
trainingData: RDD[XGBLabeledPoint],
ifCacheDataBoolean: Boolean,
hasGroup: Boolean,
nWorkers: Int): Either[RDD[Array[XGBLabeledPoint]], RDD[XGBLabeledPoint]] = {
if (hasGroup) {
val repartitionedData = repartitionForTrainingGroup(trainingData, nWorkers)
Left(cacheData(ifCacheDataBoolean, repartitionedData).
asInstanceOf[RDD[Array[XGBLabeledPoint]]])
} else {
Right(cacheData(ifCacheDataBoolean, trainingData).asInstanceOf[RDD[XGBLabeledPoint]])
}
}
/**
* @return A tuple of the booster and the metrics used to build training summary
*/
@throws(classOf[XGBoostError])
private[spark] def trainDistributed(
trainingData: RDD[XGBLabeledPoint],
params: Map[String, Any],
hasGroup: Boolean = false,
evalSetsMap: Map[String, RDD[XGBLabeledPoint]] = Map()):
sc: SparkContext,
buildTrainingData: XGBoostExecutionParams => (RDD[Watches], Option[RDD[_]]),
params: Map[String, Any]):
(Booster, Map[String, Array[Float]]) = {
logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}")
val xgbParamsFactory = new XGBoostExecutionParamsFactory(params, trainingData.sparkContext)
val xgbParamsFactory = new XGBoostExecutionParamsFactory(params, sc)
val xgbExecParams = xgbParamsFactory.buildXGBRuntimeParams
val xgbRabitParams = xgbParamsFactory.buildRabitParams.asJava
val sc = trainingData.sparkContext
val transformedTrainingData = composeInputData(trainingData, xgbExecParams.cacheTrainingSet,
hasGroup, xgbExecParams.numWorkers)
val prevBooster = xgbExecParams.checkpointParam.map { checkpointParam =>
val checkpointManager = new ExternalCheckpointManager(
checkpointParam.checkpointPath,
@ -588,6 +370,10 @@ object XGBoost extends Serializable {
checkpointManager.cleanUpHigherVersions(xgbExecParams.numRounds)
checkpointManager.loadCheckpointAsScalaBooster()
}.orNull
// Get the training data RDD and the cachedRDD
val (trainingRDD, optionalCachedRDD) = buildTrainingData(xgbExecParams)
try {
// Train for every ${savingRound} rounds and save the partially completed booster
val tracker = startTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
@ -599,13 +385,21 @@ object XGBoost extends Serializable {
tracker.getWorkerEnvs().putAll(xgbRabitParams)
val rabitEnv = tracker.getWorkerEnvs
val boostersAndMetrics = if (hasGroup) {
trainForRanking(transformedTrainingData.left.get, xgbExecParams, rabitEnv, prevBooster,
evalSetsMap)
} else {
trainForNonRanking(transformedTrainingData.right.get, xgbExecParams, rabitEnv,
prevBooster, evalSetsMap)
}
val boostersAndMetrics = trainingRDD.mapPartitions { iter => {
var optionWatches: Option[Watches] = None
// take the first Watches to train
if (iter.hasNext) {
optionWatches = Some(iter.next())
}
optionWatches.map { watches => buildDistributedBooster(watches, xgbExecParams, rabitEnv,
xgbExecParams.obj, xgbExecParams.eval, prevBooster)}
.getOrElse(throw new RuntimeException("No Watches to train"))
}}.cache()
val sparkJobThread = new Thread() {
override def run() {
// force the job
@ -642,85 +436,11 @@ object XGBoost extends Serializable {
// if the job was aborted due to an exception
logger.error("the job was aborted due to ", t)
if (xgbExecParams.killSparkContextOnWorkerFailure) {
trainingData.sparkContext.stop()
sc.stop()
}
throw t
} finally {
uncacheTrainingData(xgbExecParams.cacheTrainingSet, transformedTrainingData)
}
}
private def uncacheTrainingData(
cacheTrainingSet: Boolean,
transformedTrainingData: Either[RDD[Array[XGBLabeledPoint]], RDD[XGBLabeledPoint]]): Unit = {
if (cacheTrainingSet) {
if (transformedTrainingData.isLeft) {
transformedTrainingData.left.get.unpersist()
} else {
transformedTrainingData.right.get.unpersist()
}
}
}
private def aggByGroupInfo(trainingData: RDD[XGBLabeledPoint]) = {
val normalGroups: RDD[Array[XGBLabeledPoint]] = trainingData.mapPartitions(
// LabeledPointGroupIterator returns (Boolean, Array[XGBLabeledPoint])
new LabeledPointGroupIterator(_)).filter(!_.isEdgeGroup).map(_.points)
// edge groups with partition id.
val edgeGroups: RDD[(Int, XGBLabeledPointGroup)] = trainingData.mapPartitions(
new LabeledPointGroupIterator(_)).filter(_.isEdgeGroup).map(
group => (TaskContext.getPartitionId(), group))
// group chunks from different partitions together by group id in XGBLabeledPoint.
// use groupBy instead of aggregateBy since all groups within a partition have unique group ids.
val stitchedGroups: RDD[Array[XGBLabeledPoint]] = edgeGroups.groupBy(_._2.groupId).map(
groups => {
val it: Iterable[(Int, XGBLabeledPointGroup)] = groups._2
// sorted by partition id and merge list of Array[XGBLabeledPoint] into one array
it.toArray.sortBy(_._1).flatMap(_._2.points)
})
normalGroups.union(stitchedGroups)
}
private[spark] def repartitionForTrainingGroup(
trainingData: RDD[XGBLabeledPoint], nWorkers: Int): RDD[Array[XGBLabeledPoint]] = {
val allGroups = aggByGroupInfo(trainingData)
logger.info(s"repartitioning training group set to $nWorkers partitions")
allGroups.repartition(nWorkers)
}
private def coPartitionGroupSets(
aggedTrainingSet: RDD[Array[XGBLabeledPoint]],
evalSets: Map[String, RDD[XGBLabeledPoint]],
nWorkers: Int): RDD[(String, Iterator[Array[XGBLabeledPoint]])] = {
val repartitionedDatasets = Map("train" -> aggedTrainingSet) ++ evalSets.map {
case (name, rdd) => {
val aggedRdd = aggByGroupInfo(rdd)
if (aggedRdd.getNumPartitions != nWorkers) {
name -> aggedRdd.repartition(nWorkers)
} else {
name -> aggedRdd
}
}
}
repartitionedDatasets.foldLeft(aggedTrainingSet.sparkContext.parallelize(
Array.fill[(String, Iterator[Array[XGBLabeledPoint]])](nWorkers)(null), nWorkers)){
case (rddOfIterWrapper, (name, rddOfIter)) =>
rddOfIterWrapper.zipPartitions(rddOfIter){
(itrWrapper, itr) =>
if (!itr.hasNext) {
logger.error("when specifying eval sets as dataframes, you have to ensure that " +
"the number of elements in each dataframe is larger than the number of workers")
throw new Exception("too few elements in evaluation sets")
}
val itrArray = itrWrapper.toArray
if (itrArray.head != null) {
new IteratorWrapper(itrArray :+ (name -> itr))
} else {
new IteratorWrapper(Array(name -> itr))
}
}
optionalCachedRDD.foreach(_.unpersist())
}
}
@ -753,7 +473,7 @@ object XGBoost extends Serializable {
}
private class Watches private(
class Watches private(
val datasets: Array[DMatrix],
val names: Array[String],
val cacheDirName: Option[String]) {
@ -964,50 +684,4 @@ private object Watches {
}
}
/**
* Within each RDD partition, group the <code>XGBLabeledPoint</code> by group id.</p>
* And the first and the last groups may not have all the items due to the data partition.
* <code>LabeledPointGroupIterator</code> orginaizes data in a tuple format:
* (isFistGroup || isLastGroup, Array[XGBLabeledPoint]).</p>
* The edge groups across partitions can be stitched together later.
* @param base collection of <code>XGBLabeledPoint</code>
*/
private[spark] class LabeledPointGroupIterator(base: Iterator[XGBLabeledPoint])
extends AbstractIterator[XGBLabeledPointGroup] {
private var firstPointOfNextGroup: XGBLabeledPoint = null
private var isNewGroup = false
override def hasNext: Boolean = {
base.hasNext || isNewGroup
}
override def next(): XGBLabeledPointGroup = {
val builder = mutable.ArrayBuilder.make[XGBLabeledPoint]
var isFirstGroup = true
if (firstPointOfNextGroup != null) {
builder += firstPointOfNextGroup
isFirstGroup = false
}
isNewGroup = false
while (!isNewGroup && base.hasNext) {
val point = base.next()
val groupId = if (firstPointOfNextGroup != null) firstPointOfNextGroup.group else point.group
firstPointOfNextGroup = point
if (point.group == groupId) {
// add to current group
builder += point
} else {
// start a new group
isNewGroup = true
}
}
val isLastGroup = !isNewGroup
val result = builder.result()
val group = XGBLabeledPointGroup(result(0).group, result, isFirstGroup || isLastGroup)
group
}
}

View File

@ -38,7 +38,7 @@ import scala.collection.{AbstractIterator, Iterator, mutable}
class XGBoostClassifier (
override val uid: String,
private val xgboostParams: Map[String, Any])
private[spark] val xgboostParams: Map[String, Any])
extends ProbabilisticClassifier[Vector, XGBoostClassifier, XGBoostClassificationModel]
with XGBoostClassifierParams with DefaultParamsWritable {
@ -176,26 +176,15 @@ class XGBoostClassifier (
"\'num_class\' in xgboost params.")
}
val weight = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val baseMargin = if (!isDefined(baseMarginCol) || $(baseMarginCol).isEmpty) {
lit(Float.NaN)
} else {
col($(baseMarginCol))
}
val trainingSet: RDD[XGBLabeledPoint] = DataUtils.convertDataFrameToXGBLabeledPointRDDs(
col($(labelCol)), col($(featuresCol)), weight, baseMargin,
None, $(numWorkers), needDeterministicRepartitioning, dataset.asInstanceOf[DataFrame]).head
val evalRDDMap = getEvalSets(xgboostParams).map {
case (name, dataFrame) => (name,
DataUtils.convertDataFrameToXGBLabeledPointRDDs(col($(labelCol)), col($(featuresCol)),
weight, baseMargin, None, $(numWorkers), needDeterministicRepartitioning, dataFrame).head)
}
// Packing with all params plus params user defined
val derivedXGBParamMap = xgboostParams ++ MLlib2XGBoostParams
val buildTrainingData = PreXGBoost.buildDatasetToRDD(this, dataset, derivedXGBParamMap)
transformSchema(dataset.schema, logging = true)
val derivedXGBParamMap = MLlib2XGBoostParams
// All non-null param maps in XGBoostClassifier are in derivedXGBParamMap.
val (_booster, _metrics) = XGBoost.trainDistributed(trainingSet, derivedXGBParamMap,
hasGroup = false, evalRDDMap)
val (_booster, _metrics) = XGBoost.trainDistributed(dataset.sparkSession.sparkContext,
buildTrainingData, derivedXGBParamMap)
val model = new XGBoostClassificationModel(uid, _numClasses, _booster)
val summary = XGBoostTrainingSummary(_metrics)
model.setSummary(summary)
@ -265,7 +254,7 @@ class XGBoostClassificationModel private[ml](
*/
override def predict(features: Vector): Double = {
import DataUtils._
val dm = new DMatrix(XGBoost.processMissingValues(
val dm = new DMatrix(processMissingValues(
Iterator(features.asXGB),
$(missing),
$(allowNonZeroForMissing)
@ -324,7 +313,7 @@ class XGBoostClassificationModel private[ml](
}
val dm = new DMatrix(
XGBoost.processMissingValues(
processMissingValues(
features.map(_.asXGB),
$(missing),
$(allowNonZeroForMissing)

View File

@ -171,27 +171,16 @@ class XGBoostRegressor (
set(objectiveType, "regression")
}
val weight = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val baseMargin = if (!isDefined(baseMarginCol) || $(baseMarginCol).isEmpty) {
lit(Float.NaN)
} else {
col($(baseMarginCol))
}
val group = if (!isDefined(groupCol) || $(groupCol).isEmpty) lit(-1) else col($(groupCol))
val trainingSet: RDD[XGBLabeledPoint] = DataUtils.convertDataFrameToXGBLabeledPointRDDs(
col($(labelCol)), col($(featuresCol)), weight, baseMargin, Some(group),
$(numWorkers), needDeterministicRepartitioning, dataset.asInstanceOf[DataFrame]).head
val evalRDDMap = getEvalSets(xgboostParams).map {
case (name, dataFrame) => (name,
DataUtils.convertDataFrameToXGBLabeledPointRDDs(col($(labelCol)), col($(featuresCol)),
weight, baseMargin, Some(group), $(numWorkers), needDeterministicRepartitioning,
dataFrame).head)
}
transformSchema(dataset.schema, logging = true)
val derivedXGBParamMap = MLlib2XGBoostParams
// Packing with all params plus params user defined
val derivedXGBParamMap = xgboostParams ++ MLlib2XGBoostParams
val buildTrainingData = PreXGBoost.buildDatasetToRDD(this, dataset, derivedXGBParamMap)
// All non-null param maps in XGBoostRegressor are in derivedXGBParamMap.
val (_booster, _metrics) = XGBoost.trainDistributed(trainingSet, derivedXGBParamMap,
hasGroup = group != lit(-1), evalRDDMap)
val (_booster, _metrics) = XGBoost.trainDistributed(dataset.sparkSession.sparkContext,
buildTrainingData, derivedXGBParamMap)
val model = new XGBoostRegressionModel(uid, _booster)
val summary = XGBoostTrainingSummary(_metrics)
model.setSummary(summary)
@ -260,7 +249,7 @@ class XGBoostRegressionModel private[ml] (
*/
override def predict(features: Vector): Double = {
import DataUtils._
val dm = new DMatrix(XGBoost.processMissingValues(
val dm = new DMatrix(processMissingValues(
Iterator(features.asXGB),
$(missing),
$(allowNonZeroForMissing)
@ -301,7 +290,7 @@ class XGBoostRegressionModel private[ml] (
}
val dm = new DMatrix(
XGBoost.processMissingValues(
processMissingValues(
features.map(_.asXGB),
$(missing),
$(allowNonZeroForMissing)

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014,2021 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -14,24 +14,20 @@
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
package ml.dmlc.xgboost4j.scala.spark.params
import ml.dmlc.xgboost4j.scala.spark.params._
import org.apache.spark.ml.param.shared.HasWeightCol
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasWeightCol}
private[spark] sealed trait XGBoostEstimatorCommon extends GeneralParams with LearningTaskParams
with BoosterParams with RabitParams with ParamMapFuncs with NonParamVariables {
with BoosterParams with RabitParams with ParamMapFuncs with NonParamVariables with HasWeightCol
with HasBaseMarginCol with HasLeafPredictionCol with HasContribPredictionCol with HasFeaturesCol
with HasLabelCol {
def needDeterministicRepartitioning: Boolean = {
getCheckpointPath != null && getCheckpointPath.nonEmpty && getCheckpointInterval > 0
}
}
private[spark] trait XGBoostClassifierParams extends HasWeightCol with HasBaseMarginCol
with HasNumClass with HasLeafPredictionCol with HasContribPredictionCol
with XGBoostEstimatorCommon
private[spark] trait XGBoostClassifierParams extends XGBoostEstimatorCommon with HasNumClass
private[spark] trait XGBoostRegressorParams extends HasBaseMarginCol with HasWeightCol
with HasGroupCol with HasLeafPredictionCol with HasContribPredictionCol
with XGBoostEstimatorCommon
private[spark] trait XGBoostRegressorParams extends XGBoostEstimatorCommon with HasGroupCol

View File

@ -18,6 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark
import org.apache.spark.ml.linalg.Vectors
import org.scalatest.FunSuite
import ml.dmlc.xgboost4j.scala.spark.DataUtils.PackedParams
import org.apache.spark.sql.functions._
@ -55,13 +56,13 @@ class DeterministicPartitioningSuite extends FunSuite with TmpFolderPerSuite wit
resultDF
})
val transformedRDDs = transformedDFs.map(df => DataUtils.convertDataFrameToXGBLabeledPointRDDs(
col("label"),
col("features"),
lit(1.0),
lit(Float.NaN),
None,
numWorkers,
deterministicPartition = true,
PackedParams(col("label"),
col("features"),
lit(1.0),
lit(Float.NaN),
None,
numWorkers,
deterministicPartition = true),
df
).head)
val resultsMaps = transformedRDDs.map(rdd => rdd.mapPartitionsWithIndex {
@ -90,14 +91,13 @@ class DeterministicPartitioningSuite extends FunSuite with TmpFolderPerSuite wit
val df = ss.createDataFrame(sc.parallelize(dataset)).toDF("id", "label", "features")
val dfRepartitioned = DataUtils.convertDataFrameToXGBLabeledPointRDDs(
col("label"),
col("features"),
lit(1.0),
lit(Float.NaN),
None,
10,
deterministicPartition = true,
df
PackedParams(col("label"),
col("features"),
lit(1.0),
lit(Float.NaN),
None,
10,
deterministicPartition = true), df
).head
val partitionsSizes = dfRepartitioned

View File

@ -17,12 +17,14 @@
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.java.XGBoostError
import scala.util.Random
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import ml.dmlc.xgboost4j.scala.DMatrix
import org.apache.spark.TaskContext
import org.scalatest.FunSuite
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.sql.functions.lit
@ -30,13 +32,14 @@ class XGBoostGeneralSuite extends FunSuite with TmpFolderPerSuite with PerTest {
test("distributed training with the specified worker number") {
val trainingRDD = sc.parallelize(Classification.train)
val buildTrainingRDD = PreXGBoost.buildRDDLabeledPointToRDDWatches(trainingRDD)
val (booster, metrics) = XGBoost.trainDistributed(
trainingRDD,
sc,
buildTrainingRDD,
List("eta" -> "1", "max_depth" -> "6",
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
"custom_eval" -> null, "custom_obj" -> null, "use_external_memory" -> false,
"missing" -> Float.NaN).toMap,
hasGroup = false)
"missing" -> Float.NaN).toMap)
assert(booster != null)
}
@ -179,7 +182,7 @@ class XGBoostGeneralSuite extends FunSuite with TmpFolderPerSuite with PerTest {
// test different splits to cover the corner cases.
for (split <- 1 to 20) {
val trainingRDD = sc.parallelize(Ranking.train, split)
val traingGroupsRDD = XGBoost.repartitionForTrainingGroup(trainingRDD, 4)
val traingGroupsRDD = PreXGBoost.repartitionForTrainingGroup(trainingRDD, 4)
val trainingGroups: Array[Array[XGBLabeledPoint]] = traingGroupsRDD.collect()
// check the the order of the groups with group id.
// Ranking.train has 20 groups
@ -201,18 +204,19 @@ class XGBoostGeneralSuite extends FunSuite with TmpFolderPerSuite with PerTest {
// make one partition empty for testing
it.filter(_ => TaskContext.getPartitionId() != 3)
})
XGBoost.repartitionForTrainingGroup(trainingRDD, 4)
PreXGBoost.repartitionForTrainingGroup(trainingRDD, 4)
}
test("distributed training with group data") {
val trainingRDD = sc.parallelize(Ranking.train, 5)
val buildTrainingRDD = PreXGBoost.buildRDDLabeledPointToRDDWatches(trainingRDD, hasGroup = true)
val (booster, _) = XGBoost.trainDistributed(
trainingRDD,
sc,
buildTrainingRDD,
List("eta" -> "1", "max_depth" -> "6",
"objective" -> "rank:pairwise", "num_round" -> 5, "num_workers" -> numWorkers,
"custom_eval" -> null, "custom_obj" -> null, "use_external_memory" -> false,
"missing" -> Float.NaN).toMap,
hasGroup = true)
"missing" -> Float.NaN).toMap)
assert(booster != null)
}