[jvm-packages] enable deterministic repartitioning when checkpoint is enabled (#4807)
* do reparititoning in DataUtil * keep previous behavior of partitioning without checkpoint * deterministic repartitioning * change
This commit is contained in:
parent
277e25797b
commit
fc8c9b0521
@ -18,6 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark
|
|||||||
|
|
||||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||||
|
|
||||||
|
import org.apache.spark.HashPartitioner
|
||||||
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
|
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
|
||||||
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
|
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
|
||||||
import org.apache.spark.ml.param.Param
|
import org.apache.spark.ml.param.Param
|
||||||
@ -73,12 +74,81 @@ object DataUtils extends Serializable {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private def featureValueOfDenseVector(rowHashCode: Int, features: DenseVector): Float = {
|
||||||
|
val featureId = {
|
||||||
|
if (rowHashCode > 0) {
|
||||||
|
rowHashCode % features.size
|
||||||
|
} else {
|
||||||
|
// prevent overflow
|
||||||
|
math.abs(rowHashCode + 1) % features.size
|
||||||
|
}
|
||||||
|
}
|
||||||
|
features.values(featureId).toFloat
|
||||||
|
}
|
||||||
|
|
||||||
|
private def featureValueOfSparseVector(rowHashCode: Int, features: SparseVector): Float = {
|
||||||
|
val featureId = {
|
||||||
|
if (rowHashCode > 0) {
|
||||||
|
rowHashCode % features.indices.length
|
||||||
|
} else {
|
||||||
|
// prevent overflow
|
||||||
|
math.abs(rowHashCode + 1) % features.indices.length
|
||||||
|
}
|
||||||
|
}
|
||||||
|
features.values(featureId).toFloat
|
||||||
|
}
|
||||||
|
|
||||||
|
private def calculatePartitionKey(row: Row, numPartitions: Int): Int = {
|
||||||
|
val Row(_, features: Vector, _, _) = row
|
||||||
|
val rowHashCode = row.hashCode()
|
||||||
|
val featureValue = features match {
|
||||||
|
case denseVector: DenseVector =>
|
||||||
|
featureValueOfDenseVector(rowHashCode, denseVector)
|
||||||
|
case sparseVector: SparseVector =>
|
||||||
|
featureValueOfSparseVector(rowHashCode, sparseVector)
|
||||||
|
}
|
||||||
|
math.abs((rowHashCode.toLong + featureValue).toString.hashCode % numPartitions)
|
||||||
|
}
|
||||||
|
|
||||||
|
private def attachPartitionKey(
|
||||||
|
row: Row,
|
||||||
|
deterministicPartition: Boolean,
|
||||||
|
numWorkers: Int,
|
||||||
|
xgbLp: XGBLabeledPoint): (Int, XGBLabeledPoint) = {
|
||||||
|
if (deterministicPartition) {
|
||||||
|
(calculatePartitionKey(row, numWorkers), xgbLp)
|
||||||
|
} else {
|
||||||
|
(1, xgbLp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private def repartitionRDDs(
|
||||||
|
deterministicPartition: Boolean,
|
||||||
|
numWorkers: Int,
|
||||||
|
arrayOfRDDs: Array[RDD[(Int, XGBLabeledPoint)]]): Array[RDD[XGBLabeledPoint]] = {
|
||||||
|
if (deterministicPartition) {
|
||||||
|
arrayOfRDDs.map {rdd => rdd.partitionBy(new HashPartitioner(numWorkers))}.map {
|
||||||
|
rdd => rdd.map(_._2)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
arrayOfRDDs.map(rdd => {
|
||||||
|
if (rdd.getNumPartitions != numWorkers) {
|
||||||
|
rdd.map(_._2).repartition(numWorkers)
|
||||||
|
} else {
|
||||||
|
rdd.map(_._2)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private[spark] def convertDataFrameToXGBLabeledPointRDDs(
|
private[spark] def convertDataFrameToXGBLabeledPointRDDs(
|
||||||
labelCol: Column,
|
labelCol: Column,
|
||||||
featuresCol: Column,
|
featuresCol: Column,
|
||||||
weight: Column,
|
weight: Column,
|
||||||
baseMargin: Column,
|
baseMargin: Column,
|
||||||
group: Option[Column],
|
group: Option[Column],
|
||||||
|
numWorkers: Int,
|
||||||
|
deterministicPartition: Boolean,
|
||||||
dataFrames: DataFrame*): Array[RDD[XGBLabeledPoint]] = {
|
dataFrames: DataFrame*): Array[RDD[XGBLabeledPoint]] = {
|
||||||
val selectedColumns = group.map(groupCol => Seq(labelCol.cast(FloatType),
|
val selectedColumns = group.map(groupCol => Seq(labelCol.cast(FloatType),
|
||||||
featuresCol,
|
featuresCol,
|
||||||
@ -88,22 +158,26 @@ object DataUtils extends Serializable {
|
|||||||
featuresCol,
|
featuresCol,
|
||||||
weight.cast(FloatType),
|
weight.cast(FloatType),
|
||||||
baseMargin.cast(FloatType)))
|
baseMargin.cast(FloatType)))
|
||||||
dataFrames.toArray.map {
|
val arrayOfRDDs = dataFrames.toArray.map {
|
||||||
df => df.select(selectedColumns: _*).rdd.map {
|
df => df.select(selectedColumns: _*).rdd.map {
|
||||||
case Row(label: Float, features: Vector, weight: Float, group: Int, baseMargin: Float) =>
|
case row @ Row(label: Float, features: Vector, weight: Float, group: Int,
|
||||||
|
baseMargin: Float) =>
|
||||||
val (indices, values) = features match {
|
val (indices, values) = features match {
|
||||||
case v: SparseVector => (v.indices, v.values.map(_.toFloat))
|
case v: SparseVector => (v.indices, v.values.map(_.toFloat))
|
||||||
case v: DenseVector => (null, v.values.map(_.toFloat))
|
case v: DenseVector => (null, v.values.map(_.toFloat))
|
||||||
}
|
}
|
||||||
XGBLabeledPoint(label, indices, values, weight, group, baseMargin)
|
val xgbLp = XGBLabeledPoint(label, indices, values, weight, group, baseMargin)
|
||||||
case Row(label: Float, features: Vector, weight: Float, baseMargin: Float) =>
|
attachPartitionKey(row, deterministicPartition, numWorkers, xgbLp)
|
||||||
|
case row @ Row(label: Float, features: Vector, weight: Float, baseMargin: Float) =>
|
||||||
val (indices, values) = features match {
|
val (indices, values) = features match {
|
||||||
case v: SparseVector => (v.indices, v.values.map(_.toFloat))
|
case v: SparseVector => (v.indices, v.values.map(_.toFloat))
|
||||||
case v: DenseVector => (null, v.values.map(_.toFloat))
|
case v: DenseVector => (null, v.values.map(_.toFloat))
|
||||||
}
|
}
|
||||||
XGBLabeledPoint(label, indices, values, weight, baseMargin = baseMargin)
|
val xgbLp = XGBLabeledPoint(label, indices, values, weight, baseMargin = baseMargin)
|
||||||
|
attachPartitionKey(row, deterministicPartition, numWorkers, xgbLp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
repartitionRDDs(deterministicPartition, numWorkers, arrayOfRDDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -475,8 +475,7 @@ object XGBoost extends Serializable {
|
|||||||
Left(cacheData(ifCacheDataBoolean, repartitionedData).
|
Left(cacheData(ifCacheDataBoolean, repartitionedData).
|
||||||
asInstanceOf[RDD[Array[XGBLabeledPoint]]])
|
asInstanceOf[RDD[Array[XGBLabeledPoint]]])
|
||||||
} else {
|
} else {
|
||||||
val repartitionedData = repartitionForTraining(trainingData, nWorkers)
|
Right(cacheData(ifCacheDataBoolean, trainingData).asInstanceOf[RDD[XGBLabeledPoint]])
|
||||||
Right(cacheData(ifCacheDataBoolean, repartitionedData).asInstanceOf[RDD[XGBLabeledPoint]])
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -568,15 +567,6 @@ object XGBoost extends Serializable {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private[spark] def repartitionForTraining(trainingData: RDD[XGBLabeledPoint], nWorkers: Int) = {
|
|
||||||
if (trainingData.getNumPartitions != nWorkers) {
|
|
||||||
logger.info(s"repartitioning training set to $nWorkers partitions")
|
|
||||||
trainingData.repartition(nWorkers)
|
|
||||||
} else {
|
|
||||||
trainingData
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private def aggByGroupInfo(trainingData: RDD[XGBLabeledPoint]) = {
|
private def aggByGroupInfo(trainingData: RDD[XGBLabeledPoint]) = {
|
||||||
val normalGroups: RDD[Array[XGBLabeledPoint]] = trainingData.mapPartitions(
|
val normalGroups: RDD[Array[XGBLabeledPoint]] = trainingData.mapPartitions(
|
||||||
// LabeledPointGroupIterator returns (Boolean, Array[XGBLabeledPoint])
|
// LabeledPointGroupIterator returns (Boolean, Array[XGBLabeledPoint])
|
||||||
|
|||||||
@ -37,10 +37,6 @@ import org.json4s.DefaultFormats
|
|||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
import scala.collection.{AbstractIterator, Iterator, mutable}
|
import scala.collection.{AbstractIterator, Iterator, mutable}
|
||||||
|
|
||||||
private[spark] trait XGBoostClassifierParams extends GeneralParams with LearningTaskParams
|
|
||||||
with BoosterParams with HasWeightCol with HasBaseMarginCol with HasNumClass with ParamMapFuncs
|
|
||||||
with HasLeafPredictionCol with HasContribPredictionCol with NonParamVariables
|
|
||||||
|
|
||||||
class XGBoostClassifier (
|
class XGBoostClassifier (
|
||||||
override val uid: String,
|
override val uid: String,
|
||||||
private val xgboostParams: Map[String, Any])
|
private val xgboostParams: Map[String, Any])
|
||||||
@ -182,11 +178,11 @@ class XGBoostClassifier (
|
|||||||
|
|
||||||
val trainingSet: RDD[XGBLabeledPoint] = DataUtils.convertDataFrameToXGBLabeledPointRDDs(
|
val trainingSet: RDD[XGBLabeledPoint] = DataUtils.convertDataFrameToXGBLabeledPointRDDs(
|
||||||
col($(labelCol)), col($(featuresCol)), weight, baseMargin,
|
col($(labelCol)), col($(featuresCol)), weight, baseMargin,
|
||||||
None, dataset.asInstanceOf[DataFrame]).head
|
None, $(numWorkers), needDeterministicRepartitioning, dataset.asInstanceOf[DataFrame]).head
|
||||||
val evalRDDMap = getEvalSets(xgboostParams).map {
|
val evalRDDMap = getEvalSets(xgboostParams).map {
|
||||||
case (name, dataFrame) => (name,
|
case (name, dataFrame) => (name,
|
||||||
DataUtils.convertDataFrameToXGBLabeledPointRDDs(col($(labelCol)), col($(featuresCol)),
|
DataUtils.convertDataFrameToXGBLabeledPointRDDs(col($(labelCol)), col($(featuresCol)),
|
||||||
weight, baseMargin, None, dataFrame).head)
|
weight, baseMargin, None, $(numWorkers), needDeterministicRepartitioning, dataFrame).head)
|
||||||
}
|
}
|
||||||
transformSchema(dataset.schema, logging = true)
|
transformSchema(dataset.schema, logging = true)
|
||||||
val derivedXGBParamMap = MLlib2XGBoostParams
|
val derivedXGBParamMap = MLlib2XGBoostParams
|
||||||
|
|||||||
@ -0,0 +1,37 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.scala.spark.params._
|
||||||
|
|
||||||
|
import org.apache.spark.ml.param.shared.HasWeightCol
|
||||||
|
|
||||||
|
private[spark] sealed trait XGBoostEstimatorCommon extends GeneralParams with LearningTaskParams
|
||||||
|
with BoosterParams with ParamMapFuncs with NonParamVariables {
|
||||||
|
|
||||||
|
def needDeterministicRepartitioning: Boolean = {
|
||||||
|
getCheckpointPath.nonEmpty && getCheckpointInterval > 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private[spark] trait XGBoostClassifierParams extends HasWeightCol with HasBaseMarginCol
|
||||||
|
with HasNumClass with HasLeafPredictionCol with HasContribPredictionCol
|
||||||
|
with XGBoostEstimatorCommon
|
||||||
|
|
||||||
|
private[spark] trait XGBoostRegressorParams extends HasBaseMarginCol with HasWeightCol
|
||||||
|
with HasGroupCol with HasLeafPredictionCol with HasContribPredictionCol
|
||||||
|
with XGBoostEstimatorCommon
|
||||||
@ -41,10 +41,6 @@ import scala.collection.mutable.ListBuffer
|
|||||||
|
|
||||||
import org.apache.spark.broadcast.Broadcast
|
import org.apache.spark.broadcast.Broadcast
|
||||||
|
|
||||||
private[spark] trait XGBoostRegressorParams extends GeneralParams with BoosterParams
|
|
||||||
with LearningTaskParams with HasBaseMarginCol with HasWeightCol with HasGroupCol
|
|
||||||
with ParamMapFuncs with HasLeafPredictionCol with HasContribPredictionCol with NonParamVariables
|
|
||||||
|
|
||||||
class XGBoostRegressor (
|
class XGBoostRegressor (
|
||||||
override val uid: String,
|
override val uid: String,
|
||||||
private val xgboostParams: Map[String, Any])
|
private val xgboostParams: Map[String, Any])
|
||||||
@ -178,11 +174,12 @@ class XGBoostRegressor (
|
|||||||
val group = if (!isDefined(groupCol) || $(groupCol).isEmpty) lit(-1) else col($(groupCol))
|
val group = if (!isDefined(groupCol) || $(groupCol).isEmpty) lit(-1) else col($(groupCol))
|
||||||
val trainingSet: RDD[XGBLabeledPoint] = DataUtils.convertDataFrameToXGBLabeledPointRDDs(
|
val trainingSet: RDD[XGBLabeledPoint] = DataUtils.convertDataFrameToXGBLabeledPointRDDs(
|
||||||
col($(labelCol)), col($(featuresCol)), weight, baseMargin, Some(group),
|
col($(labelCol)), col($(featuresCol)), weight, baseMargin, Some(group),
|
||||||
dataset.asInstanceOf[DataFrame]).head
|
$(numWorkers), needDeterministicRepartitioning, dataset.asInstanceOf[DataFrame]).head
|
||||||
val evalRDDMap = getEvalSets(xgboostParams).map {
|
val evalRDDMap = getEvalSets(xgboostParams).map {
|
||||||
case (name, dataFrame) => (name,
|
case (name, dataFrame) => (name,
|
||||||
DataUtils.convertDataFrameToXGBLabeledPointRDDs(col($(labelCol)), col($(featuresCol)),
|
DataUtils.convertDataFrameToXGBLabeledPointRDDs(col($(labelCol)), col($(featuresCol)),
|
||||||
weight, baseMargin, Some(group), dataFrame).head)
|
weight, baseMargin, Some(group), $(numWorkers), needDeterministicRepartitioning,
|
||||||
|
dataFrame).head)
|
||||||
}
|
}
|
||||||
transformSchema(dataset.schema, logging = true)
|
transformSchema(dataset.schema, logging = true)
|
||||||
val derivedXGBParamMap = MLlib2XGBoostParams
|
val derivedXGBParamMap = MLlib2XGBoostParams
|
||||||
|
|||||||
@ -0,0 +1,82 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
|
import org.scalatest.FunSuite
|
||||||
|
|
||||||
|
import org.apache.spark.sql.functions._
|
||||||
|
|
||||||
|
class DeterministicPartitioningSuite extends FunSuite with TmpFolderPerSuite with PerTest {
|
||||||
|
|
||||||
|
test("perform deterministic partitioning when checkpointInternal and" +
|
||||||
|
" checkpointPath is set (Classifier)") {
|
||||||
|
val tmpPath = createTmpFolder("model1").toAbsolutePath.toString
|
||||||
|
val paramMap = Map("eta" -> "1", "max_depth" -> 2,
|
||||||
|
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
|
||||||
|
"checkpoint_interval" -> 2, "num_workers" -> numWorkers)
|
||||||
|
val xgbClassifier = new XGBoostClassifier(paramMap)
|
||||||
|
assert(xgbClassifier.needDeterministicRepartitioning)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("perform deterministic partitioning when checkpointInternal and" +
|
||||||
|
" checkpointPath is set (Regressor)") {
|
||||||
|
val tmpPath = createTmpFolder("model1").toAbsolutePath.toString
|
||||||
|
val paramMap = Map("eta" -> "1", "max_depth" -> 2,
|
||||||
|
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
|
||||||
|
"checkpoint_interval" -> 2, "num_workers" -> numWorkers)
|
||||||
|
val xgbRegressor = new XGBoostRegressor(paramMap)
|
||||||
|
assert(xgbRegressor.needDeterministicRepartitioning)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("deterministic partitioning takes effect with various parts of data") {
|
||||||
|
val trainingDF = buildDataFrame(Classification.train)
|
||||||
|
// the test idea is that, we apply a chain of repartitions over trainingDFs but they
|
||||||
|
// have to produce the identical RDDs
|
||||||
|
val transformedDFs = (1 until 6).map(shuffleCount => {
|
||||||
|
var resultDF = trainingDF
|
||||||
|
for (i <- 0 until shuffleCount) {
|
||||||
|
resultDF = resultDF.repartition(numWorkers)
|
||||||
|
}
|
||||||
|
resultDF
|
||||||
|
})
|
||||||
|
val transformedRDDs = transformedDFs.map(df => DataUtils.convertDataFrameToXGBLabeledPointRDDs(
|
||||||
|
col("label"),
|
||||||
|
col("features"),
|
||||||
|
lit(1.0),
|
||||||
|
lit(Float.NaN),
|
||||||
|
None,
|
||||||
|
numWorkers,
|
||||||
|
deterministicPartition = true,
|
||||||
|
df
|
||||||
|
).head)
|
||||||
|
val resultsMaps = transformedRDDs.map(rdd => rdd.mapPartitionsWithIndex {
|
||||||
|
case (partitionIndex, labelPoints) =>
|
||||||
|
Iterator((partitionIndex, labelPoints.toList))
|
||||||
|
}.collect().toMap)
|
||||||
|
resultsMaps.foldLeft(resultsMaps.head) { case (map1, map2) =>
|
||||||
|
assert(map1.keys.toSet === map2.keys.toSet)
|
||||||
|
for ((parIdx, labeledPoints) <- map1) {
|
||||||
|
val sortedA = labeledPoints.sortBy(_.hashCode())
|
||||||
|
val sortedB = map2(parIdx).sortBy(_.hashCode())
|
||||||
|
assert(sortedA.length === sortedB.length)
|
||||||
|
assert(sortedA.indices.forall(idx =>
|
||||||
|
sortedA(idx).values.toSet === sortedB(idx).values.toSet))
|
||||||
|
}
|
||||||
|
map2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user