[jvm-packages] move dmatrix building into rabit context for cpu pipeline (#7908)

This commit is contained in:
Bobby Wang
2022-05-17 14:52:25 +08:00
committed by GitHub
parent 77d4a53c32
commit b41cf92dc2
5 changed files with 20 additions and 36 deletions

View File

@@ -61,15 +61,14 @@ class GpuPreXGBoost extends PreXGBoostProvider {
* @param estimator [[XGBoostClassifier]] or [[XGBoostRegressor]]
* @param dataset the training data
* @param params all user defined and defaulted params
* @return [[XGBoostExecutionParams]] => (Boolean, RDD[[() => Watches]], Option[ RDD[_] ])
* Boolean if building DMatrix in rabit context
* @return [[XGBoostExecutionParams]] => (RDD[[() => Watches]], Option[ RDD[_] ])
* RDD[() => Watches] will be used as the training input
* Option[ RDD[_] ] is the optional cached RDD
*/
override def buildDatasetToRDD(estimator: Estimator[_],
dataset: Dataset[_],
params: Map[String, Any]):
XGBoostExecutionParams => (Boolean, RDD[() => Watches], Option[RDD[_]]) = {
XGBoostExecutionParams => (RDD[() => Watches], Option[RDD[_]]) = {
GpuPreXGBoost.buildDatasetToRDD(estimator, dataset, params)
}
@@ -123,8 +122,7 @@ object GpuPreXGBoost extends PreXGBoostProvider {
* @param estimator supports XGBoostClassifier and XGBoostRegressor
* @param dataset the training data
* @param params all user defined and defaulted params
* @return [[XGBoostExecutionParams]] => (Boolean, RDD[[() => Watches]], Option[ RDD[_] ])
* Boolean if building DMatrix in rabit context
* @return [[XGBoostExecutionParams]] => (RDD[[() => Watches]], Option[ RDD[_] ])
* RDD[() => Watches] will be used as the training input to build DMatrix
* Option[ RDD[_] ] is the optional cached RDD
*/
@@ -132,7 +130,7 @@ object GpuPreXGBoost extends PreXGBoostProvider {
estimator: Estimator[_],
dataset: Dataset[_],
params: Map[String, Any]):
XGBoostExecutionParams => (Boolean, RDD[() => Watches], Option[RDD[_]]) = {
XGBoostExecutionParams => (RDD[() => Watches], Option[RDD[_]]) = {
val (Seq(labelName, weightName, marginName), feturesCols, groupName, evalSets) =
estimator match {
@@ -170,7 +168,7 @@ object GpuPreXGBoost extends PreXGBoostProvider {
xgbExecParams: XGBoostExecutionParams =>
val dataMap = prepareInputData(trainingData, evalDataMap, xgbExecParams.numWorkers,
xgbExecParams.cacheTrainingSet)
(true, buildRDDWatches(dataMap, xgbExecParams, evalDataMap.isEmpty), None)
(buildRDDWatches(dataMap, xgbExecParams, evalDataMap.isEmpty), None)
}
/**