explicitly throw exception when detecting empty partition in training dataset (#1281)

This commit is contained in:
Nan Zhu 2016-06-15 16:03:37 -04:00 committed by GitHub
parent 465e5dfb87
commit c9a73fe2a9

View File

@ -16,20 +16,16 @@
package ml.dmlc.xgboost4j.scala.spark
import java.nio.file.Paths
import scala.collection.mutable
import scala.collection.JavaConverters._
import scala.collection.mutable
import org.apache.hadoop.fs.{Path, FileSystem}
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, Rabit, RabitTracker, XGBoostError}
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
import org.apache.commons.logging.LogFactory
import org.apache.spark.{SparkContext, TaskContext}
import org.apache.hadoop.fs.Path
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, XGBoostError, Rabit, RabitTracker}
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
import org.apache.spark.{SparkContext, TaskContext}
object XGBoost extends Serializable {
private val logger = LogFactory.getLog("XGBoostSpark")
@ -58,22 +54,33 @@ object XGBoost extends Serializable {
}
}
val appName = partitionedData.context.appName
// to workaround the empty partitions in training dataset,
// this might not be the best efficient implementation, see
// (https://github.com/dmlc/xgboost/issues/1277)
partitionedData.mapPartitions {
trainingSamples =>
rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
Rabit.init(rabitEnv.asJava)
val cacheFileName: String = {
if (useExternalMemory && trainingSamples.hasNext) {
s"$appName-dtrain_cache-${TaskContext.getPartitionId()}"
} else {
null
var booster: Booster = null
if (trainingSamples.hasNext) {
val cacheFileName: String = {
if (useExternalMemory && trainingSamples.hasNext) {
s"$appName-dtrain_cache-${TaskContext.getPartitionId()}"
} else {
null
}
}
val trainingSet = new DMatrix(new JDMatrix(trainingSamples, cacheFileName))
booster = SXGBoost.train(trainingSet, xgBoostConfMap, round,
watches = new mutable.HashMap[String, DMatrix] {
put("train", trainingSet)
}.toMap, obj, eval)
Rabit.shutdown()
} else {
Rabit.shutdown()
throw new XGBoostError(s"detect the empty partition in training dataset, partition ID:" +
s" ${TaskContext.getPartitionId().toString}")
}
val trainingSet = new DMatrix(new JDMatrix(trainingSamples, cacheFileName))
val booster = SXGBoost.train(trainingSet, xgBoostConfMap, round,
watches = new mutable.HashMap[String, DMatrix]{put("train", trainingSet)}.toMap,
obj, eval)
Rabit.shutdown()
Iterator(booster)
}.cache()
}