[jvm-packages] enable deterministic repartitioning when checkpoint is enabled (#4807)

* do reparititoning in DataUtil

* keep previous behavior of partitioning without checkpoint

* deterministic repartitioning

* change
This commit is contained in:
Nan Zhu 2019-09-19 15:21:05 -07:00 committed by GitHub
parent 277e25797b
commit fc8c9b0521
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 204 additions and 28 deletions

View File

@ -18,6 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.spark.HashPartitioner
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.Param
@ -73,12 +74,81 @@ object DataUtils extends Serializable {
}
}
private def featureValueOfDenseVector(rowHashCode: Int, features: DenseVector): Float = {
val featureId = {
if (rowHashCode > 0) {
rowHashCode % features.size
} else {
// prevent overflow
math.abs(rowHashCode + 1) % features.size
}
}
features.values(featureId).toFloat
}
private def featureValueOfSparseVector(rowHashCode: Int, features: SparseVector): Float = {
val featureId = {
if (rowHashCode > 0) {
rowHashCode % features.indices.length
} else {
// prevent overflow
math.abs(rowHashCode + 1) % features.indices.length
}
}
features.values(featureId).toFloat
}
private def calculatePartitionKey(row: Row, numPartitions: Int): Int = {
val Row(_, features: Vector, _, _) = row
val rowHashCode = row.hashCode()
val featureValue = features match {
case denseVector: DenseVector =>
featureValueOfDenseVector(rowHashCode, denseVector)
case sparseVector: SparseVector =>
featureValueOfSparseVector(rowHashCode, sparseVector)
}
math.abs((rowHashCode.toLong + featureValue).toString.hashCode % numPartitions)
}
private def attachPartitionKey(
row: Row,
deterministicPartition: Boolean,
numWorkers: Int,
xgbLp: XGBLabeledPoint): (Int, XGBLabeledPoint) = {
if (deterministicPartition) {
(calculatePartitionKey(row, numWorkers), xgbLp)
} else {
(1, xgbLp)
}
}
private def repartitionRDDs(
deterministicPartition: Boolean,
numWorkers: Int,
arrayOfRDDs: Array[RDD[(Int, XGBLabeledPoint)]]): Array[RDD[XGBLabeledPoint]] = {
if (deterministicPartition) {
arrayOfRDDs.map {rdd => rdd.partitionBy(new HashPartitioner(numWorkers))}.map {
rdd => rdd.map(_._2)
}
} else {
arrayOfRDDs.map(rdd => {
if (rdd.getNumPartitions != numWorkers) {
rdd.map(_._2).repartition(numWorkers)
} else {
rdd.map(_._2)
}
})
}
}
private[spark] def convertDataFrameToXGBLabeledPointRDDs(
labelCol: Column,
featuresCol: Column,
weight: Column,
baseMargin: Column,
group: Option[Column],
numWorkers: Int,
deterministicPartition: Boolean,
dataFrames: DataFrame*): Array[RDD[XGBLabeledPoint]] = {
val selectedColumns = group.map(groupCol => Seq(labelCol.cast(FloatType),
featuresCol,
@ -88,22 +158,26 @@ object DataUtils extends Serializable {
featuresCol,
weight.cast(FloatType),
baseMargin.cast(FloatType)))
dataFrames.toArray.map {
val arrayOfRDDs = dataFrames.toArray.map {
df => df.select(selectedColumns: _*).rdd.map {
case Row(label: Float, features: Vector, weight: Float, group: Int, baseMargin: Float) =>
case row @ Row(label: Float, features: Vector, weight: Float, group: Int,
baseMargin: Float) =>
val (indices, values) = features match {
case v: SparseVector => (v.indices, v.values.map(_.toFloat))
case v: DenseVector => (null, v.values.map(_.toFloat))
}
XGBLabeledPoint(label, indices, values, weight, group, baseMargin)
case Row(label: Float, features: Vector, weight: Float, baseMargin: Float) =>
val xgbLp = XGBLabeledPoint(label, indices, values, weight, group, baseMargin)
attachPartitionKey(row, deterministicPartition, numWorkers, xgbLp)
case row @ Row(label: Float, features: Vector, weight: Float, baseMargin: Float) =>
val (indices, values) = features match {
case v: SparseVector => (v.indices, v.values.map(_.toFloat))
case v: DenseVector => (null, v.values.map(_.toFloat))
}
XGBLabeledPoint(label, indices, values, weight, baseMargin = baseMargin)
val xgbLp = XGBLabeledPoint(label, indices, values, weight, baseMargin = baseMargin)
attachPartitionKey(row, deterministicPartition, numWorkers, xgbLp)
}
}
repartitionRDDs(deterministicPartition, numWorkers, arrayOfRDDs)
}
}

View File

@ -475,8 +475,7 @@ object XGBoost extends Serializable {
Left(cacheData(ifCacheDataBoolean, repartitionedData).
asInstanceOf[RDD[Array[XGBLabeledPoint]]])
} else {
val repartitionedData = repartitionForTraining(trainingData, nWorkers)
Right(cacheData(ifCacheDataBoolean, repartitionedData).asInstanceOf[RDD[XGBLabeledPoint]])
Right(cacheData(ifCacheDataBoolean, trainingData).asInstanceOf[RDD[XGBLabeledPoint]])
}
}
@ -568,15 +567,6 @@ object XGBoost extends Serializable {
}
}
private[spark] def repartitionForTraining(trainingData: RDD[XGBLabeledPoint], nWorkers: Int) = {
if (trainingData.getNumPartitions != nWorkers) {
logger.info(s"repartitioning training set to $nWorkers partitions")
trainingData.repartition(nWorkers)
} else {
trainingData
}
}
private def aggByGroupInfo(trainingData: RDD[XGBLabeledPoint]) = {
val normalGroups: RDD[Array[XGBLabeledPoint]] = trainingData.mapPartitions(
// LabeledPointGroupIterator returns (Boolean, Array[XGBLabeledPoint])

View File

@ -37,10 +37,6 @@ import org.json4s.DefaultFormats
import scala.collection.JavaConverters._
import scala.collection.{AbstractIterator, Iterator, mutable}
private[spark] trait XGBoostClassifierParams extends GeneralParams with LearningTaskParams
with BoosterParams with HasWeightCol with HasBaseMarginCol with HasNumClass with ParamMapFuncs
with HasLeafPredictionCol with HasContribPredictionCol with NonParamVariables
class XGBoostClassifier (
override val uid: String,
private val xgboostParams: Map[String, Any])
@ -182,11 +178,11 @@ class XGBoostClassifier (
val trainingSet: RDD[XGBLabeledPoint] = DataUtils.convertDataFrameToXGBLabeledPointRDDs(
col($(labelCol)), col($(featuresCol)), weight, baseMargin,
None, dataset.asInstanceOf[DataFrame]).head
None, $(numWorkers), needDeterministicRepartitioning, dataset.asInstanceOf[DataFrame]).head
val evalRDDMap = getEvalSets(xgboostParams).map {
case (name, dataFrame) => (name,
DataUtils.convertDataFrameToXGBLabeledPointRDDs(col($(labelCol)), col($(featuresCol)),
weight, baseMargin, None, dataFrame).head)
weight, baseMargin, None, $(numWorkers), needDeterministicRepartitioning, dataFrame).head)
}
transformSchema(dataset.schema, logging = true)
val derivedXGBParamMap = MLlib2XGBoostParams

