From d7f79255ec68718c272f6f7af1ca207cde8f821b Mon Sep 17 00:00:00 2001 From: Nan Zhu Date: Sat, 27 Aug 2016 17:16:22 -0400 Subject: [PATCH] improve test of save/load model (#1515) --- .../scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala index d17c25b2e..9de130c72 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala @@ -150,11 +150,13 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter { val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0", "objective" -> "binary:logistic").toMap val xgBoostModel = XGBoost.train(trainingRDD, paramMap, 5, numWorkers) - assert(eval.eval(xgBoostModel.predict(testSetDMatrix), testSetDMatrix) < 0.1) + val evalResults = eval.eval(xgBoostModel.predict(testSetDMatrix), testSetDMatrix) + assert(evalResults < 0.1) xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath) val loadedXGBooostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath) val predicts = loadedXGBooostModel.predict(testSetDMatrix) - assert(eval.eval(predicts, testSetDMatrix) < 0.1) + val loadedEvalResults = eval.eval(predicts, testSetDMatrix) + assert(loadedEvalResults == evalResults) } test("nthread configuration must be equal to spark.task.cpus") {