explicitly throw exception when detecting empty partition in training dataset (#1281)
This commit is contained in:
parent
465e5dfb87
commit
c9a73fe2a9
@ -16,20 +16,16 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import java.nio.file.Paths
|
||||
|
||||
import scala.collection.mutable
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable
|
||||
|
||||
import org.apache.hadoop.fs.{Path, FileSystem}
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, Rabit, RabitTracker, XGBoostError}
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||
import org.apache.commons.logging.LogFactory
|
||||
import org.apache.spark.{SparkContext, TaskContext}
|
||||
import org.apache.hadoop.fs.Path
|
||||
import org.apache.spark.mllib.regression.LabeledPoint
|
||||
import org.apache.spark.rdd.RDD
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, XGBoostError, Rabit, RabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||
import org.apache.spark.{SparkContext, TaskContext}
|
||||
|
||||
object XGBoost extends Serializable {
|
||||
private val logger = LogFactory.getLog("XGBoostSpark")
|
||||
@ -58,22 +54,33 @@ object XGBoost extends Serializable {
|
||||
}
|
||||
}
|
||||
val appName = partitionedData.context.appName
|
||||
// to workaround the empty partitions in training dataset,
|
||||
// this might not be the best efficient implementation, see
|
||||
// (https://github.com/dmlc/xgboost/issues/1277)
|
||||
partitionedData.mapPartitions {
|
||||
trainingSamples =>
|
||||
rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
|
||||
Rabit.init(rabitEnv.asJava)
|
||||
val cacheFileName: String = {
|
||||
if (useExternalMemory && trainingSamples.hasNext) {
|
||||
s"$appName-dtrain_cache-${TaskContext.getPartitionId()}"
|
||||
} else {
|
||||
null
|
||||
var booster: Booster = null
|
||||
if (trainingSamples.hasNext) {
|
||||
val cacheFileName: String = {
|
||||
if (useExternalMemory && trainingSamples.hasNext) {
|
||||
s"$appName-dtrain_cache-${TaskContext.getPartitionId()}"
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}
|
||||
val trainingSet = new DMatrix(new JDMatrix(trainingSamples, cacheFileName))
|
||||
booster = SXGBoost.train(trainingSet, xgBoostConfMap, round,
|
||||
watches = new mutable.HashMap[String, DMatrix] {
|
||||
put("train", trainingSet)
|
||||
}.toMap, obj, eval)
|
||||
Rabit.shutdown()
|
||||
} else {
|
||||
Rabit.shutdown()
|
||||
throw new XGBoostError(s"detect the empty partition in training dataset, partition ID:" +
|
||||
s" ${TaskContext.getPartitionId().toString}")
|
||||
}
|
||||
val trainingSet = new DMatrix(new JDMatrix(trainingSamples, cacheFileName))
|
||||
val booster = SXGBoost.train(trainingSet, xgBoostConfMap, round,
|
||||
watches = new mutable.HashMap[String, DMatrix]{put("train", trainingSet)}.toMap,
|
||||
obj, eval)
|
||||
Rabit.shutdown()
|
||||
Iterator(booster)
|
||||
}.cache()
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user