[jvm-packages] enable deterministic repartitioning when checkpoint is enabled (#4807)
* do reparititoning in DataUtil * keep previous behavior of partitioning without checkpoint * deterministic repartitioning * change
This commit is contained in:
parent
277e25797b
commit
fc8c9b0521
@ -18,6 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
|
||||
import org.apache.spark.HashPartitioner
|
||||
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
|
||||
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
|
||||
import org.apache.spark.ml.param.Param
|
||||
@ -73,12 +74,81 @@ object DataUtils extends Serializable {
|
||||
}
|
||||
}
|
||||
|
||||
private def featureValueOfDenseVector(rowHashCode: Int, features: DenseVector): Float = {
|
||||
val featureId = {
|
||||
if (rowHashCode > 0) {
|
||||
rowHashCode % features.size
|
||||
} else {
|
||||
// prevent overflow
|
||||
math.abs(rowHashCode + 1) % features.size
|
||||
}
|
||||
}
|
||||
features.values(featureId).toFloat
|
||||
}
|
||||
|
||||
private def featureValueOfSparseVector(rowHashCode: Int, features: SparseVector): Float = {
|
||||
val featureId = {
|
||||
if (rowHashCode > 0) {
|
||||
rowHashCode % features.indices.length
|
||||
} else {
|
||||
// prevent overflow
|
||||
math.abs(rowHashCode + 1) % features.indices.length
|
||||
}
|
||||
}
|
||||
features.values(featureId).toFloat
|
||||
}
|
||||
|
||||
private def calculatePartitionKey(row: Row, numPartitions: Int): Int = {
|
||||
val Row(_, features: Vector, _, _) = row
|
||||
val rowHashCode = row.hashCode()
|
||||
val featureValue = features match {
|
||||
case denseVector: DenseVector =>
|
||||
featureValueOfDenseVector(rowHashCode, denseVector)
|
||||
case sparseVector: SparseVector =>
|
||||
featureValueOfSparseVector(rowHashCode, sparseVector)
|
||||
}
|
||||
math.abs((rowHashCode.toLong + featureValue).toString.hashCode % numPartitions)
|
||||
}
|
||||
|
||||
private def attachPartitionKey(
|
||||
row: Row,
|
||||
deterministicPartition: Boolean,
|
||||
numWorkers: Int,
|
||||
xgbLp: XGBLabeledPoint): (Int, XGBLabeledPoint) = {
|
||||
if (deterministicPartition) {
|
||||
(calculatePartitionKey(row, numWorkers), xgbLp)
|
||||
} else {
|
||||
(1, xgbLp)
|
||||
}
|
||||
}
|
||||
|
||||
private def repartitionRDDs(
|
||||
deterministicPartition: Boolean,
|
||||
numWorkers: Int,
|
||||
arrayOfRDDs: Array[RDD[(Int, XGBLabeledPoint)]]): Array[RDD[XGBLabeledPoint]] = {
|
||||
if (deterministicPartition) {
|
||||
arrayOfRDDs.map {rdd => rdd.partitionBy(new HashPartitioner(numWorkers))}.map {
|
||||
rdd => rdd.map(_._2)
|
||||
}
|
||||
} else {
|
||||
arrayOfRDDs.map(rdd => {
|
||||
if (rdd.getNumPartitions != numWorkers) {
|
||||
rdd.map(_._2).repartition(numWorkers)
|
||||
} else {
|
||||
rdd.map(_._2)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
private[spark] def convertDataFrameToXGBLabeledPointRDDs(
|
||||
labelCol: Column,
|
||||
featuresCol: Column,
|
||||
weight: Column,
|
||||
baseMargin: Column,
|
||||
group: Option[Column],
|
||||
numWorkers: Int,
|
||||
deterministicPartition: Boolean,
|
||||
dataFrames: DataFrame*): Array[RDD[XGBLabeledPoint]] = {
|
||||
val selectedColumns = group.map(groupCol => Seq(labelCol.cast(FloatType),
|
||||
featuresCol,
|
||||
@ -88,22 +158,26 @@ object DataUtils extends Serializable {
|
||||
featuresCol,
|
||||
weight.cast(FloatType),
|
||||
baseMargin.cast(FloatType)))
|
||||
dataFrames.toArray.map {
|
||||
val arrayOfRDDs = dataFrames.toArray.map {
|
||||
df => df.select(selectedColumns: _*).rdd.map {
|
||||
case Row(label: Float, features: Vector, weight: Float, group: Int, baseMargin: Float) =>
|
||||
case row @ Row(label: Float, features: Vector, weight: Float, group: Int,
|
||||
baseMargin: Float) =>
|
||||
val (indices, values) = features match {
|
||||
case v: SparseVector => (v.indices, v.values.map(_.toFloat))
|
||||
case v: DenseVector => (null, v.values.map(_.toFloat))
|
||||
}
|
||||
XGBLabeledPoint(label, indices, values, weight, group, baseMargin)
|
||||
case Row(label: Float, features: Vector, weight: Float, baseMargin: Float) =>
|
||||
val xgbLp = XGBLabeledPoint(label, indices, values, weight, group, baseMargin)
|
||||
attachPartitionKey(row, deterministicPartition, numWorkers, xgbLp)
|
||||
case row @ Row(label: Float, features: Vector, weight: Float, baseMargin: Float) =>
|
||||
val (indices, values) = features match {
|
||||
case v: SparseVector => (v.indices, v.values.map(_.toFloat))
|
||||
case v: DenseVector => (null, v.values.map(_.toFloat))
|
||||
}
|
||||
XGBLabeledPoint(label, indices, values, weight, baseMargin = baseMargin)
|
||||
val xgbLp = XGBLabeledPoint(label, indices, values, weight, baseMargin = baseMargin)
|
||||
attachPartitionKey(row, deterministicPartition, numWorkers, xgbLp)
|
||||
}
|
||||
}
|
||||
repartitionRDDs(deterministicPartition, numWorkers, arrayOfRDDs)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -475,8 +475,7 @@ object XGBoost extends Serializable {
|
||||
Left(cacheData(ifCacheDataBoolean, repartitionedData).
|
||||
asInstanceOf[RDD[Array[XGBLabeledPoint]]])
|
||||
} else {
|
||||
val repartitionedData = repartitionForTraining(trainingData, nWorkers)
|
||||
Right(cacheData(ifCacheDataBoolean, repartitionedData).asInstanceOf[RDD[XGBLabeledPoint]])
|
||||
Right(cacheData(ifCacheDataBoolean, trainingData).asInstanceOf[RDD[XGBLabeledPoint]])
|
||||
}
|
||||
}
|
||||
|
||||
@ -568,15 +567,6 @@ object XGBoost extends Serializable {
|
||||
}
|
||||
}
|
||||
|
||||
private[spark] def repartitionForTraining(trainingData: RDD[XGBLabeledPoint], nWorkers: Int) = {
|
||||
if (trainingData.getNumPartitions != nWorkers) {
|
||||
logger.info(s"repartitioning training set to $nWorkers partitions")
|
||||
trainingData.repartition(nWorkers)
|
||||
} else {
|
||||
trainingData
|
||||
}
|
||||
}
|
||||
|
||||
private def aggByGroupInfo(trainingData: RDD[XGBLabeledPoint]) = {
|
||||
val normalGroups: RDD[Array[XGBLabeledPoint]] = trainingData.mapPartitions(
|
||||
// LabeledPointGroupIterator returns (Boolean, Array[XGBLabeledPoint])
|
||||
|
||||
@ -37,10 +37,6 @@ import org.json4s.DefaultFormats
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.{AbstractIterator, Iterator, mutable}
|
||||
|
||||
private[spark] trait XGBoostClassifierParams extends GeneralParams with LearningTaskParams
|
||||
with BoosterParams with HasWeightCol with HasBaseMarginCol with HasNumClass with ParamMapFuncs
|
||||
with HasLeafPredictionCol with HasContribPredictionCol with NonParamVariables
|
||||
|
||||
class XGBoostClassifier (
|
||||
override val uid: String,
|
||||
private val xgboostParams: Map[String, Any])
|
||||
@ -182,11 +178,11 @@ class XGBoostClassifier (
|
||||
|
||||
val trainingSet: RDD[XGBLabeledPoint] = DataUtils.convertDataFrameToXGBLabeledPointRDDs(
|
||||
col($(labelCol)), col($(featuresCol)), weight, baseMargin,
|
||||
None, dataset.asInstanceOf[DataFrame]).head
|
||||
None, $(numWorkers), needDeterministicRepartitioning, dataset.asInstanceOf[DataFrame]).head
|
||||
val evalRDDMap = getEvalSets(xgboostParams).map {
|
||||
case (name, dataFrame) => (name,
|
||||
DataUtils.convertDataFrameToXGBLabeledPointRDDs(col($(labelCol)), col($(featuresCol)),
|
||||
weight, baseMargin, None, dataFrame).head)
|
||||
weight, baseMargin, None, $(numWorkers), needDeterministicRepartitioning, dataFrame).head)
|
||||
}
|
||||
transformSchema(dataset.schema, logging = true)
|
||||
val derivedXGBParamMap = MLlib2XGBoostParams
|
||||
|
||||
@ -0,0 +1,37 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.spark.params._
|
||||
|
||||
import org.apache.spark.ml.param.shared.HasWeightCol
|
||||
|
||||
private[spark] sealed trait XGBoostEstimatorCommon extends GeneralParams with LearningTaskParams
|
||||
with BoosterParams with ParamMapFuncs with NonParamVariables {
|
||||
|
||||
def needDeterministicRepartitioning: Boolean = {
|
||||
getCheckpointPath.nonEmpty && getCheckpointInterval > 0
|
||||
}
|
||||
}
|
||||
|
||||
private[spark] trait XGBoostClassifierParams extends HasWeightCol with HasBaseMarginCol
|
||||
with HasNumClass with HasLeafPredictionCol with HasContribPredictionCol
|
||||
with XGBoostEstimatorCommon
|
||||
|
||||
private[spark] trait XGBoostRegressorParams extends HasBaseMarginCol with HasWeightCol
|
||||
with HasGroupCol with HasLeafPredictionCol with HasContribPredictionCol
|
||||
with XGBoostEstimatorCommon
|
||||
@ -41,10 +41,6 @@ import scala.collection.mutable.ListBuffer
|
||||
|
||||
import org.apache.spark.broadcast.Broadcast
|
||||
|
||||
private[spark] trait XGBoostRegressorParams extends GeneralParams with BoosterParams
|
||||
with LearningTaskParams with HasBaseMarginCol with HasWeightCol with HasGroupCol
|
||||
with ParamMapFuncs with HasLeafPredictionCol with HasContribPredictionCol with NonParamVariables
|
||||
|
||||
class XGBoostRegressor (
|
||||
override val uid: String,
|
||||
private val xgboostParams: Map[String, Any])
|
||||
@ -178,11 +174,12 @@ class XGBoostRegressor (
|
||||
val group = if (!isDefined(groupCol) || $(groupCol).isEmpty) lit(-1) else col($(groupCol))
|
||||
val trainingSet: RDD[XGBLabeledPoint] = DataUtils.convertDataFrameToXGBLabeledPointRDDs(
|
||||
col($(labelCol)), col($(featuresCol)), weight, baseMargin, Some(group),
|
||||
dataset.asInstanceOf[DataFrame]).head
|
||||
$(numWorkers), needDeterministicRepartitioning, dataset.asInstanceOf[DataFrame]).head
|
||||
val evalRDDMap = getEvalSets(xgboostParams).map {
|
||||
case (name, dataFrame) => (name,
|
||||
DataUtils.convertDataFrameToXGBLabeledPointRDDs(col($(labelCol)), col($(featuresCol)),
|
||||
weight, baseMargin, Some(group), dataFrame).head)
|
||||
weight, baseMargin, Some(group), $(numWorkers), needDeterministicRepartitioning,
|
||||
dataFrame).head)
|
||||
}
|
||||
transformSchema(dataset.schema, logging = true)
|
||||
val derivedXGBParamMap = MLlib2XGBoostParams
|
||||
|
||||
@ -0,0 +1,82 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
import org.apache.spark.sql.functions._
|
||||
|
||||
class DeterministicPartitioningSuite extends FunSuite with TmpFolderPerSuite with PerTest {
|
||||
|
||||
test("perform deterministic partitioning when checkpointInternal and" +
|
||||
" checkpointPath is set (Classifier)") {
|
||||
val tmpPath = createTmpFolder("model1").toAbsolutePath.toString
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> 2,
|
||||
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
|
||||
"checkpoint_interval" -> 2, "num_workers" -> numWorkers)
|
||||
val xgbClassifier = new XGBoostClassifier(paramMap)
|
||||
assert(xgbClassifier.needDeterministicRepartitioning)
|
||||
}
|
||||
|
||||
test("perform deterministic partitioning when checkpointInternal and" +
|
||||
" checkpointPath is set (Regressor)") {
|
||||
val tmpPath = createTmpFolder("model1").toAbsolutePath.toString
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> 2,
|
||||
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
|
||||
"checkpoint_interval" -> 2, "num_workers" -> numWorkers)
|
||||
val xgbRegressor = new XGBoostRegressor(paramMap)
|
||||
assert(xgbRegressor.needDeterministicRepartitioning)
|
||||
}
|
||||
|
||||
test("deterministic partitioning takes effect with various parts of data") {
|
||||
val trainingDF = buildDataFrame(Classification.train)
|
||||
// the test idea is that, we apply a chain of repartitions over trainingDFs but they
|
||||
// have to produce the identical RDDs
|
||||
val transformedDFs = (1 until 6).map(shuffleCount => {
|
||||
var resultDF = trainingDF
|
||||
for (i <- 0 until shuffleCount) {
|
||||
resultDF = resultDF.repartition(numWorkers)
|
||||
}
|
||||
resultDF
|
||||
})
|
||||
val transformedRDDs = transformedDFs.map(df => DataUtils.convertDataFrameToXGBLabeledPointRDDs(
|
||||
col("label"),
|
||||
col("features"),
|
||||
lit(1.0),
|
||||
lit(Float.NaN),
|
||||
None,
|
||||
numWorkers,
|
||||
deterministicPartition = true,
|
||||
df
|
||||
).head)
|
||||
val resultsMaps = transformedRDDs.map(rdd => rdd.mapPartitionsWithIndex {
|
||||
case (partitionIndex, labelPoints) =>
|
||||
Iterator((partitionIndex, labelPoints.toList))
|
||||
}.collect().toMap)
|
||||
resultsMaps.foldLeft(resultsMaps.head) { case (map1, map2) =>
|
||||
assert(map1.keys.toSet === map2.keys.toSet)
|
||||
for ((parIdx, labeledPoints) <- map1) {
|
||||
val sortedA = labeledPoints.sortBy(_.hashCode())
|
||||
val sortedB = map2(parIdx).sortBy(_.hashCode())
|
||||
assert(sortedA.length === sortedB.length)
|
||||
assert(sortedA.indices.forall(idx =>
|
||||
sortedA(idx).values.toSet === sortedB(idx).values.toSet))
|
||||
}
|
||||
map2
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user