diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala index a2ea44443..95f7fc9ea 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala @@ -59,19 +59,18 @@ abstract class XGBoostModel(protected var _booster: Booster) * * @param testSet test set represented as RDD */ - def predictLeaves(testSet: RDD[MLVector]): RDD[Array[Array[Float]]] = { + def predictLeaves(testSet: RDD[MLVector]): RDD[Array[Float]] = { import DataUtils._ val broadcastBooster = testSet.sparkContext.broadcast(_booster) testSet.mapPartitions { testSamples => val rabitEnv = Map("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString) Rabit.init(rabitEnv.asJava) - if (testSamples.hasNext) { + if (testSamples.nonEmpty) { val dMatrix = new DMatrix(new JDMatrix(testSamples, null)) try { - val res = broadcastBooster.value.predictLeaf(dMatrix) - Rabit.shutdown() - Iterator(res) + broadcastBooster.value.predictLeaf(dMatrix).iterator } finally { + Rabit.shutdown() dMatrix.delete() } } else { @@ -151,7 +150,7 @@ abstract class XGBoostModel(protected var _booster: Booster) * @param testSet test set represented as RDD * @param missingValue the specified value to represent the missing value */ - def predict(testSet: RDD[MLDenseVector], missingValue: Float): RDD[Array[Array[Float]]] = { + def predict(testSet: RDD[MLDenseVector], missingValue: Float): RDD[Array[Float]] = { val broadcastBooster = testSet.sparkContext.broadcast(_booster) testSet.mapPartitions { testSamples => val sampleArray = testSamples.toList @@ -169,7 +168,7 @@ abstract class XGBoostModel(protected var _booster: Booster) } val dMatrix = new DMatrix(flatSampleArray, numRows, numColumns, missingValue) try { - Iterator(broadcastBooster.value.predict(dMatrix)) + broadcastBooster.value.predict(dMatrix).iterator } finally { Rabit.shutdown() dMatrix.delete() @@ -188,7 +187,7 @@ abstract class XGBoostModel(protected var _booster: Booster) def predict( testSet: RDD[MLVector], useExternalCache: Boolean = false, - outputMargin: Boolean = false): RDD[Array[Array[Float]]] = { + outputMargin: Boolean = false): RDD[Array[Float]] = { val broadcastBooster = testSet.sparkContext.broadcast(_booster) val appName = testSet.context.appName testSet.mapPartitions { testSamples => @@ -205,7 +204,7 @@ abstract class XGBoostModel(protected var _booster: Booster) } val dMatrix = new DMatrix(new JDMatrix(testSamples, cacheFileName)) try { - Iterator(broadcastBooster.value.predict(dMatrix)) + broadcastBooster.value.predict(dMatrix).iterator } finally { Rabit.shutdown() dMatrix.delete() diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala index d4007401b..83ee6da9a 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala @@ -252,7 +252,7 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils { "objective" -> "binary:logistic") val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers) val predRDD = xgBoostModel.predict(testRDD) - val predResult1 = predRDD.collect()(0) + val predResult1 = predRDD.collect() assert(testRDD.count() === predResult1.length) import DataUtils._ val predResult2 = xgBoostModel.booster.predict(new DMatrix(testSet.iterator)) @@ -273,14 +273,11 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils { test("test prediction functionality with empty partition") { def buildEmptyRDD(sparkContext: Option[SparkContext] = None): RDD[SparkVector] = { - val sampleList = new ListBuffer[SparkVector] - sparkContext.getOrElse(sc).parallelize(sampleList, numWorkers) + sparkContext.getOrElse(sc).parallelize(List[SparkVector](), numWorkers) } val trainingRDD = buildTrainingRDD(sc) val testRDD = buildEmptyRDD() - val tempDir = Files.createTempDirectory("xgboosttest-") - val tempFile = Files.createTempFile(tempDir, "", "") val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1", "objective" -> "binary:logistic").toMap val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers) @@ -358,7 +355,7 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils { val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, nWorkers = 1) val predRDD = xgBoostModel.predict(testRDD) - val predResult1: Array[Array[Float]] = predRDD.collect()(0) + val predResult1: Array[Array[Float]] = predRDD.collect() assert(testRDD.count() === predResult1.length) val avgMetric = xgBoostModel.eval(trainingRDD, "test", iter = 0, groupData = trainGroupData) @@ -386,7 +383,7 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils { val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, nWorkers = 2) val predRDD = xgBoostModel.predict(testRDD) - val predResult1: Array[Array[Float]] = predRDD.collect()(0) + val predResult1: Array[Array[Float]] = predRDD.collect() assert(testRDD.count() === predResult1.length) } @@ -403,7 +400,7 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils { val trainMargin = { XGBoost.trainWithRDD(trainRDD, paramMap, round = 1, nWorkers = 2) .predict(trainRDD.map(_.features), outputMargin = true) - .flatMap { _.flatten.iterator } + .map { case Array(m) => m } } val xgBoostModel = XGBoost.trainWithRDD( @@ -413,6 +410,6 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils { nWorkers = 2, baseMargin = trainMargin) - assert(testRDD.count() === xgBoostModel.predict(testRDD).first().length) + assert(testRDD.count() === xgBoostModel.predict(testRDD).count()) } }