[jvm-packages] move the dmatrix building into rabit context (#7823)
This fixes the QuantileDeviceDMatrix in distributed environment.
This commit is contained in:
parent
f0f76259c9
commit
c45665a55a
@ -56,18 +56,20 @@ class GpuPreXGBoost extends PreXGBoostProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convert the Dataset[_] to RDD[Watches] which will be fed to XGBoost
|
* Convert the Dataset[_] to RDD[() => Watches] which will be fed to XGBoost
|
||||||
*
|
*
|
||||||
* @param estimator [[XGBoostClassifier]] or [[XGBoostRegressor]]
|
* @param estimator [[XGBoostClassifier]] or [[XGBoostRegressor]]
|
||||||
* @param dataset the training data
|
* @param dataset the training data
|
||||||
* @param params all user defined and defaulted params
|
* @param params all user defined and defaulted params
|
||||||
* @return [[XGBoostExecutionParams]] => (RDD[[Watches]], Option[ RDD[_] ])
|
* @return [[XGBoostExecutionParams]] => (Boolean, RDD[[() => Watches]], Option[ RDD[_] ])
|
||||||
* RDD[Watches] will be used as the training input
|
* Boolean if building DMatrix in rabit context
|
||||||
|
* RDD[() => Watches] will be used as the training input
|
||||||
* Option[ RDD[_] ] is the optional cached RDD
|
* Option[ RDD[_] ] is the optional cached RDD
|
||||||
*/
|
*/
|
||||||
override def buildDatasetToRDD(estimator: Estimator[_],
|
override def buildDatasetToRDD(estimator: Estimator[_],
|
||||||
dataset: Dataset[_],
|
dataset: Dataset[_],
|
||||||
params: Map[String, Any]): XGBoostExecutionParams => (RDD[Watches], Option[RDD[_]]) = {
|
params: Map[String, Any]):
|
||||||
|
XGBoostExecutionParams => (Boolean, RDD[() => Watches], Option[RDD[_]]) = {
|
||||||
GpuPreXGBoost.buildDatasetToRDD(estimator, dataset, params)
|
GpuPreXGBoost.buildDatasetToRDD(estimator, dataset, params)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -116,19 +118,21 @@ object GpuPreXGBoost extends PreXGBoostProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convert the Dataset[_] to RDD[Watches] which will be fed to XGBoost
|
* Convert the Dataset[_] to RDD[() => Watches] which will be fed to XGBoost
|
||||||
*
|
*
|
||||||
* @param estimator supports XGBoostClassifier and XGBoostRegressor
|
* @param estimator supports XGBoostClassifier and XGBoostRegressor
|
||||||
* @param dataset the training data
|
* @param dataset the training data
|
||||||
* @param params all user defined and defaulted params
|
* @param params all user defined and defaulted params
|
||||||
* @return [[XGBoostExecutionParams]] => (RDD[[Watches]], Option[ RDD[_] ])
|
* @return [[XGBoostExecutionParams]] => (Boolean, RDD[[() => Watches]], Option[ RDD[_] ])
|
||||||
* RDD[Watches] will be used as the training input
|
* Boolean if building DMatrix in rabit context
|
||||||
|
* RDD[() => Watches] will be used as the training input to build DMatrix
|
||||||
* Option[ RDD[_] ] is the optional cached RDD
|
* Option[ RDD[_] ] is the optional cached RDD
|
||||||
*/
|
*/
|
||||||
override def buildDatasetToRDD(
|
override def buildDatasetToRDD(
|
||||||
estimator: Estimator[_],
|
estimator: Estimator[_],
|
||||||
dataset: Dataset[_],
|
dataset: Dataset[_],
|
||||||
params: Map[String, Any]): XGBoostExecutionParams => (RDD[Watches], Option[RDD[_]]) = {
|
params: Map[String, Any]):
|
||||||
|
XGBoostExecutionParams => (Boolean, RDD[() => Watches], Option[RDD[_]]) = {
|
||||||
|
|
||||||
val (Seq(labelName, weightName, marginName), feturesCols, groupName, evalSets) =
|
val (Seq(labelName, weightName, marginName), feturesCols, groupName, evalSets) =
|
||||||
estimator match {
|
estimator match {
|
||||||
@ -166,7 +170,7 @@ object GpuPreXGBoost extends PreXGBoostProvider {
|
|||||||
xgbExecParams: XGBoostExecutionParams =>
|
xgbExecParams: XGBoostExecutionParams =>
|
||||||
val dataMap = prepareInputData(trainingData, evalDataMap, xgbExecParams.numWorkers,
|
val dataMap = prepareInputData(trainingData, evalDataMap, xgbExecParams.numWorkers,
|
||||||
xgbExecParams.cacheTrainingSet)
|
xgbExecParams.cacheTrainingSet)
|
||||||
(buildRDDWatches(dataMap, xgbExecParams, evalDataMap.isEmpty), None)
|
(true, buildRDDWatches(dataMap, xgbExecParams, evalDataMap.isEmpty), None)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -448,7 +452,7 @@ object GpuPreXGBoost extends PreXGBoostProvider {
|
|||||||
private def buildRDDWatches(
|
private def buildRDDWatches(
|
||||||
dataMap: Map[String, ColumnDataBatch],
|
dataMap: Map[String, ColumnDataBatch],
|
||||||
xgbExeParams: XGBoostExecutionParams,
|
xgbExeParams: XGBoostExecutionParams,
|
||||||
noEvalSet: Boolean): RDD[Watches] = {
|
noEvalSet: Boolean): RDD[() => Watches] = {
|
||||||
|
|
||||||
val sc = dataMap(TRAIN_NAME).rawDF.sparkSession.sparkContext
|
val sc = dataMap(TRAIN_NAME).rawDF.sparkSession.sparkContext
|
||||||
val maxBin = xgbExeParams.toMap.getOrElse("max_bin", 256).asInstanceOf[Int]
|
val maxBin = xgbExeParams.toMap.getOrElse("max_bin", 256).asInstanceOf[Int]
|
||||||
@ -459,7 +463,7 @@ object GpuPreXGBoost extends PreXGBoostProvider {
|
|||||||
GpuUtils.toColumnarRdd(dataMap(TRAIN_NAME).rawDF).mapPartitions({
|
GpuUtils.toColumnarRdd(dataMap(TRAIN_NAME).rawDF).mapPartitions({
|
||||||
iter =>
|
iter =>
|
||||||
val iterColBatch = iter.map(table => new GpuColumnBatch(table, null))
|
val iterColBatch = iter.map(table => new GpuColumnBatch(table, null))
|
||||||
Iterator(buildWatches(
|
Iterator(() => buildWatches(
|
||||||
PreXGBoost.getCacheDirName(xgbExeParams.useExternalMemory), xgbExeParams.missing,
|
PreXGBoost.getCacheDirName(xgbExeParams.useExternalMemory), xgbExeParams.missing,
|
||||||
colIndicesForTrain, iterColBatch, maxBin))
|
colIndicesForTrain, iterColBatch, maxBin))
|
||||||
})
|
})
|
||||||
@ -469,7 +473,7 @@ object GpuPreXGBoost extends PreXGBoostProvider {
|
|||||||
val nameAndColIndices = dataMap.map(nc => (nc._1, nc._2.colIndices))
|
val nameAndColIndices = dataMap.map(nc => (nc._1, nc._2.colIndices))
|
||||||
coPartitionForGpu(dataMap, sc, xgbExeParams.numWorkers).mapPartitions {
|
coPartitionForGpu(dataMap, sc, xgbExeParams.numWorkers).mapPartitions {
|
||||||
nameAndColumnBatchIter =>
|
nameAndColumnBatchIter =>
|
||||||
Iterator(buildWatchesWithEval(
|
Iterator(() => buildWatchesWithEval(
|
||||||
PreXGBoost.getCacheDirName(xgbExeParams.useExternalMemory), xgbExeParams.missing,
|
PreXGBoost.getCacheDirName(xgbExeParams.useExternalMemory), xgbExeParams.missing,
|
||||||
nameAndColIndices, nameAndColumnBatchIter, maxBin))
|
nameAndColIndices, nameAndColumnBatchIter, maxBin))
|
||||||
}
|
}
|
||||||
|
|||||||
@ -96,19 +96,21 @@ object PreXGBoost extends PreXGBoostProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convert the Dataset[_] to RDD[Watches] which will be fed to XGBoost
|
* Convert the Dataset[_] to RDD[() => Watches] which will be fed to XGBoost
|
||||||
*
|
*
|
||||||
* @param estimator supports XGBoostClassifier and XGBoostRegressor
|
* @param estimator supports XGBoostClassifier and XGBoostRegressor
|
||||||
* @param dataset the training data
|
* @param dataset the training data
|
||||||
* @param params all user defined and defaulted params
|
* @param params all user defined and defaulted params
|
||||||
* @return [[XGBoostExecutionParams]] => (RDD[[Watches]], Option[ RDD[_] ])
|
* @return [[XGBoostExecutionParams]] => (Boolean, RDD[[() => Watches]], Option[ RDD[_] ])
|
||||||
* RDD[Watches] will be used as the training input
|
* Boolean if building DMatrix in rabit context
|
||||||
|
* RDD[() => Watches] will be used as the training input
|
||||||
* Option[RDD[_]\] is the optional cached RDD
|
* Option[RDD[_]\] is the optional cached RDD
|
||||||
*/
|
*/
|
||||||
override def buildDatasetToRDD(
|
override def buildDatasetToRDD(
|
||||||
estimator: Estimator[_],
|
estimator: Estimator[_],
|
||||||
dataset: Dataset[_],
|
dataset: Dataset[_],
|
||||||
params: Map[String, Any]): XGBoostExecutionParams => (RDD[Watches], Option[RDD[_]]) = {
|
params: Map[String, Any]): XGBoostExecutionParams =>
|
||||||
|
(Boolean, RDD[() => Watches], Option[RDD[_]]) = {
|
||||||
|
|
||||||
if (optionProvider.isDefined && optionProvider.get.providerEnabled(Some(dataset))) {
|
if (optionProvider.isDefined && optionProvider.get.providerEnabled(Some(dataset))) {
|
||||||
return optionProvider.get.buildDatasetToRDD(estimator, dataset, params)
|
return optionProvider.get.buildDatasetToRDD(estimator, dataset, params)
|
||||||
@ -170,12 +172,12 @@ object PreXGBoost extends PreXGBoostProvider {
|
|||||||
val cachedRDD = if (xgbExecParams.cacheTrainingSet) {
|
val cachedRDD = if (xgbExecParams.cacheTrainingSet) {
|
||||||
Some(trainingData.persist(StorageLevel.MEMORY_AND_DISK))
|
Some(trainingData.persist(StorageLevel.MEMORY_AND_DISK))
|
||||||
} else None
|
} else None
|
||||||
(trainForRanking(trainingData, xgbExecParams, evalRDDMap), cachedRDD)
|
(false, trainForRanking(trainingData, xgbExecParams, evalRDDMap), cachedRDD)
|
||||||
case Right(trainingData) =>
|
case Right(trainingData) =>
|
||||||
val cachedRDD = if (xgbExecParams.cacheTrainingSet) {
|
val cachedRDD = if (xgbExecParams.cacheTrainingSet) {
|
||||||
Some(trainingData.persist(StorageLevel.MEMORY_AND_DISK))
|
Some(trainingData.persist(StorageLevel.MEMORY_AND_DISK))
|
||||||
} else None
|
} else None
|
||||||
(trainForNonRanking(trainingData, xgbExecParams, evalRDDMap), cachedRDD)
|
(false, trainForNonRanking(trainingData, xgbExecParams, evalRDDMap), cachedRDD)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@ -311,17 +313,18 @@ object PreXGBoost extends PreXGBoostProvider {
|
|||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Converting the RDD[XGBLabeledPoint] to the function to build RDD[Watches]
|
* Converting the RDD[XGBLabeledPoint] to the function to build RDD[() => Watches]
|
||||||
*
|
*
|
||||||
* @param trainingSet the input training RDD[XGBLabeledPoint]
|
* @param trainingSet the input training RDD[XGBLabeledPoint]
|
||||||
* @param evalRDDMap the eval set
|
* @param evalRDDMap the eval set
|
||||||
* @param hasGroup if has group
|
* @param hasGroup if has group
|
||||||
* @return function to build (RDD[Watches], the cached RDD)
|
* @return function to build (RDD[() => Watches], the cached RDD)
|
||||||
*/
|
*/
|
||||||
private[spark] def buildRDDLabeledPointToRDDWatches(
|
private[spark] def buildRDDLabeledPointToRDDWatches(
|
||||||
trainingSet: RDD[XGBLabeledPoint],
|
trainingSet: RDD[XGBLabeledPoint],
|
||||||
evalRDDMap: Map[String, RDD[XGBLabeledPoint]] = Map(),
|
evalRDDMap: Map[String, RDD[XGBLabeledPoint]] = Map(),
|
||||||
hasGroup: Boolean = false): XGBoostExecutionParams => (RDD[Watches], Option[RDD[_]]) = {
|
hasGroup: Boolean = false):
|
||||||
|
XGBoostExecutionParams => (Boolean, RDD[() => Watches], Option[RDD[_]]) = {
|
||||||
|
|
||||||
xgbExecParams: XGBoostExecutionParams =>
|
xgbExecParams: XGBoostExecutionParams =>
|
||||||
composeInputData(trainingSet, hasGroup, xgbExecParams.numWorkers) match {
|
composeInputData(trainingSet, hasGroup, xgbExecParams.numWorkers) match {
|
||||||
@ -329,12 +332,12 @@ object PreXGBoost extends PreXGBoostProvider {
|
|||||||
val cachedRDD = if (xgbExecParams.cacheTrainingSet) {
|
val cachedRDD = if (xgbExecParams.cacheTrainingSet) {
|
||||||
Some(trainingData.persist(StorageLevel.MEMORY_AND_DISK))
|
Some(trainingData.persist(StorageLevel.MEMORY_AND_DISK))
|
||||||
} else None
|
} else None
|
||||||
(trainForRanking(trainingData, xgbExecParams, evalRDDMap), cachedRDD)
|
(false, trainForRanking(trainingData, xgbExecParams, evalRDDMap), cachedRDD)
|
||||||
case Right(trainingData) =>
|
case Right(trainingData) =>
|
||||||
val cachedRDD = if (xgbExecParams.cacheTrainingSet) {
|
val cachedRDD = if (xgbExecParams.cacheTrainingSet) {
|
||||||
Some(trainingData.persist(StorageLevel.MEMORY_AND_DISK))
|
Some(trainingData.persist(StorageLevel.MEMORY_AND_DISK))
|
||||||
} else None
|
} else None
|
||||||
(trainForNonRanking(trainingData, xgbExecParams, evalRDDMap), cachedRDD)
|
(false, trainForNonRanking(trainingData, xgbExecParams, evalRDDMap), cachedRDD)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -374,34 +377,34 @@ object PreXGBoost extends PreXGBoostProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Build RDD[Watches] for Ranking
|
* Build RDD[() => Watches] for Ranking
|
||||||
* @param trainingData the training data RDD
|
* @param trainingData the training data RDD
|
||||||
* @param xgbExecutionParams xgboost execution params
|
* @param xgbExecutionParams xgboost execution params
|
||||||
* @param evalSetsMap the eval RDD
|
* @param evalSetsMap the eval RDD
|
||||||
* @return RDD[Watches]
|
* @return RDD[() => Watches]
|
||||||
*/
|
*/
|
||||||
private def trainForRanking(
|
private def trainForRanking(
|
||||||
trainingData: RDD[Array[XGBLabeledPoint]],
|
trainingData: RDD[Array[XGBLabeledPoint]],
|
||||||
xgbExecutionParam: XGBoostExecutionParams,
|
xgbExecutionParam: XGBoostExecutionParams,
|
||||||
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[Watches] = {
|
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[() => Watches] = {
|
||||||
if (evalSetsMap.isEmpty) {
|
if (evalSetsMap.isEmpty) {
|
||||||
trainingData.mapPartitions(labeledPointGroups => {
|
trainingData.mapPartitions(labeledPointGroups => {
|
||||||
val watches = Watches.buildWatchesWithGroup(xgbExecutionParam,
|
val buildWatches = () => Watches.buildWatchesWithGroup(xgbExecutionParam,
|
||||||
DataUtils.processMissingValuesWithGroup(labeledPointGroups, xgbExecutionParam.missing,
|
DataUtils.processMissingValuesWithGroup(labeledPointGroups, xgbExecutionParam.missing,
|
||||||
xgbExecutionParam.allowNonZeroForMissing),
|
xgbExecutionParam.allowNonZeroForMissing),
|
||||||
getCacheDirName(xgbExecutionParam.useExternalMemory))
|
getCacheDirName(xgbExecutionParam.useExternalMemory))
|
||||||
Iterator.single(watches)
|
Iterator.single(buildWatches)
|
||||||
}).cache()
|
}).cache()
|
||||||
} else {
|
} else {
|
||||||
coPartitionGroupSets(trainingData, evalSetsMap, xgbExecutionParam.numWorkers).mapPartitions(
|
coPartitionGroupSets(trainingData, evalSetsMap, xgbExecutionParam.numWorkers).mapPartitions(
|
||||||
labeledPointGroupSets => {
|
labeledPointGroupSets => {
|
||||||
val watches = Watches.buildWatchesWithGroup(
|
val buildWatches = () => Watches.buildWatchesWithGroup(
|
||||||
labeledPointGroupSets.map {
|
labeledPointGroupSets.map {
|
||||||
case (name, iter) => (name, DataUtils.processMissingValuesWithGroup(iter,
|
case (name, iter) => (name, DataUtils.processMissingValuesWithGroup(iter,
|
||||||
xgbExecutionParam.missing, xgbExecutionParam.allowNonZeroForMissing))
|
xgbExecutionParam.missing, xgbExecutionParam.allowNonZeroForMissing))
|
||||||
},
|
},
|
||||||
getCacheDirName(xgbExecutionParam.useExternalMemory))
|
getCacheDirName(xgbExecutionParam.useExternalMemory))
|
||||||
Iterator.single(watches)
|
Iterator.single(buildWatches)
|
||||||
}).cache()
|
}).cache()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -462,35 +465,35 @@ object PreXGBoost extends PreXGBoostProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Build RDD[Watches] for Non-Ranking
|
* Build RDD[() => Watches] for Non-Ranking
|
||||||
* @param trainingData the training data RDD
|
* @param trainingData the training data RDD
|
||||||
* @param xgbExecutionParams xgboost execution params
|
* @param xgbExecutionParams xgboost execution params
|
||||||
* @param evalSetsMap the eval RDD
|
* @param evalSetsMap the eval RDD
|
||||||
* @return RDD[Watches]
|
* @return RDD[() => Watches]
|
||||||
*/
|
*/
|
||||||
private def trainForNonRanking(
|
private def trainForNonRanking(
|
||||||
trainingData: RDD[XGBLabeledPoint],
|
trainingData: RDD[XGBLabeledPoint],
|
||||||
xgbExecutionParams: XGBoostExecutionParams,
|
xgbExecutionParams: XGBoostExecutionParams,
|
||||||
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[Watches] = {
|
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[() => Watches] = {
|
||||||
if (evalSetsMap.isEmpty) {
|
if (evalSetsMap.isEmpty) {
|
||||||
trainingData.mapPartitions { labeledPoints => {
|
trainingData.mapPartitions { labeledPoints => {
|
||||||
val watches = Watches.buildWatches(xgbExecutionParams,
|
val buildWatches = () => Watches.buildWatches(xgbExecutionParams,
|
||||||
DataUtils.processMissingValues(labeledPoints, xgbExecutionParams.missing,
|
DataUtils.processMissingValues(labeledPoints, xgbExecutionParams.missing,
|
||||||
xgbExecutionParams.allowNonZeroForMissing),
|
xgbExecutionParams.allowNonZeroForMissing),
|
||||||
getCacheDirName(xgbExecutionParams.useExternalMemory))
|
getCacheDirName(xgbExecutionParams.useExternalMemory))
|
||||||
Iterator.single(watches)
|
Iterator.single(buildWatches)
|
||||||
}}.cache()
|
}}.cache()
|
||||||
} else {
|
} else {
|
||||||
coPartitionNoGroupSets(trainingData, evalSetsMap, xgbExecutionParams.numWorkers).
|
coPartitionNoGroupSets(trainingData, evalSetsMap, xgbExecutionParams.numWorkers).
|
||||||
mapPartitions {
|
mapPartitions {
|
||||||
nameAndLabeledPointSets =>
|
nameAndLabeledPointSets =>
|
||||||
val watches = Watches.buildWatches(
|
val buildWatches = () => Watches.buildWatches(
|
||||||
nameAndLabeledPointSets.map {
|
nameAndLabeledPointSets.map {
|
||||||
case (name, iter) => (name, DataUtils.processMissingValues(iter,
|
case (name, iter) => (name, DataUtils.processMissingValues(iter,
|
||||||
xgbExecutionParams.missing, xgbExecutionParams.allowNonZeroForMissing))
|
xgbExecutionParams.missing, xgbExecutionParams.allowNonZeroForMissing))
|
||||||
},
|
},
|
||||||
getCacheDirName(xgbExecutionParams.useExternalMemory))
|
getCacheDirName(xgbExecutionParams.useExternalMemory))
|
||||||
Iterator.single(watches)
|
Iterator.single(buildWatches)
|
||||||
}.cache()
|
}.cache()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2021 by Contributors
|
Copyright (c) 2021-2022 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -45,19 +45,21 @@ private[scala] trait PreXGBoostProvider {
|
|||||||
def transformSchema(xgboostEstimator: XGBoostEstimatorCommon, schema: StructType): StructType
|
def transformSchema(xgboostEstimator: XGBoostEstimatorCommon, schema: StructType): StructType
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convert the Dataset[_] to RDD[Watches] which will be fed to XGBoost
|
* Convert the Dataset[_] to RDD[() => Watches] which will be fed to XGBoost
|
||||||
*
|
*
|
||||||
* @param estimator supports XGBoostClassifier and XGBoostRegressor
|
* @param estimator supports XGBoostClassifier and XGBoostRegressor
|
||||||
* @param dataset the training data
|
* @param dataset the training data
|
||||||
* @param params all user defined and defaulted params
|
* @param params all user defined and defaulted params
|
||||||
* @return [[XGBoostExecutionParams]] => (RDD[[Watches]], Option[ RDD[_] ])
|
* @return [[XGBoostExecutionParams]] => (Boolean, RDD[[() => Watches]], Option[ RDD[_] ])
|
||||||
* RDD[Watches] will be used as the training input
|
* Boolean if building DMatrix in rabit context
|
||||||
|
* RDD[() => Watches] will be used as the training input to build DMatrix
|
||||||
* Option[ RDD[_] ] is the optional cached RDD
|
* Option[ RDD[_] ] is the optional cached RDD
|
||||||
*/
|
*/
|
||||||
def buildDatasetToRDD(
|
def buildDatasetToRDD(
|
||||||
estimator: Estimator[_],
|
estimator: Estimator[_],
|
||||||
dataset: Dataset[_],
|
dataset: Dataset[_],
|
||||||
params: Map[String, Any]): XGBoostExecutionParams => (RDD[Watches], Option[RDD[_]])
|
params: Map[String, Any]):
|
||||||
|
XGBoostExecutionParams => (Boolean, RDD[() => Watches], Option[RDD[_]])
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Transform Dataset
|
* Transform Dataset
|
||||||
|
|||||||
@ -283,13 +283,8 @@ object XGBoost extends Serializable {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private def buildDistributedBooster(
|
private def buildWatchesAndCheck(buildWatchesFun: () => Watches): Watches = {
|
||||||
watches: Watches,
|
val watches = buildWatchesFun()
|
||||||
xgbExecutionParam: XGBoostExecutionParams,
|
|
||||||
rabitEnv: java.util.Map[String, String],
|
|
||||||
obj: ObjectiveTrait,
|
|
||||||
eval: EvalTrait,
|
|
||||||
prevBooster: Booster): Iterator[(Booster, Map[String, Array[Float]])] = {
|
|
||||||
// to workaround the empty partitions in training dataset,
|
// to workaround the empty partitions in training dataset,
|
||||||
// this might not be the best efficient implementation, see
|
// this might not be the best efficient implementation, see
|
||||||
// (https://github.com/dmlc/xgboost/issues/1277)
|
// (https://github.com/dmlc/xgboost/issues/1277)
|
||||||
@ -298,14 +293,39 @@ object XGBoost extends Serializable {
|
|||||||
s"detected an empty partition in the training data, partition ID:" +
|
s"detected an empty partition in the training data, partition ID:" +
|
||||||
s" ${TaskContext.getPartitionId()}")
|
s" ${TaskContext.getPartitionId()}")
|
||||||
}
|
}
|
||||||
|
watches
|
||||||
|
}
|
||||||
|
|
||||||
|
private def buildDistributedBooster(
|
||||||
|
buildDMatrixInRabit: Boolean,
|
||||||
|
buildWatches: () => Watches,
|
||||||
|
xgbExecutionParam: XGBoostExecutionParams,
|
||||||
|
rabitEnv: java.util.Map[String, String],
|
||||||
|
obj: ObjectiveTrait,
|
||||||
|
eval: EvalTrait,
|
||||||
|
prevBooster: Booster): Iterator[(Booster, Map[String, Array[Float]])] = {
|
||||||
|
|
||||||
|
var watches: Watches = null
|
||||||
|
if (!buildDMatrixInRabit) {
|
||||||
|
// for CPU pipeline, we need to build DMatrix out of rabit context
|
||||||
|
watches = buildWatchesAndCheck(buildWatches)
|
||||||
|
}
|
||||||
|
|
||||||
val taskId = TaskContext.getPartitionId().toString
|
val taskId = TaskContext.getPartitionId().toString
|
||||||
val attempt = TaskContext.get().attemptNumber.toString
|
val attempt = TaskContext.get().attemptNumber.toString
|
||||||
rabitEnv.put("DMLC_TASK_ID", taskId)
|
rabitEnv.put("DMLC_TASK_ID", taskId)
|
||||||
rabitEnv.put("DMLC_NUM_ATTEMPT", attempt)
|
rabitEnv.put("DMLC_NUM_ATTEMPT", attempt)
|
||||||
val numRounds = xgbExecutionParam.numRounds
|
val numRounds = xgbExecutionParam.numRounds
|
||||||
val makeCheckpoint = xgbExecutionParam.checkpointParam.isDefined && taskId.toInt == 0
|
val makeCheckpoint = xgbExecutionParam.checkpointParam.isDefined && taskId.toInt == 0
|
||||||
|
|
||||||
try {
|
try {
|
||||||
Rabit.init(rabitEnv)
|
Rabit.init(rabitEnv)
|
||||||
|
|
||||||
|
if (buildDMatrixInRabit) {
|
||||||
|
// for GPU pipeline, we need to move dmatrix building into rabit context
|
||||||
|
watches = buildWatchesAndCheck(buildWatches)
|
||||||
|
}
|
||||||
|
|
||||||
val numEarlyStoppingRounds = xgbExecutionParam.earlyStoppingParams.numEarlyStoppingRounds
|
val numEarlyStoppingRounds = xgbExecutionParam.earlyStoppingParams.numEarlyStoppingRounds
|
||||||
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](numRounds))
|
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](numRounds))
|
||||||
val externalCheckpointParams = xgbExecutionParam.checkpointParam
|
val externalCheckpointParams = xgbExecutionParam.checkpointParam
|
||||||
@ -338,7 +358,7 @@ object XGBoost extends Serializable {
|
|||||||
throw xgbException
|
throw xgbException
|
||||||
} finally {
|
} finally {
|
||||||
Rabit.shutdown()
|
Rabit.shutdown()
|
||||||
watches.delete()
|
if (watches != null) watches.delete()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -364,7 +384,7 @@ object XGBoost extends Serializable {
|
|||||||
@throws(classOf[XGBoostError])
|
@throws(classOf[XGBoostError])
|
||||||
private[spark] def trainDistributed(
|
private[spark] def trainDistributed(
|
||||||
sc: SparkContext,
|
sc: SparkContext,
|
||||||
buildTrainingData: XGBoostExecutionParams => (RDD[Watches], Option[RDD[_]]),
|
buildTrainingData: XGBoostExecutionParams => (Boolean, RDD[() => Watches], Option[RDD[_]]),
|
||||||
params: Map[String, Any]):
|
params: Map[String, Any]):
|
||||||
(Booster, Map[String, Array[Float]]) = {
|
(Booster, Map[String, Array[Float]]) = {
|
||||||
|
|
||||||
@ -383,7 +403,7 @@ object XGBoost extends Serializable {
|
|||||||
}.orNull
|
}.orNull
|
||||||
|
|
||||||
// Get the training data RDD and the cachedRDD
|
// Get the training data RDD and the cachedRDD
|
||||||
val (trainingRDD, optionalCachedRDD) = buildTrainingData(xgbExecParams)
|
val (buildDMatrixInRabit, trainingRDD, optionalCachedRDD) = buildTrainingData(xgbExecParams)
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Train for every ${savingRound} rounds and save the partially completed booster
|
// Train for every ${savingRound} rounds and save the partially completed booster
|
||||||
@ -398,15 +418,16 @@ object XGBoost extends Serializable {
|
|||||||
val rabitEnv = tracker.getWorkerEnvs
|
val rabitEnv = tracker.getWorkerEnvs
|
||||||
|
|
||||||
val boostersAndMetrics = trainingRDD.mapPartitions { iter => {
|
val boostersAndMetrics = trainingRDD.mapPartitions { iter => {
|
||||||
var optionWatches: Option[Watches] = None
|
var optionWatches: Option[() => Watches] = None
|
||||||
|
|
||||||
// take the first Watches to train
|
// take the first Watches to train
|
||||||
if (iter.hasNext) {
|
if (iter.hasNext) {
|
||||||
optionWatches = Some(iter.next())
|
optionWatches = Some(iter.next())
|
||||||
}
|
}
|
||||||
|
|
||||||
optionWatches.map { watches => buildDistributedBooster(watches, xgbExecParams, rabitEnv,
|
optionWatches.map { buildWatches => buildDistributedBooster(buildDMatrixInRabit,
|
||||||
xgbExecParams.obj, xgbExecParams.eval, prevBooster)}
|
buildWatches, xgbExecParams, rabitEnv, xgbExecParams.obj,
|
||||||
|
xgbExecParams.eval, prevBooster)}
|
||||||
.getOrElse(throw new RuntimeException("No Watches to train"))
|
.getOrElse(throw new RuntimeException("No Watches to train"))
|
||||||
|
|
||||||
}}.cache()
|
}}.cache()
|
||||||
|
|||||||
@ -119,6 +119,8 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
test("test SparkContext should not be killed ") {
|
test("test SparkContext should not be killed ") {
|
||||||
|
cancel("For some reason, sparkContext can't cancel the job locally in the CI env," +
|
||||||
|
"which will be resolved when introducing barrier mode")
|
||||||
val training = buildDataFrame(Classification.train)
|
val training = buildDataFrame(Classification.train)
|
||||||
// mock rank 0 failure during 8th allreduce synchronization
|
// mock rank 0 failure during 8th allreduce synchronization
|
||||||
Rabit.mockList = Array("0,8,0,0").toList.asJava
|
Rabit.mockList = Array("0,8,0,0").toList.asJava
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user