[jvm-packages] test consistency of prediction functions with DMatrix and RDD (#1518)
* test consistency of prediction functions between DMatrix and RDD * fix the failed test cases
This commit is contained in:
parent
d7f79255ec
commit
6d65aae091
@ -128,7 +128,7 @@ class XGBoostModel(_booster: Booster) extends Serializable {
|
||||
* @param testSet test set represented as DMatrix
|
||||
*/
|
||||
def predict(testSet: DMatrix): Array[Array[Float]] = {
|
||||
_booster.predict(testSet, true, 0)
|
||||
_booster.predict(testSet)
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -128,18 +128,113 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter {
|
||||
List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
||||
"objective" -> "binary:logistic").toMap,
|
||||
new scala.collection.mutable.HashMap[String, String],
|
||||
numWorkers = 2, round = 5, null, null, false)
|
||||
numWorkers = 2, round = 5, null, null, useExternalMemory = false)
|
||||
val boosterCount = boosterRDD.count()
|
||||
assert(boosterCount === 2)
|
||||
val boosters = boosterRDD.collect()
|
||||
val eval = new EvalError()
|
||||
for (booster <- boosters) {
|
||||
val predicts = booster.predict(testSetDMatrix, true)
|
||||
assert(new EvalError().eval(predicts, testSetDMatrix) < 0.17)
|
||||
val predicts = booster.predict(testSetDMatrix, outPutMargin = true)
|
||||
assert(eval.eval(predicts, testSetDMatrix) < 0.17)
|
||||
}
|
||||
}
|
||||
|
||||
test("training with external memory cache") {
|
||||
sc.stop()
|
||||
sc = null
|
||||
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite")
|
||||
val customSparkContext = new SparkContext(sparkConf)
|
||||
val eval = new EvalError()
|
||||
val trainingRDD = buildTrainingRDD(Some(customSparkContext))
|
||||
val testSet = readFile(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
||||
import DataUtils._
|
||||
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "0",
|
||||
"objective" -> "binary:logistic").toMap
|
||||
val xgBoostModel = XGBoost.train(trainingRDD, paramMap, round = 5,
|
||||
nWorkers = numWorkers, useExternalMemory = true)
|
||||
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||
testSetDMatrix) < 0.1)
|
||||
customSparkContext.stop()
|
||||
// clean
|
||||
val dir = new File(".")
|
||||
for (file <- dir.listFiles() if file.getName.startsWith("XGBoostSuite-0-dtrain_cache")) {
|
||||
file.delete()
|
||||
}
|
||||
}
|
||||
|
||||
test("save and load model") {
|
||||
test("test with dense vectors containing missing value") {
|
||||
def buildDenseRDD(): RDD[LabeledPoint] = {
|
||||
val nrow = 100
|
||||
val ncol = 5
|
||||
val data0 = Array.ofDim[Double](nrow, ncol)
|
||||
// put random nums
|
||||
for (r <- 0 until nrow; c <- 0 until ncol) {
|
||||
data0(r)(c) = {
|
||||
if (c == ncol - 1) {
|
||||
-0.1
|
||||
} else {
|
||||
Random.nextDouble()
|
||||
}
|
||||
}
|
||||
}
|
||||
// create label
|
||||
val label0 = new Array[Double](nrow)
|
||||
for (i <- label0.indices) {
|
||||
label0(i) = Random.nextDouble()
|
||||
}
|
||||
val points = new ListBuffer[LabeledPoint]
|
||||
for (r <- 0 until nrow) {
|
||||
points += LabeledPoint(label0(r), Vectors.dense(data0(r)))
|
||||
}
|
||||
sc.parallelize(points)
|
||||
}
|
||||
val trainingRDD = buildDenseRDD().repartition(4)
|
||||
val testRDD = buildDenseRDD().repartition(4)
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
||||
"objective" -> "binary:logistic").toMap
|
||||
val xgBoostModel = XGBoost.train(trainingRDD, paramMap, 5, numWorkers)
|
||||
xgBoostModel.predict(testRDD.map(_.features.toDense), missingValue = -0.1f).collect()
|
||||
}
|
||||
|
||||
test("test consistency of prediction functions with RDD") {
|
||||
val trainingRDD = buildTrainingRDD()
|
||||
val testSet = readFile(getClass.getResource("/agaricus.txt.test").getFile)
|
||||
val testRDD = sc.parallelize(testSet, numSlices = 1).map(_.features)
|
||||
val testCollection = testRDD.collect()
|
||||
for (i <- testSet.indices) {
|
||||
assert(testCollection(i).toDense.values.sameElements(testSet(i).features.toDense.values))
|
||||
}
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
||||
"objective" -> "binary:logistic").toMap
|
||||
val xgBoostModel = XGBoost.train(trainingRDD, paramMap, 5, numWorkers)
|
||||
val predRDD = xgBoostModel.predict(testRDD)
|
||||
val predResult1 = predRDD.collect()(0)
|
||||
import DataUtils._
|
||||
val predResult2 = xgBoostModel.predict(new DMatrix(testSet.iterator))
|
||||
for (i <- predResult1.indices; j <- predResult1(i).indices) {
|
||||
assert(predResult1(i)(j) === predResult2(i)(j))
|
||||
}
|
||||
}
|
||||
|
||||
test("test prediction functionality with empty partition") {
|
||||
def buildEmptyRDD(sparkContext: Option[SparkContext] = None): RDD[SparkVector] = {
|
||||
val sampleList = new ListBuffer[SparkVector]
|
||||
sparkContext.getOrElse(sc).parallelize(sampleList, numWorkers)
|
||||
}
|
||||
|
||||
val trainingRDD = buildTrainingRDD()
|
||||
val testRDD = buildEmptyRDD()
|
||||
import DataUtils._
|
||||
val tempDir = Files.createTempDirectory("xgboosttest-")
|
||||
val tempFile = Files.createTempFile(tempDir, "", "")
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
||||
"objective" -> "binary:logistic").toMap
|
||||
val xgBoostModel = XGBoost.train(trainingRDD, paramMap, 5, numWorkers)
|
||||
println(xgBoostModel.predict(testRDD).collect().length === 0)
|
||||
}
|
||||
|
||||
test("test model consistency after save and load") {
|
||||
val eval = new EvalError()
|
||||
val trainingRDD = buildTrainingRDD()
|
||||
val testSet = readFile(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
||||
@ -150,11 +245,12 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter {
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
||||
"objective" -> "binary:logistic").toMap
|
||||
val xgBoostModel = XGBoost.train(trainingRDD, paramMap, 5, numWorkers)
|
||||
val evalResults = eval.eval(xgBoostModel.predict(testSetDMatrix), testSetDMatrix)
|
||||
val evalResults = eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||
testSetDMatrix)
|
||||
assert(evalResults < 0.1)
|
||||
xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath)
|
||||
val loadedXGBooostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath)
|
||||
val predicts = loadedXGBooostModel.predict(testSetDMatrix)
|
||||
val predicts = loadedXGBooostModel.booster.predict(testSetDMatrix, outPutMargin = true)
|
||||
val loadedEvalResults = eval.eval(predicts, testSetDMatrix)
|
||||
assert(loadedEvalResults == evalResults)
|
||||
}
|
||||
@ -190,83 +286,8 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter {
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
||||
"objective" -> "binary:logistic").toMap
|
||||
val xgBoostModel = XGBoost.train(trainingRDD, paramMap, 5, numWorkers)
|
||||
assert(eval.eval(xgBoostModel.predict(testSetDMatrix), testSetDMatrix) < 0.1)
|
||||
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||
testSetDMatrix) < 0.1)
|
||||
customSparkContext.stop()
|
||||
}
|
||||
|
||||
test("test with empty partition") {
|
||||
|
||||
def buildEmptyRDD(sparkContext: Option[SparkContext] = None): RDD[SparkVector] = {
|
||||
val sampleList = new ListBuffer[SparkVector]
|
||||
sparkContext.getOrElse(sc).parallelize(sampleList, numWorkers)
|
||||
}
|
||||
|
||||
val eval = new EvalError()
|
||||
val trainingRDD = buildTrainingRDD()
|
||||
val testRDD = buildEmptyRDD()
|
||||
import DataUtils._
|
||||
val tempDir = Files.createTempDirectory("xgboosttest-")
|
||||
val tempFile = Files.createTempFile(tempDir, "", "")
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
||||
"objective" -> "binary:logistic").toMap
|
||||
val xgBoostModel = XGBoost.train(trainingRDD, paramMap, 5, numWorkers)
|
||||
|
||||
println(xgBoostModel.predict(testRDD).collect())
|
||||
}
|
||||
|
||||
test("test with dense vectors containing missing value") {
|
||||
def buildDenseRDD(): RDD[LabeledPoint] = {
|
||||
val nrow = 100
|
||||
val ncol = 5
|
||||
val data0 = Array.ofDim[Double](nrow, ncol)
|
||||
// put random nums
|
||||
for (r <- 0 until nrow; c <- 0 until ncol) {
|
||||
data0(r)(c) = {
|
||||
if (c == ncol - 1) {
|
||||
-0.1
|
||||
} else {
|
||||
Random.nextDouble()
|
||||
}
|
||||
}
|
||||
}
|
||||
// create label
|
||||
val label0 = new Array[Double](nrow)
|
||||
for (i <- label0.indices) {
|
||||
label0(i) = Random.nextDouble()
|
||||
}
|
||||
val points = new ListBuffer[LabeledPoint]
|
||||
for (r <- 0 until nrow) {
|
||||
points += LabeledPoint(label0(r), Vectors.dense(data0(r)))
|
||||
}
|
||||
sc.parallelize(points)
|
||||
}
|
||||
val trainingRDD = buildDenseRDD().repartition(4)
|
||||
val testRDD = buildDenseRDD().repartition(4)
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
||||
"objective" -> "binary:logistic").toMap
|
||||
val xgBoostModel = XGBoost.train(trainingRDD, paramMap, 5, 4)
|
||||
xgBoostModel.predict(testRDD.map(_.features.toDense), missingValue = -0.1f).collect()
|
||||
}
|
||||
|
||||
test("training with external memory cache") {
|
||||
sc.stop()
|
||||
sc = null
|
||||
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite")
|
||||
val customSparkContext = new SparkContext(sparkConf)
|
||||
val eval = new EvalError()
|
||||
val trainingRDD = buildTrainingRDD(Some(customSparkContext))
|
||||
val testSet = readFile(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
||||
import DataUtils._
|
||||
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
||||
"objective" -> "binary:logistic").toMap
|
||||
val xgBoostModel = XGBoost.train(trainingRDD, paramMap, 5, numWorkers, useExternalMemory = true)
|
||||
assert(eval.eval(xgBoostModel.predict(testSetDMatrix), testSetDMatrix) < 0.1)
|
||||
customSparkContext.stop()
|
||||
// clean
|
||||
val dir = new File(".")
|
||||
for (file <- dir.listFiles() if file.getName.startsWith("XGBoostSuite-dtrain_cache")) {
|
||||
file.delete()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user