[jvm-packages] move the dmatrix building into rabit context (#7823)

This fixes the QuantileDeviceDMatrix in distributed environment.
This commit is contained in:
Bobby Wang
2022-04-23 00:06:50 +08:00
committed by GitHub
parent f0f76259c9
commit c45665a55a
5 changed files with 87 additions and 55 deletions

View File

@@ -56,18 +56,20 @@ class GpuPreXGBoost extends PreXGBoostProvider {
}
/**
* Convert the Dataset[_] to RDD[Watches] which will be fed to XGBoost
* Convert the Dataset[_] to RDD[() => Watches] which will be fed to XGBoost
*
* @param estimator [[XGBoostClassifier]] or [[XGBoostRegressor]]
* @param dataset the training data
* @param params all user defined and defaulted params
* @return [[XGBoostExecutionParams]] => (RDD[[Watches]], Option[ RDD[_] ])
* RDD[Watches] will be used as the training input
* @return [[XGBoostExecutionParams]] => (Boolean, RDD[[() => Watches]], Option[ RDD[_] ])
* Boolean if building DMatrix in rabit context
* RDD[() => Watches] will be used as the training input
* Option[ RDD[_] ] is the optional cached RDD
*/
override def buildDatasetToRDD(estimator: Estimator[_],
dataset: Dataset[_],
params: Map[String, Any]): XGBoostExecutionParams => (RDD[Watches], Option[RDD[_]]) = {
params: Map[String, Any]):
XGBoostExecutionParams => (Boolean, RDD[() => Watches], Option[RDD[_]]) = {
GpuPreXGBoost.buildDatasetToRDD(estimator, dataset, params)
}
@@ -116,19 +118,21 @@ object GpuPreXGBoost extends PreXGBoostProvider {
}
/**
* Convert the Dataset[_] to RDD[Watches] which will be fed to XGBoost
* Convert the Dataset[_] to RDD[() => Watches] which will be fed to XGBoost
*
* @param estimator supports XGBoostClassifier and XGBoostRegressor
* @param dataset the training data
* @param params all user defined and defaulted params
* @return [[XGBoostExecutionParams]] => (RDD[[Watches]], Option[ RDD[_] ])
* RDD[Watches] will be used as the training input
* @return [[XGBoostExecutionParams]] => (Boolean, RDD[[() => Watches]], Option[ RDD[_] ])
* Boolean if building DMatrix in rabit context
* RDD[() => Watches] will be used as the training input to build DMatrix
* Option[ RDD[_] ] is the optional cached RDD
*/
override def buildDatasetToRDD(
estimator: Estimator[_],
dataset: Dataset[_],
params: Map[String, Any]): XGBoostExecutionParams => (RDD[Watches], Option[RDD[_]]) = {
params: Map[String, Any]):
XGBoostExecutionParams => (Boolean, RDD[() => Watches], Option[RDD[_]]) = {
val (Seq(labelName, weightName, marginName), feturesCols, groupName, evalSets) =
estimator match {
@@ -166,7 +170,7 @@ object GpuPreXGBoost extends PreXGBoostProvider {
xgbExecParams: XGBoostExecutionParams =>
val dataMap = prepareInputData(trainingData, evalDataMap, xgbExecParams.numWorkers,
xgbExecParams.cacheTrainingSet)
(buildRDDWatches(dataMap, xgbExecParams, evalDataMap.isEmpty), None)
(true, buildRDDWatches(dataMap, xgbExecParams, evalDataMap.isEmpty), None)
}
/**
@@ -448,7 +452,7 @@ object GpuPreXGBoost extends PreXGBoostProvider {
private def buildRDDWatches(
dataMap: Map[String, ColumnDataBatch],
xgbExeParams: XGBoostExecutionParams,
noEvalSet: Boolean): RDD[Watches] = {
noEvalSet: Boolean): RDD[() => Watches] = {
val sc = dataMap(TRAIN_NAME).rawDF.sparkSession.sparkContext
val maxBin = xgbExeParams.toMap.getOrElse("max_bin", 256).asInstanceOf[Int]
@@ -459,7 +463,7 @@ object GpuPreXGBoost extends PreXGBoostProvider {
GpuUtils.toColumnarRdd(dataMap(TRAIN_NAME).rawDF).mapPartitions({
iter =>
val iterColBatch = iter.map(table => new GpuColumnBatch(table, null))
Iterator(buildWatches(
Iterator(() => buildWatches(
PreXGBoost.getCacheDirName(xgbExeParams.useExternalMemory), xgbExeParams.missing,
colIndicesForTrain, iterColBatch, maxBin))
})
@@ -469,7 +473,7 @@ object GpuPreXGBoost extends PreXGBoostProvider {
val nameAndColIndices = dataMap.map(nc => (nc._1, nc._2.colIndices))
coPartitionForGpu(dataMap, sc, xgbExeParams.numWorkers).mapPartitions {
nameAndColumnBatchIter =>
Iterator(buildWatchesWithEval(
Iterator(() => buildWatchesWithEval(
PreXGBoost.getCacheDirName(xgbExeParams.useExternalMemory), xgbExeParams.missing,
nameAndColIndices, nameAndColumnBatchIter, maxBin))
}