View File

@ -0,0 +1,37 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.scala.spark.params._
import org.apache.spark.ml.param.shared.HasWeightCol
private[spark] sealed trait XGBoostEstimatorCommon extends GeneralParams with LearningTaskParams
with BoosterParams with ParamMapFuncs with NonParamVariables {
def needDeterministicRepartitioning: Boolean = {
getCheckpointPath.nonEmpty && getCheckpointInterval > 0
}
}
private[spark] trait XGBoostClassifierParams extends HasWeightCol with HasBaseMarginCol
with HasNumClass with HasLeafPredictionCol with HasContribPredictionCol
with XGBoostEstimatorCommon
private[spark] trait XGBoostRegressorParams extends HasBaseMarginCol with HasWeightCol
with HasGroupCol with HasLeafPredictionCol with HasContribPredictionCol
with XGBoostEstimatorCommon

View File

@ -41,10 +41,6 @@ import scala.collection.mutable.ListBuffer
import org.apache.spark.broadcast.Broadcast
private[spark] trait XGBoostRegressorParams extends GeneralParams with BoosterParams
with LearningTaskParams with HasBaseMarginCol with HasWeightCol with HasGroupCol
with ParamMapFuncs with HasLeafPredictionCol with HasContribPredictionCol with NonParamVariables
class XGBoostRegressor (
override val uid: String,
private val xgboostParams: Map[String, Any])
@ -178,11 +174,12 @@ class XGBoostRegressor (
val group = if (!isDefined(groupCol) || $(groupCol).isEmpty) lit(-1) else col($(groupCol))
val trainingSet: RDD[XGBLabeledPoint] = DataUtils.convertDataFrameToXGBLabeledPointRDDs(
col($(labelCol)), col($(featuresCol)), weight, baseMargin, Some(group),
dataset.asInstanceOf[DataFrame]).head
$(numWorkers), needDeterministicRepartitioning, dataset.asInstanceOf[DataFrame]).head
val evalRDDMap = getEvalSets(xgboostParams).map {
case (name, dataFrame) => (name,
DataUtils.convertDataFrameToXGBLabeledPointRDDs(col($(labelCol)), col($(featuresCol)),
weight, baseMargin, Some(group), dataFrame).head)
weight, baseMargin, Some(group), $(numWorkers), needDeterministicRepartitioning,
dataFrame).head)
}
transformSchema(dataset.schema, logging = true)
val derivedXGBParamMap = MLlib2XGBoostParams

