[BLOCKING][jvm-packages] fix non-deterministic order within a partition (in the case of an upstream shuffle) on prediction (#4388)

* [jvm-packages][hot-fix] fix column mismatch caused by zip actions at XGBooostModel.transformInternal

* apply minibatch in prediction

* an iterator-compatible minibatch prediction

* regressor impl

* continuous working on mini-batch prediction of xgboost4j-spark

* Update Booster.java
This commit is contained in:
Xu Xiao
2019-04-27 02:09:20 +08:00
committed by Nan Zhu
parent 503cc42f48
commit 2d875ec019
7 changed files with 232 additions and 100 deletions

View File

@@ -19,11 +19,12 @@ package ml.dmlc.xgboost4j.scala.spark
import java.io.File
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql._
import org.scalatest.{BeforeAndAfterEach, FunSuite}
import scala.util.Random
trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
protected val numWorkers: Int = Runtime.getRuntime.availableProcessors()
@@ -80,6 +81,18 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
.toDF("id", "label", "features")
}
protected def buildDataFrameWithRandSort(
labeledPoints: Seq[XGBLabeledPoint],
numPartitions: Int = numWorkers): DataFrame = {
val df = buildDataFrame(labeledPoints, numPartitions)
val rndSortedRDD = df.rdd.mapPartitions { iter =>
iter.map(_ -> Random.nextDouble()).toList
.sortBy(_._2)
.map(_._1).iterator
}
ss.createDataFrame(rndSortedRDD, df.schema)
}
protected def buildDataFrameWithGroup(
labeledPoints: Seq[XGBLabeledPoint],
numPartitions: Int = numWorkers): DataFrame = {

View File

@@ -27,13 +27,28 @@ import org.apache.spark.Partitioner
class XGBoostClassifierSuite extends FunSuite with PerTest {
test("XGBoost-Spark XGBoostClassifier ouput should match XGBoost4j") {
test("XGBoost-Spark XGBoostClassifier output should match XGBoost4j") {
val trainingDM = new DMatrix(Classification.train.iterator)
val testDM = new DMatrix(Classification.test.iterator)
val trainingDF = buildDataFrame(Classification.train)
val testDF = buildDataFrame(Classification.test)
val round = 5
checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF)
}
test("XGBoostClassifier should make correct predictions after upstream random sort") {
val trainingDM = new DMatrix(Classification.train.iterator)
val testDM = new DMatrix(Classification.test.iterator)
val trainingDF = buildDataFrameWithRandSort(Classification.train)
val testDF = buildDataFrameWithRandSort(Classification.test)
checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF)
}
private def checkResultsWithXGBoost4j(
trainingDM: DMatrix,
testDM: DMatrix,
trainingDF: DataFrame,
testDF: DataFrame,
round: Int = 5): Unit = {
val paramMap = Map(
"eta" -> "1",
"max_depth" -> "6",
@@ -47,7 +62,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
"num_workers" -> numWorkers)).fit(trainingDF)
val prediction2 = model2.transform(testDF).
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("probability"))).toMap
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("probability"))).toMap
assert(testDF.count() === prediction2.size)
// the vector length in probability column is 2 since we have to fit to the evaluator in Spark
@@ -60,7 +75,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
val prediction3 = model1.predict(testDM, outPutMargin = true)
val prediction4 = model2.transform(testDF).
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("rawPrediction"))).toMap
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("rawPrediction"))).toMap
assert(testDF.count() === prediction4.size)
// the vector length in rawPrediction column is 2 since we have to fit to the evaluator in Spark
@@ -73,7 +88,9 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
// check the equality of single instance prediction
val firstOfDM = testDM.slice(Array(0))
val firstOfDF = testDF.head().getAs[Vector]("features")
val firstOfDF = testDF.filter(_.getAs[Int]("id") == 0)
.head()
.getAs[Vector]("features")
val prediction5 = math.round(model1.predict(firstOfDM)(0)(0))
val prediction6 = model2.predict(firstOfDF)
assert(prediction5 === prediction6)

View File

