diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 68a42fd12..7bebf223e 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -377,7 +377,7 @@ object XGBoost extends Serializable { // to workaround the empty partitions in training dataset, // this might not be the best efficient implementation, see // (https://github.com/dmlc/xgboost/issues/1277) - if (watches.toMap("train").rowNum == 0) { + if (!watches.toMap.contains("train")) { throw new XGBoostError( s"detected an empty partition in the training data, partition ID:" + s" ${TaskContext.getPartitionId()}") diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala index 3ee2b21f2..76040ac63 100755 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala @@ -16,10 +16,12 @@ package ml.dmlc.xgboost4j.scala.spark +import ml.dmlc.xgboost4j.java.XGBoostError + import scala.util.Random import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} import ml.dmlc.xgboost4j.scala.DMatrix -import org.apache.spark.{TaskContext} +import org.apache.spark.TaskContext import org.scalatest.FunSuite import org.apache.spark.ml.feature.VectorAssembler import org.apache.spark.sql.functions.lit @@ -367,4 +369,16 @@ class XGBoostGeneralSuite extends FunSuite with TmpFolderPerSuite with PerTest { df2.collect() } + test("throw exception for empty partition in trainingset") { + val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "multi:softmax", "num_class" -> "2", "num_round" -> 5, + "num_workers" -> numWorkers, "tree_method" -> "auto") + // The Dmatrix will be empty + val trainingDF = buildDataFrame(Seq(XGBLabeledPoint(1.0f, 1, Array(), Array()))) + val xgb = new XGBoostClassifier(paramMap) + intercept[XGBoostError] { + val model = xgb.fit(trainingDF) + } + } + }