diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/DataUtils.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/DataUtils.scala index 7f011e9f2..36c76ce15 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/DataUtils.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/DataUtils.scala @@ -18,6 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} +import org.apache.spark.HashPartitioner import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint} import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.Param @@ -73,12 +74,81 @@ object DataUtils extends Serializable { } } + private def featureValueOfDenseVector(rowHashCode: Int, features: DenseVector): Float = { + val featureId = { + if (rowHashCode > 0) { + rowHashCode % features.size + } else { + // prevent overflow + math.abs(rowHashCode + 1) % features.size + } + } + features.values(featureId).toFloat + } + + private def featureValueOfSparseVector(rowHashCode: Int, features: SparseVector): Float = { + val featureId = { + if (rowHashCode > 0) { + rowHashCode % features.indices.length + } else { + // prevent overflow + math.abs(rowHashCode + 1) % features.indices.length + } + } + features.values(featureId).toFloat + } + + private def calculatePartitionKey(row: Row, numPartitions: Int): Int = { + val Row(_, features: Vector, _, _) = row + val rowHashCode = row.hashCode() + val featureValue = features match { + case denseVector: DenseVector => + featureValueOfDenseVector(rowHashCode, denseVector) + case sparseVector: SparseVector => + featureValueOfSparseVector(rowHashCode, sparseVector) + } + math.abs((rowHashCode.toLong + featureValue).toString.hashCode % numPartitions) + } + + private def attachPartitionKey( + row: Row, + deterministicPartition: Boolean, + numWorkers: Int, + xgbLp: XGBLabeledPoint): (Int, XGBLabeledPoint) = { + if (deterministicPartition) { + (calculatePartitionKey(row, numWorkers), xgbLp) + } else { + (1, xgbLp) + } + } + + private def repartitionRDDs( + deterministicPartition: Boolean, + numWorkers: Int, + arrayOfRDDs: Array[RDD[(Int, XGBLabeledPoint)]]): Array[RDD[XGBLabeledPoint]] = { + if (deterministicPartition) { + arrayOfRDDs.map {rdd => rdd.partitionBy(new HashPartitioner(numWorkers))}.map { + rdd => rdd.map(_._2) + } + } else { + arrayOfRDDs.map(rdd => { + if (rdd.getNumPartitions != numWorkers) { + rdd.map(_._2).repartition(numWorkers) + } else { + rdd.map(_._2) + } + }) + } + } + private[spark] def convertDataFrameToXGBLabeledPointRDDs( labelCol: Column, featuresCol: Column, weight: Column, baseMargin: Column, group: Option[Column], + numWorkers: Int, + deterministicPartition: Boolean, dataFrames: DataFrame*): Array[RDD[XGBLabeledPoint]] = { val selectedColumns = group.map(groupCol => Seq(labelCol.cast(FloatType), featuresCol, @@ -88,22 +158,26 @@ object DataUtils extends Serializable { featuresCol, weight.cast(FloatType), baseMargin.cast(FloatType))) - dataFrames.toArray.map { + val arrayOfRDDs = dataFrames.toArray.map { df => df.select(selectedColumns: _*).rdd.map { - case Row(label: Float, features: Vector, weight: Float, group: Int, baseMargin: Float) => + case row @ Row(label: Float, features: Vector, weight: Float, group: Int, + baseMargin: Float) => val (indices, values) = features match { case v: SparseVector => (v.indices, v.values.map(_.toFloat)) case v: DenseVector => (null, v.values.map(_.toFloat)) } - XGBLabeledPoint(label, indices, values, weight, group, baseMargin) - case Row(label: Float, features: Vector, weight: Float, baseMargin: Float) => + val xgbLp = XGBLabeledPoint(label, indices, values, weight, group, baseMargin) + attachPartitionKey(row, deterministicPartition, numWorkers, xgbLp) + case row @ Row(label: Float, features: Vector, weight: Float, baseMargin: Float) => val (indices, values) = features match { case v: SparseVector => (v.indices, v.values.map(_.toFloat)) case v: DenseVector => (null, v.values.map(_.toFloat)) } - XGBLabeledPoint(label, indices, values, weight, baseMargin = baseMargin) + val xgbLp = XGBLabeledPoint(label, indices, values, weight, baseMargin = baseMargin) + attachPartitionKey(row, deterministicPartition, numWorkers, xgbLp) } } + repartitionRDDs(deterministicPartition, numWorkers, arrayOfRDDs) } } diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index b96a446d1..5bd847e0f 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -475,8 +475,7 @@ object XGBoost extends Serializable { Left(cacheData(ifCacheDataBoolean, repartitionedData). asInstanceOf[RDD[Array[XGBLabeledPoint]]]) } else { - val repartitionedData = repartitionForTraining(trainingData, nWorkers) - Right(cacheData(ifCacheDataBoolean, repartitionedData).asInstanceOf[RDD[XGBLabeledPoint]]) + Right(cacheData(ifCacheDataBoolean, trainingData).asInstanceOf[RDD[XGBLabeledPoint]]) } } @@ -568,15 +567,6 @@ object XGBoost extends Serializable { } } - private[spark] def repartitionForTraining(trainingData: RDD[XGBLabeledPoint], nWorkers: Int) = { - if (trainingData.getNumPartitions != nWorkers) { - logger.info(s"repartitioning training set to $nWorkers partitions") - trainingData.repartition(nWorkers) - } else { - trainingData - } - } - private def aggByGroupInfo(trainingData: RDD[XGBLabeledPoint]) = { val normalGroups: RDD[Array[XGBLabeledPoint]] = trainingData.mapPartitions( // LabeledPointGroupIterator returns (Boolean, Array[XGBLabeledPoint]) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala index 832826bf8..db4936430 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala @@ -37,10 +37,6 @@ import org.json4s.DefaultFormats import scala.collection.JavaConverters._ import scala.collection.{AbstractIterator, Iterator, mutable} -private[spark] trait XGBoostClassifierParams extends GeneralParams with LearningTaskParams - with BoosterParams with HasWeightCol with HasBaseMarginCol with HasNumClass with ParamMapFuncs - with HasLeafPredictionCol with HasContribPredictionCol with NonParamVariables - class XGBoostClassifier ( override val uid: String, private val xgboostParams: Map[String, Any]) @@ -182,11 +178,11 @@ class XGBoostClassifier ( val trainingSet: RDD[XGBLabeledPoint] = DataUtils.convertDataFrameToXGBLabeledPointRDDs( col($(labelCol)), col($(featuresCol)), weight, baseMargin, - None, dataset.asInstanceOf[DataFrame]).head + None, $(numWorkers), needDeterministicRepartitioning, dataset.asInstanceOf[DataFrame]).head val evalRDDMap = getEvalSets(xgboostParams).map { case (name, dataFrame) => (name, DataUtils.convertDataFrameToXGBLabeledPointRDDs(col($(labelCol)), col($(featuresCol)), - weight, baseMargin, None, dataFrame).head) + weight, baseMargin, None, $(numWorkers), needDeterministicRepartitioning, dataFrame).head) } transformSchema(dataset.schema, logging = true) val derivedXGBParamMap = MLlib2XGBoostParams diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimatorCommon.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimatorCommon.scala new file mode 100644 index 000000000..1213a8f72 --- /dev/null +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimatorCommon.scala @@ -0,0 +1,37 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala.spark + +import ml.dmlc.xgboost4j.scala.spark.params._ + +import org.apache.spark.ml.param.shared.HasWeightCol + +private[spark] sealed trait XGBoostEstimatorCommon extends GeneralParams with LearningTaskParams + with BoosterParams with ParamMapFuncs with NonParamVariables { + + def needDeterministicRepartitioning: Boolean = { + getCheckpointPath.nonEmpty && getCheckpointInterval > 0 + } +} + +private[spark] trait XGBoostClassifierParams extends HasWeightCol with HasBaseMarginCol + with HasNumClass with HasLeafPredictionCol with HasContribPredictionCol + with XGBoostEstimatorCommon + +private[spark] trait XGBoostRegressorParams extends HasBaseMarginCol with HasWeightCol + with HasGroupCol with HasLeafPredictionCol with HasContribPredictionCol + with XGBoostEstimatorCommon diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala index f447042e2..e2f22c7af 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala @@ -41,10 +41,6 @@ import scala.collection.mutable.ListBuffer import org.apache.spark.broadcast.Broadcast -private[spark] trait XGBoostRegressorParams extends GeneralParams with BoosterParams - with LearningTaskParams with HasBaseMarginCol with HasWeightCol with HasGroupCol - with ParamMapFuncs with HasLeafPredictionCol with HasContribPredictionCol with NonParamVariables - class XGBoostRegressor ( override val uid: String, private val xgboostParams: Map[String, Any]) @@ -178,11 +174,12 @@ class XGBoostRegressor ( val group = if (!isDefined(groupCol) || $(groupCol).isEmpty) lit(-1) else col($(groupCol)) val trainingSet: RDD[XGBLabeledPoint] = DataUtils.convertDataFrameToXGBLabeledPointRDDs( col($(labelCol)), col($(featuresCol)), weight, baseMargin, Some(group), - dataset.asInstanceOf[DataFrame]).head + $(numWorkers), needDeterministicRepartitioning, dataset.asInstanceOf[DataFrame]).head val evalRDDMap = getEvalSets(xgboostParams).map { case (name, dataFrame) => (name, DataUtils.convertDataFrameToXGBLabeledPointRDDs(col($(labelCol)), col($(featuresCol)), - weight, baseMargin, Some(group), dataFrame).head) + weight, baseMargin, Some(group), $(numWorkers), needDeterministicRepartitioning, + dataFrame).head) } transformSchema(dataset.schema, logging = true) val derivedXGBParamMap = MLlib2XGBoostParams diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/DeterministicPartitioningSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/DeterministicPartitioningSuite.scala new file mode 100644 index 000000000..986b0843b --- /dev/null +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/DeterministicPartitioningSuite.scala @@ -0,0 +1,82 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala.spark + +import org.scalatest.FunSuite + +import org.apache.spark.sql.functions._ + +class DeterministicPartitioningSuite extends FunSuite with TmpFolderPerSuite with PerTest { + + test("perform deterministic partitioning when checkpointInternal and" + + " checkpointPath is set (Classifier)") { + val tmpPath = createTmpFolder("model1").toAbsolutePath.toString + val paramMap = Map("eta" -> "1", "max_depth" -> 2, + "objective" -> "binary:logistic", "checkpoint_path" -> tmpPath, + "checkpoint_interval" -> 2, "num_workers" -> numWorkers) + val xgbClassifier = new XGBoostClassifier(paramMap) + assert(xgbClassifier.needDeterministicRepartitioning) + } + + test("perform deterministic partitioning when checkpointInternal and" + + " checkpointPath is set (Regressor)") { + val tmpPath = createTmpFolder("model1").toAbsolutePath.toString + val paramMap = Map("eta" -> "1", "max_depth" -> 2, + "objective" -> "binary:logistic", "checkpoint_path" -> tmpPath, + "checkpoint_interval" -> 2, "num_workers" -> numWorkers) + val xgbRegressor = new XGBoostRegressor(paramMap) + assert(xgbRegressor.needDeterministicRepartitioning) + } + + test("deterministic partitioning takes effect with various parts of data") { + val trainingDF = buildDataFrame(Classification.train) + // the test idea is that, we apply a chain of repartitions over trainingDFs but they + // have to produce the identical RDDs + val transformedDFs = (1 until 6).map(shuffleCount => { + var resultDF = trainingDF + for (i <- 0 until shuffleCount) { + resultDF = resultDF.repartition(numWorkers) + } + resultDF + }) + val transformedRDDs = transformedDFs.map(df => DataUtils.convertDataFrameToXGBLabeledPointRDDs( + col("label"), + col("features"), + lit(1.0), + lit(Float.NaN), + None, + numWorkers, + deterministicPartition = true, + df + ).head) + val resultsMaps = transformedRDDs.map(rdd => rdd.mapPartitionsWithIndex { + case (partitionIndex, labelPoints) => + Iterator((partitionIndex, labelPoints.toList)) + }.collect().toMap) + resultsMaps.foldLeft(resultsMaps.head) { case (map1, map2) => + assert(map1.keys.toSet === map2.keys.toSet) + for ((parIdx, labeledPoints) <- map1) { + val sortedA = labeledPoints.sortBy(_.hashCode()) + val sortedB = map2(parIdx).sortBy(_.hashCode()) + assert(sortedA.length === sortedB.length) + assert(sortedA.indices.forall(idx => + sortedA(idx).values.toSet === sortedB(idx).values.toSet)) + } + map2 + } + } +}