improve test of save/load model (#1515)

This commit is contained in:
Nan Zhu 2016-08-27 17:16:22 -04:00 committed by GitHub
parent 53ce511be3
commit d7f79255ec

View File

@ -150,11 +150,13 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter {
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0", val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
"objective" -> "binary:logistic").toMap "objective" -> "binary:logistic").toMap
val xgBoostModel = XGBoost.train(trainingRDD, paramMap, 5, numWorkers) val xgBoostModel = XGBoost.train(trainingRDD, paramMap, 5, numWorkers)
assert(eval.eval(xgBoostModel.predict(testSetDMatrix), testSetDMatrix) < 0.1) val evalResults = eval.eval(xgBoostModel.predict(testSetDMatrix), testSetDMatrix)
assert(evalResults < 0.1)
xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath) xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath)
val loadedXGBooostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath) val loadedXGBooostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath)
val predicts = loadedXGBooostModel.predict(testSetDMatrix) val predicts = loadedXGBooostModel.predict(testSetDMatrix)
assert(eval.eval(predicts, testSetDMatrix) < 0.1) val loadedEvalResults = eval.eval(predicts, testSetDMatrix)
assert(loadedEvalResults == evalResults)
} }
test("nthread configuration must be equal to spark.task.cpus") { test("nthread configuration must be equal to spark.task.cpus") {