diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 5903bd2c9..a9e771ae7 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -16,20 +16,16 @@ package ml.dmlc.xgboost4j.scala.spark -import java.nio.file.Paths - -import scala.collection.mutable import scala.collection.JavaConverters._ +import scala.collection.mutable -import org.apache.hadoop.fs.{Path, FileSystem} - +import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, Rabit, RabitTracker, XGBoostError} +import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _} import org.apache.commons.logging.LogFactory -import org.apache.spark.{SparkContext, TaskContext} +import org.apache.hadoop.fs.Path import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD - -import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, XGBoostError, Rabit, RabitTracker} -import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _} +import org.apache.spark.{SparkContext, TaskContext} object XGBoost extends Serializable { private val logger = LogFactory.getLog("XGBoostSpark") @@ -58,22 +54,33 @@ object XGBoost extends Serializable { } } val appName = partitionedData.context.appName + // to workaround the empty partitions in training dataset, + // this might not be the best efficient implementation, see + // (https://github.com/dmlc/xgboost/issues/1277) partitionedData.mapPartitions { trainingSamples => rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString) Rabit.init(rabitEnv.asJava) - val cacheFileName: String = { - if (useExternalMemory && trainingSamples.hasNext) { - s"$appName-dtrain_cache-${TaskContext.getPartitionId()}" - } else { - null + var booster: Booster = null + if (trainingSamples.hasNext) { + val cacheFileName: String = { + if (useExternalMemory && trainingSamples.hasNext) { + s"$appName-dtrain_cache-${TaskContext.getPartitionId()}" + } else { + null + } } + val trainingSet = new DMatrix(new JDMatrix(trainingSamples, cacheFileName)) + booster = SXGBoost.train(trainingSet, xgBoostConfMap, round, + watches = new mutable.HashMap[String, DMatrix] { + put("train", trainingSet) + }.toMap, obj, eval) + Rabit.shutdown() + } else { + Rabit.shutdown() + throw new XGBoostError(s"detect the empty partition in training dataset, partition ID:" + + s" ${TaskContext.getPartitionId().toString}") } - val trainingSet = new DMatrix(new JDMatrix(trainingSamples, cacheFileName)) - val booster = SXGBoost.train(trainingSet, xgBoostConfMap, round, - watches = new mutable.HashMap[String, DMatrix]{put("train", trainingSet)}.toMap, - obj, eval) - Rabit.shutdown() Iterator(booster) }.cache() }