[BLOCKING][jvm-packages] fix non-deterministic order within a partition (in the case of an upstream shuffle) on prediction (#4388)

* [jvm-packages][hot-fix] fix column mismatch caused by zip actions at XGBooostModel.transformInternal

* apply minibatch in prediction

* an iterator-compatible minibatch prediction

* regressor impl

* continuous working on mini-batch prediction of xgboost4j-spark

* Update Booster.java
This commit is contained in:
Xu Xiao 2019-04-27 02:09:20 +08:00 committed by Nan Zhu
parent 503cc42f48
commit 2d875ec019
7 changed files with 232 additions and 100 deletions

View File

@ -16,30 +16,26 @@
package ml.dmlc.xgboost4j.scala.spark package ml.dmlc.xgboost4j.scala.spark
import scala.collection.Iterator
import scala.collection.JavaConverters._
import scala.collection.mutable
import ml.dmlc.xgboost4j.java.Rabit import ml.dmlc.xgboost4j.java.Rabit
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
import ml.dmlc.xgboost4j.scala.spark.params._ import ml.dmlc.xgboost4j.scala.spark.params._
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait, ObjectiveTrait, XGBoost => SXGBoost}
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.hadoop.fs.Path import org.apache.hadoop.fs.Path
import org.apache.spark.TaskContext import org.apache.spark.TaskContext
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.classification._ import org.apache.spark.ml.classification._
import org.apache.spark.ml.linalg._ import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param._ import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.HasWeightCol import org.apache.spark.ml.param.shared.HasWeightCol
import org.apache.spark.ml.util._ import org.apache.spark.ml.util._
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._ import org.apache.spark.sql.types._
import org.apache.spark.sql._
import org.json4s.DefaultFormats import org.json4s.DefaultFormats
import org.apache.spark.broadcast.Broadcast import scala.collection.JavaConverters._
import scala.collection.{AbstractIterator, Iterator, mutable}
private[spark] trait XGBoostClassifierParams extends GeneralParams with LearningTaskParams private[spark] trait XGBoostClassifierParams extends GeneralParams with LearningTaskParams
with BoosterParams with HasWeightCol with HasBaseMarginCol with HasNumClass with ParamMapFuncs with BoosterParams with HasWeightCol with HasBaseMarginCol with HasNumClass with ParamMapFuncs
@ -216,7 +212,8 @@ class XGBoostClassificationModel private[ml](
override val numClasses: Int, override val numClasses: Int,
private[spark] val _booster: Booster) private[spark] val _booster: Booster)
extends ProbabilisticClassificationModel[Vector, XGBoostClassificationModel] extends ProbabilisticClassificationModel[Vector, XGBoostClassificationModel]
with XGBoostClassifierParams with MLWritable with Serializable { with XGBoostClassifierParams with InferenceParams
with MLWritable with Serializable {
import XGBoostClassificationModel._ import XGBoostClassificationModel._
@ -250,6 +247,8 @@ class XGBoostClassificationModel private[ml](
def setTreeLimit(value: Int): this.type = set(treeLimit, value) def setTreeLimit(value: Int): this.type = set(treeLimit, value)
def setInferBatchSize(value: Int): this.type = set(inferBatchSize, value)
/** /**
* Single instance prediction. * Single instance prediction.
* Note: The performance is not ideal, use it carefully! * Note: The performance is not ideal, use it carefully!
@ -287,46 +286,53 @@ class XGBoostClassificationModel private[ml](
val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster) val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster)
val appName = dataset.sparkSession.sparkContext.appName val appName = dataset.sparkSession.sparkContext.appName
val inputRDD = dataset.asInstanceOf[Dataset[Row]].rdd val resultRDD = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator =>
val predictionRDD = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator => new AbstractIterator[Row] {
if (rowIterator.hasNext) { private var batchCnt = 0
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
Rabit.init(rabitEnv.asJava) private val batchIterImpl = rowIterator.grouped($(inferBatchSize)).flatMap { batchRow =>
val featuresIterator = rowIterator.map(row => row.getAs[Vector]( if (batchCnt == 0) {
$(featuresCol))).toList.iterator val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
import DataUtils._ Rabit.init(rabitEnv.asJava)
val cacheInfo = { }
if ($(useExternalMemory)) {
s"$appName-${TaskContext.get().stageId()}-dtest_cache-${TaskContext.getPartitionId()}" val features = batchRow.iterator.map(row => row.getAs[Vector]($(featuresCol)))
} else {
null import DataUtils._
val cacheInfo = {
if ($(useExternalMemory)) {
s"$appName-${TaskContext.get().stageId()}-dtest_cache-" +
s"${TaskContext.getPartitionId()}-batch-$batchCnt"
} else {
null
}
}
val dm = new DMatrix(
XGBoost.processMissingValues(features.map(_.asXGB), $(missing)),
cacheInfo)
try {
val Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr) =
producePredictionItrs(bBooster, dm)
produceResultIterator(batchRow.iterator,
rawPredictionItr, probabilityItr, predLeafItr, predContribItr)
} finally {
batchCnt += 1
dm.delete()
} }
} }
val dm = new DMatrix(
XGBoost.processMissingValues(featuresIterator.map(_.asXGB), $(missing)), override def hasNext: Boolean = batchIterImpl.hasNext
cacheInfo)
try { override def next(): Row = {
val Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr) = val ret = batchIterImpl.next()
producePredictionItrs(bBooster, dm) if (!batchIterImpl.hasNext) {
Rabit.shutdown() Rabit.shutdown()
Iterator(rawPredictionItr, probabilityItr, predLeafItr, }
predContribItr) ret
} finally {
dm.delete()
} }
} else {
Iterator()
} }
} }
val resultRDD = inputRDD.zipPartitions(predictionRDD, preservesPartitioning = true) {
case (inputIterator, predictionItr) =>
if (inputIterator.hasNext) {
produceResultIterator(inputIterator, predictionItr.next(), predictionItr.next(),
predictionItr.next(), predictionItr.next())
} else {
Iterator()
}
}
bBooster.unpersist(blocking = false) bBooster.unpersist(blocking = false)
dataset.sparkSession.createDataFrame(resultRDD, generateResultSchema(schema)) dataset.sparkSession.createDataFrame(resultRDD, generateResultSchema(schema))
@ -527,4 +533,3 @@ object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel]
} }
} }
} }

View File

@ -16,10 +16,10 @@
package ml.dmlc.xgboost4j.scala.spark package ml.dmlc.xgboost4j.scala.spark
import scala.collection.Iterator import scala.collection.{AbstractIterator, Iterator, mutable}
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.java.Rabit import ml.dmlc.xgboost4j.java.{Rabit, XGBoost => JXGBoost}
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import ml.dmlc.xgboost4j.scala.spark.params.{DefaultXGBoostParamsReader, _} import ml.dmlc.xgboost4j.scala.spark.params.{DefaultXGBoostParamsReader, _}
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost} import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
@ -37,7 +37,7 @@ import org.apache.spark.sql._
import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._ import org.apache.spark.sql.types._
import org.json4s.DefaultFormats import org.json4s.DefaultFormats
import scala.collection.mutable import scala.collection.mutable.ListBuffer
import org.apache.spark.broadcast.Broadcast import org.apache.spark.broadcast.Broadcast
@ -207,7 +207,8 @@ class XGBoostRegressionModel private[ml] (
override val uid: String, override val uid: String,
private[spark] val _booster: Booster) private[spark] val _booster: Booster)
extends PredictionModel[Vector, XGBoostRegressionModel] extends PredictionModel[Vector, XGBoostRegressionModel]
with XGBoostRegressorParams with MLWritable with Serializable { with XGBoostRegressorParams with InferenceParams
with MLWritable with Serializable {
import XGBoostRegressionModel._ import XGBoostRegressionModel._
@ -241,6 +242,8 @@ class XGBoostRegressionModel private[ml] (
def setTreeLimit(value: Int): this.type = set(treeLimit, value) def setTreeLimit(value: Int): this.type = set(treeLimit, value)
def setInferBatchSize(value: Int): this.type = set(inferBatchSize, value)
/** /**
* Single instance prediction. * Single instance prediction.
* Note: The performance is not ideal, use it carefully! * Note: The performance is not ideal, use it carefully!
@ -259,45 +262,53 @@ class XGBoostRegressionModel private[ml] (
val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster) val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster)
val appName = dataset.sparkSession.sparkContext.appName val appName = dataset.sparkSession.sparkContext.appName
val inputRDD = dataset.asInstanceOf[Dataset[Row]].rdd
val predictionRDD = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator => val resultRDD = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator =>
if (rowIterator.hasNext) { new AbstractIterator[Row] {
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap private var batchCnt = 0
Rabit.init(rabitEnv.asJava)
val featuresIterator = rowIterator.map(row => row.getAs[Vector]( private val batchIterImpl = rowIterator.grouped($(inferBatchSize)).flatMap { batchRow =>
$(featuresCol))).toList.iterator if (batchCnt == 0) {
import DataUtils._ val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
val cacheInfo = { Rabit.init(rabitEnv.asJava)
if ($(useExternalMemory)) { }
s"$appName-${TaskContext.get().stageId()}-dtest_cache-${TaskContext.getPartitionId()}"
} else { val features = batchRow.iterator.map(row => row.getAs[Vector]($(featuresCol)))
null
import DataUtils._
val cacheInfo = {
if ($(useExternalMemory)) {
s"$appName-${TaskContext.get().stageId()}-dtest_cache-" +
s"${TaskContext.getPartitionId()}-batch-$batchCnt"
} else {
null
}
}
val dm = new DMatrix(
XGBoost.processMissingValues(features.map(_.asXGB), $(missing)),
cacheInfo)
try {
val Array(rawPredictionItr, predLeafItr, predContribItr) =
producePredictionItrs(bBooster, dm)
produceResultIterator(batchRow.iterator, rawPredictionItr, predLeafItr, predContribItr)
} finally {
batchCnt += 1
dm.delete()
} }
} }
val dm = new DMatrix(
XGBoost.processMissingValues(featuresIterator.map(_.asXGB), $(missing)), override def hasNext: Boolean = batchIterImpl.hasNext
cacheInfo)
try { override def next(): Row = {
val Array(originalPredictionItr, predLeafItr, predContribItr) = val ret = batchIterImpl.next()
producePredictionItrs(bBooster, dm) if (!batchIterImpl.hasNext) {
Rabit.shutdown() Rabit.shutdown()
Iterator(originalPredictionItr, predLeafItr, predContribItr) }
} finally { ret
dm.delete()
} }
} else {
Iterator()
} }
} }
val resultRDD = inputRDD.zipPartitions(predictionRDD, preservesPartitioning = true) {
case (inputIterator, predictionItr) =>
if (inputIterator.hasNext) {
produceResultIterator(inputIterator, predictionItr.next(), predictionItr.next(),
predictionItr.next())
} else {
Iterator()
}
}
bBooster.unpersist(blocking = false) bBooster.unpersist(blocking = false)
dataset.sparkSession.createDataFrame(resultRDD, generateResultSchema(schema)) dataset.sparkSession.createDataFrame(resultRDD, generateResultSchema(schema))
} }
@ -347,14 +358,14 @@ class XGBoostRegressionModel private[ml] (
resultSchema resultSchema
} }
private def producePredictionItrs(broadcastBooster: Broadcast[Booster], dm: DMatrix): private def producePredictionItrs(booster: Broadcast[Booster], dm: DMatrix):
Array[Iterator[Row]] = { Array[Iterator[Row]] = {
val originalPredictionItr = { val originalPredictionItr = {
broadcastBooster.value.predict(dm, outPutMargin = false, $(treeLimit)).map(Row(_)).iterator booster.value.predict(dm, outPutMargin = false, $(treeLimit)).map(Row(_)).iterator
} }
val predLeafItr = { val predLeafItr = {
if (isDefined(leafPredictionCol)) { if (isDefined(leafPredictionCol)) {
broadcastBooster.value.predictLeaf(dm, $(treeLimit)). booster.value.predictLeaf(dm, $(treeLimit)).
map(Row(_)).iterator map(Row(_)).iterator
} else { } else {
Iterator() Iterator()
@ -362,7 +373,7 @@ class XGBoostRegressionModel private[ml] (
} }
val predContribItr = { val predContribItr = {
if (isDefined(contribPredictionCol)) { if (isDefined(contribPredictionCol)) {
broadcastBooster.value.predictContrib(dm, $(treeLimit)). booster.value.predictContrib(dm, $(treeLimit)).
map(Row(_)).iterator map(Row(_)).iterator
} else { } else {
Iterator() Iterator()
@ -373,7 +384,6 @@ class XGBoostRegressionModel private[ml] (
override def transform(dataset: Dataset[_]): DataFrame = { override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true) transformSchema(dataset.schema, logging = true)
// Output selected columns only. // Output selected columns only.
// This is a bit complicated since it tries to avoid repeated computation. // This is a bit complicated since it tries to avoid repeated computation.
var outputData = transformInternal(dataset) var outputData = transformInternal(dataset)

View File

@ -0,0 +1,32 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark.params
import org.apache.spark.ml.param.{IntParam, Params}
private[spark] trait InferenceParams extends Params {
/**
* batch size of inference iteration
*/
final val inferBatchSize = new IntParam(this, "batchSize", "batch size of inference iteration")
/** @group getParam */
final def getInferBatchSize: Int = ${inferBatchSize}
setDefault(inferBatchSize, 32 << 10)
}

View File

@ -19,11 +19,12 @@ package ml.dmlc.xgboost4j.scala.spark
import java.io.File import java.io.File
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql._ import org.apache.spark.sql._
import org.scalatest.{BeforeAndAfterEach, FunSuite} import org.scalatest.{BeforeAndAfterEach, FunSuite}
import scala.util.Random
trait PerTest extends BeforeAndAfterEach { self: FunSuite => trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
protected val numWorkers: Int = Runtime.getRuntime.availableProcessors() protected val numWorkers: Int = Runtime.getRuntime.availableProcessors()
@ -80,6 +81,18 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
.toDF("id", "label", "features") .toDF("id", "label", "features")
} }
protected def buildDataFrameWithRandSort(
labeledPoints: Seq[XGBLabeledPoint],
numPartitions: Int = numWorkers): DataFrame = {
val df = buildDataFrame(labeledPoints, numPartitions)
val rndSortedRDD = df.rdd.mapPartitions { iter =>
iter.map(_ -> Random.nextDouble()).toList
.sortBy(_._2)
.map(_._1).iterator
}
ss.createDataFrame(rndSortedRDD, df.schema)
}
protected def buildDataFrameWithGroup( protected def buildDataFrameWithGroup(
labeledPoints: Seq[XGBLabeledPoint], labeledPoints: Seq[XGBLabeledPoint],
numPartitions: Int = numWorkers): DataFrame = { numPartitions: Int = numWorkers): DataFrame = {

View File

@ -27,13 +27,28 @@ import org.apache.spark.Partitioner
class XGBoostClassifierSuite extends FunSuite with PerTest { class XGBoostClassifierSuite extends FunSuite with PerTest {
test("XGBoost-Spark XGBoostClassifier ouput should match XGBoost4j") { test("XGBoost-Spark XGBoostClassifier output should match XGBoost4j") {
val trainingDM = new DMatrix(Classification.train.iterator) val trainingDM = new DMatrix(Classification.train.iterator)
val testDM = new DMatrix(Classification.test.iterator) val testDM = new DMatrix(Classification.test.iterator)
val trainingDF = buildDataFrame(Classification.train) val trainingDF = buildDataFrame(Classification.train)
val testDF = buildDataFrame(Classification.test) val testDF = buildDataFrame(Classification.test)
val round = 5 checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF)
}
test("XGBoostClassifier should make correct predictions after upstream random sort") {
val trainingDM = new DMatrix(Classification.train.iterator)
val testDM = new DMatrix(Classification.test.iterator)
val trainingDF = buildDataFrameWithRandSort(Classification.train)
val testDF = buildDataFrameWithRandSort(Classification.test)
checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF)
}
private def checkResultsWithXGBoost4j(
trainingDM: DMatrix,
testDM: DMatrix,
trainingDF: DataFrame,
testDF: DataFrame,
round: Int = 5): Unit = {
val paramMap = Map( val paramMap = Map(
"eta" -> "1", "eta" -> "1",
"max_depth" -> "6", "max_depth" -> "6",
@ -47,7 +62,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
"num_workers" -> numWorkers)).fit(trainingDF) "num_workers" -> numWorkers)).fit(trainingDF)
val prediction2 = model2.transform(testDF). val prediction2 = model2.transform(testDF).
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("probability"))).toMap collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("probability"))).toMap
assert(testDF.count() === prediction2.size) assert(testDF.count() === prediction2.size)
// the vector length in probability column is 2 since we have to fit to the evaluator in Spark // the vector length in probability column is 2 since we have to fit to the evaluator in Spark
@ -60,7 +75,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
val prediction3 = model1.predict(testDM, outPutMargin = true) val prediction3 = model1.predict(testDM, outPutMargin = true)
val prediction4 = model2.transform(testDF). val prediction4 = model2.transform(testDF).
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("rawPrediction"))).toMap collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("rawPrediction"))).toMap
assert(testDF.count() === prediction4.size) assert(testDF.count() === prediction4.size)
// the vector length in rawPrediction column is 2 since we have to fit to the evaluator in Spark // the vector length in rawPrediction column is 2 since we have to fit to the evaluator in Spark
@ -73,7 +88,9 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
// check the equality of single instance prediction // check the equality of single instance prediction
val firstOfDM = testDM.slice(Array(0)) val firstOfDM = testDM.slice(Array(0))
val firstOfDF = testDF.head().getAs[Vector]("features") val firstOfDF = testDF.filter(_.getAs[Int]("id") == 0)
.head()
.getAs[Vector]("features")
val prediction5 = math.round(model1.predict(firstOfDM)(0)(0)) val prediction5 = math.round(model1.predict(firstOfDM)(0)(0))
val prediction6 = model2.predict(firstOfDF) val prediction6 = model2.predict(firstOfDF)
assert(prediction5 === prediction6) assert(prediction5 === prediction6)

