diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala index 14fa3c0f6..75a91e64c 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala @@ -32,8 +32,12 @@ class XGBoostModel(_booster: Booster)(implicit val sc: SparkContext) extends Ser import DataUtils._ val broadcastBooster = testSet.sparkContext.broadcast(_booster) testSet.mapPartitions { testSamples => - val dMatrix = new DMatrix(new JDMatrix(testSamples, null)) - Iterator(broadcastBooster.value.predict(dMatrix)) + if (testSamples.hasNext) { + val dMatrix = new DMatrix(new JDMatrix(testSamples, null)) + Iterator(broadcastBooster.value.predict(dMatrix)) + } else { + Iterator() + } } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala index 6f4e98aa3..711ea35f0 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala @@ -23,7 +23,7 @@ import scala.collection.mutable.ListBuffer import scala.io.Source import org.apache.commons.logging.LogFactory -import org.apache.spark.mllib.linalg.DenseVector +import org.apache.spark.mllib.linalg.{Vector => SparkVector, Vectors, DenseVector} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD import org.apache.spark.{SparkConf, SparkContext} @@ -190,4 +190,24 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter { assert(eval.eval(xgBoostModel.predict(testSetDMatrix), testSetDMatrix) < 0.1) customSparkContext.stop() } + + test("test with empty partition") { + + def buildEmptyRDD(sparkContext: Option[SparkContext] = None): RDD[SparkVector] = { + val sampleList = new ListBuffer[SparkVector] + sparkContext.getOrElse(sc).parallelize(sampleList, numWorkers) + } + + val eval = new EvalError() + val trainingRDD = buildTrainingRDD() + val testRDD = buildEmptyRDD() + import DataUtils._ + val tempDir = Files.createTempDirectory("xgboosttest-") + val tempFile = Files.createTempFile(tempDir, "", "") + val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0", + "objective" -> "binary:logistic").toMap + val xgBoostModel = XGBoost.train(trainingRDD, paramMap, 5, numWorkers) + + println(xgBoostModel.predict(testRDD)) + } }