@@ -463,4 +463,42 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
assert(model2.summary.trainObjectiveHistory !== model2.summary.validationObjectiveHistory(0))
assert(model2.summary.trainObjectiveHistory !== model2.summary.validationObjectiveHistory(1))
}
test("infer with different batch sizes") {
val regModel = new XGBoostRegressor(Map(
"eta" -> "1",
"max_depth" -> "6",
"silent" -> "1",
"objective" -> "reg:squarederror",
"num_round" -> 5,
"num_workers" -> numWorkers))
.fit(buildDataFrame(Regression.train))
val regDF = buildDataFrame(Regression.test)
val regRet1 = regModel.transform(regDF).collect()
val regRet2 = regModel.setInferBatchSize(1).transform(regDF).collect()
val regRet3 = regModel.setInferBatchSize(10).transform(regDF).collect()
val regRet4 = regModel.setInferBatchSize(32 << 15).transform(regDF).collect()
assert(regRet1 sameElements regRet2)
assert(regRet1 sameElements regRet3)
assert(regRet1 sameElements regRet4)
val clsModel = new XGBoostClassifier(Map(
"eta" -> "1",
"max_depth" -> "6",
"silent" -> "1",
"objective" -> "binary:logistic",
"num_round" -> 5,
"num_workers" -> numWorkers))
.fit(buildDataFrame(Classification.train))
val clsDF = buildDataFrame(Classification.test)
val clsRet1 = clsModel.transform(clsDF).collect()
val clsRet2 = clsModel.setInferBatchSize(1).transform(clsDF).collect()
val clsRet3 = clsModel.setInferBatchSize(10).transform(clsDF).collect()
val clsRet4 = clsModel.setInferBatchSize(32 << 15).transform(clsDF).collect()
assert(clsRet1 sameElements clsRet2)
assert(clsRet1 sameElements clsRet3)
assert(clsRet1 sameElements clsRet4)
}
}

View File

@@ -19,19 +19,34 @@ package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.sql.functions._
import org.apache.spark.sql.Row
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.types._
import org.scalatest.FunSuite
class XGBoostRegressorSuite extends FunSuite with PerTest {
test("XGBoost-Spark XGBoostRegressor ouput should match XGBoost4j: regression") {
test("XGBoost-Spark XGBoostRegressor output should match XGBoost4j") {
val trainingDM = new DMatrix(Regression.train.iterator)
val testDM = new DMatrix(Regression.test.iterator)
val trainingDF = buildDataFrame(Regression.train)
val testDF = buildDataFrame(Regression.test)
val round = 5
checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF)
}
test("XGBoostRegressor should make correct predictions after upstream random sort") {
val trainingDM = new DMatrix(Regression.train.iterator)
val testDM = new DMatrix(Regression.test.iterator)
val trainingDF = buildDataFrameWithRandSort(Regression.train)
val testDF = buildDataFrameWithRandSort(Regression.test)
checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF)
}
private def checkResultsWithXGBoost4j(
trainingDM: DMatrix,
testDM: DMatrix,
trainingDF: DataFrame,
testDF: DataFrame,
round: Int = 5): Unit = {
val paramMap = Map(
"eta" -> "1",
"max_depth" -> "6",
@@ -45,7 +60,7 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
"num_workers" -> numWorkers)).fit(trainingDF)
val prediction2 = model2.transform(testDF).
collect().map(row => (row.getAs[Int]("id"), row.getAs[Double]("prediction"))).toMap
collect().map(row => (row.getAs[Int]("id"), row.getAs[Double]("prediction"))).toMap
assert(prediction1.indices.count { i =>
math.abs(prediction1(i)(0) - prediction2(i)) > 0.01
@@ -54,7 +69,9 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
// check the equality of single instance prediction
val firstOfDM = testDM.slice(Array(0))
val firstOfDF = testDF.head().getAs[Vector]("features")
val firstOfDF = testDF.filter(_.getAs[Int]("id") == 0)
.head()
.getAs[Vector]("features")
val prediction3 = model1.predict(firstOfDM)(0)(0)
val prediction4 = model2.predict(firstOfDF)
assert(math.abs(prediction3 - prediction4) <= 0.01f)