View File

@ -463,4 +463,42 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
assert(model2.summary.trainObjectiveHistory !== model2.summary.validationObjectiveHistory(0)) assert(model2.summary.trainObjectiveHistory !== model2.summary.validationObjectiveHistory(0))
assert(model2.summary.trainObjectiveHistory !== model2.summary.validationObjectiveHistory(1)) assert(model2.summary.trainObjectiveHistory !== model2.summary.validationObjectiveHistory(1))
} }
test("infer with different batch sizes") {
val regModel = new XGBoostRegressor(Map(
"eta" -> "1",
"max_depth" -> "6",
"silent" -> "1",
"objective" -> "reg:squarederror",
"num_round" -> 5,
"num_workers" -> numWorkers))
.fit(buildDataFrame(Regression.train))
val regDF = buildDataFrame(Regression.test)
val regRet1 = regModel.transform(regDF).collect()
val regRet2 = regModel.setInferBatchSize(1).transform(regDF).collect()
val regRet3 = regModel.setInferBatchSize(10).transform(regDF).collect()
val regRet4 = regModel.setInferBatchSize(32 << 15).transform(regDF).collect()
assert(regRet1 sameElements regRet2)
assert(regRet1 sameElements regRet3)
assert(regRet1 sameElements regRet4)
val clsModel = new XGBoostClassifier(Map(
"eta" -> "1",
"max_depth" -> "6",
"silent" -> "1",
"objective" -> "binary:logistic",
"num_round" -> 5,
"num_workers" -> numWorkers))
.fit(buildDataFrame(Classification.train))
val clsDF = buildDataFrame(Classification.test)
val clsRet1 = clsModel.transform(clsDF).collect()
val clsRet2 = clsModel.setInferBatchSize(1).transform(clsDF).collect()
val clsRet3 = clsModel.setInferBatchSize(10).transform(clsDF).collect()
val clsRet4 = clsModel.setInferBatchSize(32 << 15).transform(clsDF).collect()
assert(clsRet1 sameElements clsRet2)
assert(clsRet1 sameElements clsRet3)
assert(clsRet1 sameElements clsRet4)
}
} }

