explicitly throw exception when detecting empty partition in training dataset (#1281)
This commit is contained in:
parent
465e5dfb87
commit
c9a73fe2a9
@ -16,20 +16,16 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
import java.nio.file.Paths
|
|
||||||
|
|
||||||
import scala.collection.mutable
|
|
||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
|
import scala.collection.mutable
|
||||||
|
|
||||||
import org.apache.hadoop.fs.{Path, FileSystem}
|
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, Rabit, RabitTracker, XGBoostError}
|
||||||
|
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||||
import org.apache.commons.logging.LogFactory
|
import org.apache.commons.logging.LogFactory
|
||||||
import org.apache.spark.{SparkContext, TaskContext}
|
import org.apache.hadoop.fs.Path
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint
|
import org.apache.spark.mllib.regression.LabeledPoint
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
|
import org.apache.spark.{SparkContext, TaskContext}
|
||||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, XGBoostError, Rabit, RabitTracker}
|
|
||||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
|
||||||
|
|
||||||
object XGBoost extends Serializable {
|
object XGBoost extends Serializable {
|
||||||
private val logger = LogFactory.getLog("XGBoostSpark")
|
private val logger = LogFactory.getLog("XGBoostSpark")
|
||||||
@ -58,22 +54,33 @@ object XGBoost extends Serializable {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
val appName = partitionedData.context.appName
|
val appName = partitionedData.context.appName
|
||||||
|
// to workaround the empty partitions in training dataset,
|
||||||
|
// this might not be the best efficient implementation, see
|
||||||
|
// (https://github.com/dmlc/xgboost/issues/1277)
|
||||||
partitionedData.mapPartitions {
|
partitionedData.mapPartitions {
|
||||||
trainingSamples =>
|
trainingSamples =>
|
||||||
rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
|
rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
|
||||||
Rabit.init(rabitEnv.asJava)
|
Rabit.init(rabitEnv.asJava)
|
||||||
val cacheFileName: String = {
|
var booster: Booster = null
|
||||||
if (useExternalMemory && trainingSamples.hasNext) {
|
if (trainingSamples.hasNext) {
|
||||||
s"$appName-dtrain_cache-${TaskContext.getPartitionId()}"
|
val cacheFileName: String = {
|
||||||
} else {
|
if (useExternalMemory && trainingSamples.hasNext) {
|
||||||
null
|
s"$appName-dtrain_cache-${TaskContext.getPartitionId()}"
|
||||||
|
} else {
|
||||||
|
null
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
val trainingSet = new DMatrix(new JDMatrix(trainingSamples, cacheFileName))
|
||||||
|
booster = SXGBoost.train(trainingSet, xgBoostConfMap, round,
|
||||||
|
watches = new mutable.HashMap[String, DMatrix] {
|
||||||
|
put("train", trainingSet)
|
||||||
|
}.toMap, obj, eval)
|
||||||
|
Rabit.shutdown()
|
||||||
|
} else {
|
||||||
|
Rabit.shutdown()
|
||||||
|
throw new XGBoostError(s"detect the empty partition in training dataset, partition ID:" +
|
||||||
|
s" ${TaskContext.getPartitionId().toString}")
|
||||||
}
|
}
|
||||||
val trainingSet = new DMatrix(new JDMatrix(trainingSamples, cacheFileName))
|
|
||||||
val booster = SXGBoost.train(trainingSet, xgBoostConfMap, round,
|
|
||||||
watches = new mutable.HashMap[String, DMatrix]{put("train", trainingSet)}.toMap,
|
|
||||||
obj, eval)
|
|
||||||
Rabit.shutdown()
|
|
||||||
Iterator(booster)
|
Iterator(booster)
|
||||||
}.cache()
|
}.cache()
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user