[jvm-packages] move the dmatrix building into rabit context (#7823)
This fixes the QuantileDeviceDMatrix in distributed environment.
This commit is contained in:
@@ -56,18 +56,20 @@ class GpuPreXGBoost extends PreXGBoostProvider {
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert the Dataset[_] to RDD[Watches] which will be fed to XGBoost
|
||||
* Convert the Dataset[_] to RDD[() => Watches] which will be fed to XGBoost
|
||||
*
|
||||
* @param estimator [[XGBoostClassifier]] or [[XGBoostRegressor]]
|
||||
* @param dataset the training data
|
||||
* @param params all user defined and defaulted params
|
||||
* @return [[XGBoostExecutionParams]] => (RDD[[Watches]], Option[ RDD[_] ])
|
||||
* RDD[Watches] will be used as the training input
|
||||
* @return [[XGBoostExecutionParams]] => (Boolean, RDD[[() => Watches]], Option[ RDD[_] ])
|
||||
* Boolean if building DMatrix in rabit context
|
||||
* RDD[() => Watches] will be used as the training input
|
||||
* Option[ RDD[_] ] is the optional cached RDD
|
||||
*/
|
||||
override def buildDatasetToRDD(estimator: Estimator[_],
|
||||
dataset: Dataset[_],
|
||||
params: Map[String, Any]): XGBoostExecutionParams => (RDD[Watches], Option[RDD[_]]) = {
|
||||
params: Map[String, Any]):
|
||||
XGBoostExecutionParams => (Boolean, RDD[() => Watches], Option[RDD[_]]) = {
|
||||
GpuPreXGBoost.buildDatasetToRDD(estimator, dataset, params)
|
||||
}
|
||||
|
||||
@@ -116,19 +118,21 @@ object GpuPreXGBoost extends PreXGBoostProvider {
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert the Dataset[_] to RDD[Watches] which will be fed to XGBoost
|
||||
* Convert the Dataset[_] to RDD[() => Watches] which will be fed to XGBoost
|
||||
*
|
||||
* @param estimator supports XGBoostClassifier and XGBoostRegressor
|
||||
* @param dataset the training data
|
||||
* @param params all user defined and defaulted params
|
||||
* @return [[XGBoostExecutionParams]] => (RDD[[Watches]], Option[ RDD[_] ])
|
||||
* RDD[Watches] will be used as the training input
|
||||
* @return [[XGBoostExecutionParams]] => (Boolean, RDD[[() => Watches]], Option[ RDD[_] ])
|
||||
* Boolean if building DMatrix in rabit context
|
||||
* RDD[() => Watches] will be used as the training input to build DMatrix
|
||||
* Option[ RDD[_] ] is the optional cached RDD
|
||||
*/
|
||||
override def buildDatasetToRDD(
|
||||
estimator: Estimator[_],
|
||||
dataset: Dataset[_],
|
||||
params: Map[String, Any]): XGBoostExecutionParams => (RDD[Watches], Option[RDD[_]]) = {
|
||||
params: Map[String, Any]):
|
||||
XGBoostExecutionParams => (Boolean, RDD[() => Watches], Option[RDD[_]]) = {
|
||||
|
||||
val (Seq(labelName, weightName, marginName), feturesCols, groupName, evalSets) =
|
||||
estimator match {
|
||||
@@ -166,7 +170,7 @@ object GpuPreXGBoost extends PreXGBoostProvider {
|
||||
xgbExecParams: XGBoostExecutionParams =>
|
||||
val dataMap = prepareInputData(trainingData, evalDataMap, xgbExecParams.numWorkers,
|
||||
xgbExecParams.cacheTrainingSet)
|
||||
(buildRDDWatches(dataMap, xgbExecParams, evalDataMap.isEmpty), None)
|
||||
(true, buildRDDWatches(dataMap, xgbExecParams, evalDataMap.isEmpty), None)
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -448,7 +452,7 @@ object GpuPreXGBoost extends PreXGBoostProvider {
|
||||
private def buildRDDWatches(
|
||||
dataMap: Map[String, ColumnDataBatch],
|
||||
xgbExeParams: XGBoostExecutionParams,
|
||||
noEvalSet: Boolean): RDD[Watches] = {
|
||||
noEvalSet: Boolean): RDD[() => Watches] = {
|
||||
|
||||
val sc = dataMap(TRAIN_NAME).rawDF.sparkSession.sparkContext
|
||||
val maxBin = xgbExeParams.toMap.getOrElse("max_bin", 256).asInstanceOf[Int]
|
||||
@@ -459,7 +463,7 @@ object GpuPreXGBoost extends PreXGBoostProvider {
|
||||
GpuUtils.toColumnarRdd(dataMap(TRAIN_NAME).rawDF).mapPartitions({
|
||||
iter =>
|
||||
val iterColBatch = iter.map(table => new GpuColumnBatch(table, null))
|
||||
Iterator(buildWatches(
|
||||
Iterator(() => buildWatches(
|
||||
PreXGBoost.getCacheDirName(xgbExeParams.useExternalMemory), xgbExeParams.missing,
|
||||
colIndicesForTrain, iterColBatch, maxBin))
|
||||
})
|
||||
@@ -469,7 +473,7 @@ object GpuPreXGBoost extends PreXGBoostProvider {
|
||||
val nameAndColIndices = dataMap.map(nc => (nc._1, nc._2.colIndices))
|
||||
coPartitionForGpu(dataMap, sc, xgbExeParams.numWorkers).mapPartitions {
|
||||
nameAndColumnBatchIter =>
|
||||
Iterator(buildWatchesWithEval(
|
||||
Iterator(() => buildWatchesWithEval(
|
||||
PreXGBoost.getCacheDirName(xgbExeParams.useExternalMemory), xgbExeParams.missing,
|
||||
nameAndColIndices, nameAndColumnBatchIter, maxBin))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user