View File

@ -0,0 +1,82 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import org.scalatest.FunSuite
import org.apache.spark.sql.functions._
class DeterministicPartitioningSuite extends FunSuite with TmpFolderPerSuite with PerTest {
test("perform deterministic partitioning when checkpointInternal and" +
" checkpointPath is set (Classifier)") {
val tmpPath = createTmpFolder("model1").toAbsolutePath.toString
val paramMap = Map("eta" -> "1", "max_depth" -> 2,
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
"checkpoint_interval" -> 2, "num_workers" -> numWorkers)
val xgbClassifier = new XGBoostClassifier(paramMap)
assert(xgbClassifier.needDeterministicRepartitioning)
}
test("perform deterministic partitioning when checkpointInternal and" +
" checkpointPath is set (Regressor)") {
val tmpPath = createTmpFolder("model1").toAbsolutePath.toString
val paramMap = Map("eta" -> "1", "max_depth" -> 2,
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
"checkpoint_interval" -> 2, "num_workers" -> numWorkers)
val xgbRegressor = new XGBoostRegressor(paramMap)
assert(xgbRegressor.needDeterministicRepartitioning)
}
test("deterministic partitioning takes effect with various parts of data") {
val trainingDF = buildDataFrame(Classification.train)
// the test idea is that, we apply a chain of repartitions over trainingDFs but they
// have to produce the identical RDDs
val transformedDFs = (1 until 6).map(shuffleCount => {
var resultDF = trainingDF
for (i <- 0 until shuffleCount) {
resultDF = resultDF.repartition(numWorkers)
}
resultDF
})
val transformedRDDs = transformedDFs.map(df => DataUtils.convertDataFrameToXGBLabeledPointRDDs(
col("label"),
col("features"),
lit(1.0),
lit(Float.NaN),
None,
numWorkers,
deterministicPartition = true,
df
).head)
val resultsMaps = transformedRDDs.map(rdd => rdd.mapPartitionsWithIndex {
case (partitionIndex, labelPoints) =>
Iterator((partitionIndex, labelPoints.toList))
}.collect().toMap)
resultsMaps.foldLeft(resultsMaps.head) { case (map1, map2) =>
assert(map1.keys.toSet === map2.keys.toSet)
for ((parIdx, labeledPoints) <- map1) {
val sortedA = labeledPoints.sortBy(_.hashCode())
val sortedB = map2(parIdx).sortBy(_.hashCode())
assert(sortedA.length === sortedB.length)
assert(sortedA.indices.forall(idx =>
sortedA(idx).values.toSet === sortedB(idx).values.toSet))
}
map2
}
}
}