diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala index bda9189b7..b83a01814 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala @@ -16,30 +16,26 @@ package ml.dmlc.xgboost4j.scala.spark -import scala.collection.Iterator -import scala.collection.JavaConverters._ -import scala.collection.mutable - import ml.dmlc.xgboost4j.java.Rabit -import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost} -import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait} import ml.dmlc.xgboost4j.scala.spark.params._ +import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait, ObjectiveTrait, XGBoost => SXGBoost} import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} import org.apache.hadoop.fs.Path - import org.apache.spark.TaskContext +import org.apache.spark.broadcast.Broadcast import org.apache.spark.ml.classification._ import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.HasWeightCol import org.apache.spark.ml.util._ import org.apache.spark.rdd.RDD +import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ -import org.apache.spark.sql._ import org.json4s.DefaultFormats -import org.apache.spark.broadcast.Broadcast +import scala.collection.JavaConverters._ +import scala.collection.{AbstractIterator, Iterator, mutable} private[spark] trait XGBoostClassifierParams extends GeneralParams with LearningTaskParams with BoosterParams with HasWeightCol with HasBaseMarginCol with HasNumClass with ParamMapFuncs @@ -216,7 +212,8 @@ class XGBoostClassificationModel private[ml]( override val numClasses: Int, private[spark] val _booster: Booster) extends ProbabilisticClassificationModel[Vector, XGBoostClassificationModel] - with XGBoostClassifierParams with MLWritable with Serializable { + with XGBoostClassifierParams with InferenceParams + with MLWritable with Serializable { import XGBoostClassificationModel._ @@ -250,6 +247,8 @@ class XGBoostClassificationModel private[ml]( def setTreeLimit(value: Int): this.type = set(treeLimit, value) + def setInferBatchSize(value: Int): this.type = set(inferBatchSize, value) + /** * Single instance prediction. * Note: The performance is not ideal, use it carefully! @@ -287,46 +286,53 @@ class XGBoostClassificationModel private[ml]( val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster) val appName = dataset.sparkSession.sparkContext.appName - val inputRDD = dataset.asInstanceOf[Dataset[Row]].rdd - val predictionRDD = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator => - if (rowIterator.hasNext) { - val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap - Rabit.init(rabitEnv.asJava) - val featuresIterator = rowIterator.map(row => row.getAs[Vector]( - $(featuresCol))).toList.iterator - import DataUtils._ - val cacheInfo = { - if ($(useExternalMemory)) { - s"$appName-${TaskContext.get().stageId()}-dtest_cache-${TaskContext.getPartitionId()}" - } else { - null + val resultRDD = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator => + new AbstractIterator[Row] { + private var batchCnt = 0 + + private val batchIterImpl = rowIterator.grouped($(inferBatchSize)).flatMap { batchRow => + if (batchCnt == 0) { + val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap + Rabit.init(rabitEnv.asJava) + } + + val features = batchRow.iterator.map(row => row.getAs[Vector]($(featuresCol))) + + import DataUtils._ + val cacheInfo = { + if ($(useExternalMemory)) { + s"$appName-${TaskContext.get().stageId()}-dtest_cache-" + + s"${TaskContext.getPartitionId()}-batch-$batchCnt" + } else { + null + } + } + + val dm = new DMatrix( + XGBoost.processMissingValues(features.map(_.asXGB), $(missing)), + cacheInfo) + try { + val Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr) = + producePredictionItrs(bBooster, dm) + produceResultIterator(batchRow.iterator, + rawPredictionItr, probabilityItr, predLeafItr, predContribItr) + } finally { + batchCnt += 1 + dm.delete() } } - val dm = new DMatrix( - XGBoost.processMissingValues(featuresIterator.map(_.asXGB), $(missing)), - cacheInfo) - try { - val Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr) = - producePredictionItrs(bBooster, dm) - Rabit.shutdown() - Iterator(rawPredictionItr, probabilityItr, predLeafItr, - predContribItr) - } finally { - dm.delete() + + override def hasNext: Boolean = batchIterImpl.hasNext + + override def next(): Row = { + val ret = batchIterImpl.next() + if (!batchIterImpl.hasNext) { + Rabit.shutdown() + } + ret } - } else { - Iterator() } } - val resultRDD = inputRDD.zipPartitions(predictionRDD, preservesPartitioning = true) { - case (inputIterator, predictionItr) => - if (inputIterator.hasNext) { - produceResultIterator(inputIterator, predictionItr.next(), predictionItr.next(), - predictionItr.next(), predictionItr.next()) - } else { - Iterator() - } - } bBooster.unpersist(blocking = false) dataset.sparkSession.createDataFrame(resultRDD, generateResultSchema(schema)) @@ -527,4 +533,3 @@ object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel] } } } - diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala index b47bca27b..9b4fd3742 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala @@ -16,10 +16,10 @@ package ml.dmlc.xgboost4j.scala.spark -import scala.collection.Iterator +import scala.collection.{AbstractIterator, Iterator, mutable} import scala.collection.JavaConverters._ -import ml.dmlc.xgboost4j.java.Rabit +import ml.dmlc.xgboost4j.java.{Rabit, XGBoost => JXGBoost} import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} import ml.dmlc.xgboost4j.scala.spark.params.{DefaultXGBoostParamsReader, _} import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost} @@ -37,7 +37,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.json4s.DefaultFormats -import scala.collection.mutable +import scala.collection.mutable.ListBuffer import org.apache.spark.broadcast.Broadcast @@ -207,7 +207,8 @@ class XGBoostRegressionModel private[ml] ( override val uid: String, private[spark] val _booster: Booster) extends PredictionModel[Vector, XGBoostRegressionModel] - with XGBoostRegressorParams with MLWritable with Serializable { + with XGBoostRegressorParams with InferenceParams + with MLWritable with Serializable { import XGBoostRegressionModel._ @@ -241,6 +242,8 @@ class XGBoostRegressionModel private[ml] ( def setTreeLimit(value: Int): this.type = set(treeLimit, value) + def setInferBatchSize(value: Int): this.type = set(inferBatchSize, value) + /** * Single instance prediction. * Note: The performance is not ideal, use it carefully! @@ -259,45 +262,53 @@ class XGBoostRegressionModel private[ml] ( val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster) val appName = dataset.sparkSession.sparkContext.appName - val inputRDD = dataset.asInstanceOf[Dataset[Row]].rdd - val predictionRDD = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator => - if (rowIterator.hasNext) { - val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap - Rabit.init(rabitEnv.asJava) - val featuresIterator = rowIterator.map(row => row.getAs[Vector]( - $(featuresCol))).toList.iterator - import DataUtils._ - val cacheInfo = { - if ($(useExternalMemory)) { - s"$appName-${TaskContext.get().stageId()}-dtest_cache-${TaskContext.getPartitionId()}" - } else { - null + + val resultRDD = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator => + new AbstractIterator[Row] { + private var batchCnt = 0 + + private val batchIterImpl = rowIterator.grouped($(inferBatchSize)).flatMap { batchRow => + if (batchCnt == 0) { + val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap + Rabit.init(rabitEnv.asJava) + } + + val features = batchRow.iterator.map(row => row.getAs[Vector]($(featuresCol))) + + import DataUtils._ + val cacheInfo = { + if ($(useExternalMemory)) { + s"$appName-${TaskContext.get().stageId()}-dtest_cache-" + + s"${TaskContext.getPartitionId()}-batch-$batchCnt" + } else { + null + } + } + + val dm = new DMatrix( + XGBoost.processMissingValues(features.map(_.asXGB), $(missing)), + cacheInfo) + try { + val Array(rawPredictionItr, predLeafItr, predContribItr) = + producePredictionItrs(bBooster, dm) + produceResultIterator(batchRow.iterator, rawPredictionItr, predLeafItr, predContribItr) + } finally { + batchCnt += 1 + dm.delete() } } - val dm = new DMatrix( - XGBoost.processMissingValues(featuresIterator.map(_.asXGB), $(missing)), - cacheInfo) - try { - val Array(originalPredictionItr, predLeafItr, predContribItr) = - producePredictionItrs(bBooster, dm) - Rabit.shutdown() - Iterator(originalPredictionItr, predLeafItr, predContribItr) - } finally { - dm.delete() + + override def hasNext: Boolean = batchIterImpl.hasNext + + override def next(): Row = { + val ret = batchIterImpl.next() + if (!batchIterImpl.hasNext) { + Rabit.shutdown() + } + ret } - } else { - Iterator() } } - val resultRDD = inputRDD.zipPartitions(predictionRDD, preservesPartitioning = true) { - case (inputIterator, predictionItr) => - if (inputIterator.hasNext) { - produceResultIterator(inputIterator, predictionItr.next(), predictionItr.next(), - predictionItr.next()) - } else { - Iterator() - } - } bBooster.unpersist(blocking = false) dataset.sparkSession.createDataFrame(resultRDD, generateResultSchema(schema)) } @@ -347,14 +358,14 @@ class XGBoostRegressionModel private[ml] ( resultSchema } - private def producePredictionItrs(broadcastBooster: Broadcast[Booster], dm: DMatrix): + private def producePredictionItrs(booster: Broadcast[Booster], dm: DMatrix): Array[Iterator[Row]] = { val originalPredictionItr = { - broadcastBooster.value.predict(dm, outPutMargin = false, $(treeLimit)).map(Row(_)).iterator + booster.value.predict(dm, outPutMargin = false, $(treeLimit)).map(Row(_)).iterator } val predLeafItr = { if (isDefined(leafPredictionCol)) { - broadcastBooster.value.predictLeaf(dm, $(treeLimit)). + booster.value.predictLeaf(dm, $(treeLimit)). map(Row(_)).iterator } else { Iterator() @@ -362,7 +373,7 @@ class XGBoostRegressionModel private[ml] ( } val predContribItr = { if (isDefined(contribPredictionCol)) { - broadcastBooster.value.predictContrib(dm, $(treeLimit)). + booster.value.predictContrib(dm, $(treeLimit)). map(Row(_)).iterator } else { Iterator() @@ -373,7 +384,6 @@ class XGBoostRegressionModel private[ml] ( override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) - // Output selected columns only. // This is a bit complicated since it tries to avoid repeated computation. var outputData = transformInternal(dataset) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/InferenceParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/InferenceParams.scala new file mode 100644 index 000000000..abfe777d3 --- /dev/null +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/InferenceParams.scala @@ -0,0 +1,32 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala.spark.params + +import org.apache.spark.ml.param.{IntParam, Params} + +private[spark] trait InferenceParams extends Params { + + /** + * batch size of inference iteration + */ + final val inferBatchSize = new IntParam(this, "batchSize", "batch size of inference iteration") + + /** @group getParam */ + final def getInferBatchSize: Int = ${inferBatchSize} + + setDefault(inferBatchSize, 32 << 10) +} diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala index 7bba5f342..81ed33f20 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala @@ -19,11 +19,12 @@ package ml.dmlc.xgboost4j.scala.spark import java.io.File import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} - import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql._ import org.scalatest.{BeforeAndAfterEach, FunSuite} +import scala.util.Random + trait PerTest extends BeforeAndAfterEach { self: FunSuite => protected val numWorkers: Int = Runtime.getRuntime.availableProcessors() @@ -80,6 +81,18 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite => .toDF("id", "label", "features") } + protected def buildDataFrameWithRandSort( + labeledPoints: Seq[XGBLabeledPoint], + numPartitions: Int = numWorkers): DataFrame = { + val df = buildDataFrame(labeledPoints, numPartitions) + val rndSortedRDD = df.rdd.mapPartitions { iter => + iter.map(_ -> Random.nextDouble()).toList + .sortBy(_._2) + .map(_._1).iterator + } + ss.createDataFrame(rndSortedRDD, df.schema) + } + protected def buildDataFrameWithGroup( labeledPoints: Seq[XGBLabeledPoint], numPartitions: Int = numWorkers): DataFrame = { diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala index 4e0a2a016..2e94442c4 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala @@ -27,13 +27,28 @@ import org.apache.spark.Partitioner class XGBoostClassifierSuite extends FunSuite with PerTest { - test("XGBoost-Spark XGBoostClassifier ouput should match XGBoost4j") { + test("XGBoost-Spark XGBoostClassifier output should match XGBoost4j") { val trainingDM = new DMatrix(Classification.train.iterator) val testDM = new DMatrix(Classification.test.iterator) val trainingDF = buildDataFrame(Classification.train) val testDF = buildDataFrame(Classification.test) - val round = 5 + checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF) + } + test("XGBoostClassifier should make correct predictions after upstream random sort") { + val trainingDM = new DMatrix(Classification.train.iterator) + val testDM = new DMatrix(Classification.test.iterator) + val trainingDF = buildDataFrameWithRandSort(Classification.train) + val testDF = buildDataFrameWithRandSort(Classification.test) + checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF) + } + + private def checkResultsWithXGBoost4j( + trainingDM: DMatrix, + testDM: DMatrix, + trainingDF: DataFrame, + testDF: DataFrame, + round: Int = 5): Unit = { val paramMap = Map( "eta" -> "1", "max_depth" -> "6", @@ -47,7 +62,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest { "num_workers" -> numWorkers)).fit(trainingDF) val prediction2 = model2.transform(testDF). - collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("probability"))).toMap + collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("probability"))).toMap assert(testDF.count() === prediction2.size) // the vector length in probability column is 2 since we have to fit to the evaluator in Spark @@ -60,7 +75,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest { val prediction3 = model1.predict(testDM, outPutMargin = true) val prediction4 = model2.transform(testDF). - collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("rawPrediction"))).toMap + collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("rawPrediction"))).toMap assert(testDF.count() === prediction4.size) // the vector length in rawPrediction column is 2 since we have to fit to the evaluator in Spark @@ -73,7 +88,9 @@ class XGBoostClassifierSuite extends FunSuite with PerTest { // check the equality of single instance prediction val firstOfDM = testDM.slice(Array(0)) - val firstOfDF = testDF.head().getAs[Vector]("features") + val firstOfDF = testDF.filter(_.getAs[Int]("id") == 0) + .head() + .getAs[Vector]("features") val prediction5 = math.round(model1.predict(firstOfDM)(0)(0)) val prediction6 = model2.predict(firstOfDF) assert(prediction5 === prediction6) diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala index 1affe1474..70d2634c4 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala @@ -463,4 +463,42 @@ class XGBoostGeneralSuite extends FunSuite with PerTest { assert(model2.summary.trainObjectiveHistory !== model2.summary.validationObjectiveHistory(0)) assert(model2.summary.trainObjectiveHistory !== model2.summary.validationObjectiveHistory(1)) } + + test("infer with different batch sizes") { + val regModel = new XGBoostRegressor(Map( + "eta" -> "1", + "max_depth" -> "6", + "silent" -> "1", + "objective" -> "reg:squarederror", + "num_round" -> 5, + "num_workers" -> numWorkers)) + .fit(buildDataFrame(Regression.train)) + val regDF = buildDataFrame(Regression.test) + + val regRet1 = regModel.transform(regDF).collect() + val regRet2 = regModel.setInferBatchSize(1).transform(regDF).collect() + val regRet3 = regModel.setInferBatchSize(10).transform(regDF).collect() + val regRet4 = regModel.setInferBatchSize(32 << 15).transform(regDF).collect() + assert(regRet1 sameElements regRet2) + assert(regRet1 sameElements regRet3) + assert(regRet1 sameElements regRet4) + + val clsModel = new XGBoostClassifier(Map( + "eta" -> "1", + "max_depth" -> "6", + "silent" -> "1", + "objective" -> "binary:logistic", + "num_round" -> 5, + "num_workers" -> numWorkers)) + .fit(buildDataFrame(Classification.train)) + val clsDF = buildDataFrame(Classification.test) + + val clsRet1 = clsModel.transform(clsDF).collect() + val clsRet2 = clsModel.setInferBatchSize(1).transform(clsDF).collect() + val clsRet3 = clsModel.setInferBatchSize(10).transform(clsDF).collect() + val clsRet4 = clsModel.setInferBatchSize(32 << 15).transform(clsDF).collect() + assert(clsRet1 sameElements clsRet2) + assert(clsRet1 sameElements clsRet3) + assert(clsRet1 sameElements clsRet4) + } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala index b21148013..ff88ff328 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala @@ -19,19 +19,34 @@ package ml.dmlc.xgboost4j.scala.spark import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost} import org.apache.spark.ml.linalg.Vector import org.apache.spark.sql.functions._ -import org.apache.spark.sql.Row +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.types._ import org.scalatest.FunSuite class XGBoostRegressorSuite extends FunSuite with PerTest { - test("XGBoost-Spark XGBoostRegressor ouput should match XGBoost4j: regression") { + test("XGBoost-Spark XGBoostRegressor output should match XGBoost4j") { val trainingDM = new DMatrix(Regression.train.iterator) val testDM = new DMatrix(Regression.test.iterator) val trainingDF = buildDataFrame(Regression.train) val testDF = buildDataFrame(Regression.test) - val round = 5 + checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF) + } + test("XGBoostRegressor should make correct predictions after upstream random sort") { + val trainingDM = new DMatrix(Regression.train.iterator) + val testDM = new DMatrix(Regression.test.iterator) + val trainingDF = buildDataFrameWithRandSort(Regression.train) + val testDF = buildDataFrameWithRandSort(Regression.test) + checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF) + } + + private def checkResultsWithXGBoost4j( + trainingDM: DMatrix, + testDM: DMatrix, + trainingDF: DataFrame, + testDF: DataFrame, + round: Int = 5): Unit = { val paramMap = Map( "eta" -> "1", "max_depth" -> "6", @@ -45,7 +60,7 @@ class XGBoostRegressorSuite extends FunSuite with PerTest { "num_workers" -> numWorkers)).fit(trainingDF) val prediction2 = model2.transform(testDF). - collect().map(row => (row.getAs[Int]("id"), row.getAs[Double]("prediction"))).toMap + collect().map(row => (row.getAs[Int]("id"), row.getAs[Double]("prediction"))).toMap assert(prediction1.indices.count { i => math.abs(prediction1(i)(0) - prediction2(i)) > 0.01 @@ -54,7 +69,9 @@ class XGBoostRegressorSuite extends FunSuite with PerTest { // check the equality of single instance prediction val firstOfDM = testDM.slice(Array(0)) - val firstOfDF = testDF.head().getAs[Vector]("features") + val firstOfDF = testDF.filter(_.getAs[Int]("id") == 0) + .head() + .getAs[Vector]("features") val prediction3 = model1.predict(firstOfDM)(0)(0) val prediction4 = model2.predict(firstOfDF) assert(math.abs(prediction3 - prediction4) <= 0.01f)