View File

@ -19,19 +19,34 @@ package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost} import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.linalg.Vector
import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions._
import org.apache.spark.sql.Row import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.types._ import org.apache.spark.sql.types._
import org.scalatest.FunSuite import org.scalatest.FunSuite
class XGBoostRegressorSuite extends FunSuite with PerTest { class XGBoostRegressorSuite extends FunSuite with PerTest {
test("XGBoost-Spark XGBoostRegressor ouput should match XGBoost4j: regression") { test("XGBoost-Spark XGBoostRegressor output should match XGBoost4j") {
val trainingDM = new DMatrix(Regression.train.iterator) val trainingDM = new DMatrix(Regression.train.iterator)
val testDM = new DMatrix(Regression.test.iterator) val testDM = new DMatrix(Regression.test.iterator)
val trainingDF = buildDataFrame(Regression.train) val trainingDF = buildDataFrame(Regression.train)
val testDF = buildDataFrame(Regression.test) val testDF = buildDataFrame(Regression.test)
val round = 5 checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF)
}
test("XGBoostRegressor should make correct predictions after upstream random sort") {
val trainingDM = new DMatrix(Regression.train.iterator)
val testDM = new DMatrix(Regression.test.iterator)
val trainingDF = buildDataFrameWithRandSort(Regression.train)
val testDF = buildDataFrameWithRandSort(Regression.test)
checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF)
}
private def checkResultsWithXGBoost4j(
trainingDM: DMatrix,
testDM: DMatrix,
trainingDF: DataFrame,
testDF: DataFrame,
round: Int = 5): Unit = {
val paramMap = Map( val paramMap = Map(
"eta" -> "1", "eta" -> "1",
"max_depth" -> "6", "max_depth" -> "6",
@ -45,7 +60,7 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
"num_workers" -> numWorkers)).fit(trainingDF) "num_workers" -> numWorkers)).fit(trainingDF)
val prediction2 = model2.transform(testDF). val prediction2 = model2.transform(testDF).
collect().map(row => (row.getAs[Int]("id"), row.getAs[Double]("prediction"))).toMap collect().map(row => (row.getAs[Int]("id"), row.getAs[Double]("prediction"))).toMap
assert(prediction1.indices.count { i => assert(prediction1.indices.count { i =>
math.abs(prediction1(i)(0) - prediction2(i)) > 0.01 math.abs(prediction1(i)(0) - prediction2(i)) > 0.01
@ -54,7 +69,9 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
// check the equality of single instance prediction // check the equality of single instance prediction
val firstOfDM = testDM.slice(Array(0)) val firstOfDM = testDM.slice(Array(0))
val firstOfDF = testDF.head().getAs[Vector]("features") val firstOfDF = testDF.filter(_.getAs[Int]("id") == 0)
.head()
.getAs[Vector]("features")
val prediction3 = model1.predict(firstOfDM)(0)(0) val prediction3 = model1.predict(firstOfDM)(0)(0)
val prediction4 = model2.predict(firstOfDF) val prediction4 = model2.predict(firstOfDF)
assert(math.abs(prediction3 - prediction4) <= 0.01f) assert(math.abs(prediction3 - prediction4) <= 0.01f)