diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/DistTrainWithSpark.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/DistTrainWithSpark.scala index 6cd418fde..b96089e42 100644 --- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/DistTrainWithSpark.scala +++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/DistTrainWithSpark.scala @@ -47,7 +47,7 @@ object DistTrainWithSpark { "objective" -> "binary:logistic").toMap val xgboostModel = XGBoost.train(trainRDD, paramMap, numRound, nWorkers = args(1).toInt, useExternalMemory = true) - xgboostModel.predict(new DMatrix(testSet)) + xgboostModel.booster.predict(new DMatrix(testSet)) // save model to HDFS path xgboostModel.saveModelAsHadoopFile(outputModelPath) } diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala index 6a692ffb3..597f08031 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala @@ -122,15 +122,6 @@ class XGBoostModel(_booster: Booster) extends Serializable { } } - /** - * Predict result with the given test set (represented as DMatrix) - * - * @param testSet test set represented as DMatrix - */ - def predict(testSet: DMatrix): Array[Array[Float]] = { - _booster.predict(testSet) - } - /** * Predict leaf instances with the given test set (represented as RDD) * @@ -149,15 +140,6 @@ class XGBoostModel(_booster: Booster) extends Serializable { } } - /** - * Predict leaf instances with the given test set (represented as DMatrix) - * - * @param testSet test set represented as DMatrix - */ - def predictLeaves(testSet: DMatrix): Array[Array[Float]] = { - _booster.predictLeaf(testSet, 0) - } - /** * Save the model as to HDFS-compatible file system. * diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala index ba16e2f41..2b6131546 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala @@ -125,7 +125,7 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter { val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null)) val boosterRDD = XGBoost.buildDistributedBoosters( trainingRDD, - List("eta" -> "1", "max_depth" -> "2", "silent" -> "0", + List("eta" -> "1", "max_depth" -> "6", "silent" -> "0", "objective" -> "binary:logistic").toMap, new scala.collection.mutable.HashMap[String, String], numWorkers = 2, round = 5, null, null, useExternalMemory = false) @@ -134,8 +134,9 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter { val boosters = boosterRDD.collect() val eval = new EvalError() for (booster <- boosters) { + // the threshold is 0.11 because it does not sync boosters with AllReduce val predicts = booster.predict(testSetDMatrix, outPutMargin = true) - assert(eval.eval(predicts, testSetDMatrix) < 0.17) + assert(eval.eval(predicts, testSetDMatrix) < 0.11) } } @@ -211,7 +212,7 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter { val predRDD = xgBoostModel.predict(testRDD) val predResult1 = predRDD.collect()(0) import DataUtils._ - val predResult2 = xgBoostModel.predict(new DMatrix(testSet.iterator)) + val predResult2 = xgBoostModel.booster.predict(new DMatrix(testSet.iterator)) for (i <- predResult1.indices; j <- predResult1(i).indices) { assert(predResult1(i)(j) === predResult2(i)(j)) } @@ -222,7 +223,6 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter { val sampleList = new ListBuffer[SparkVector] sparkContext.getOrElse(sc).parallelize(sampleList, numWorkers) } - val trainingRDD = buildTrainingRDD() val testRDD = buildEmptyRDD() import DataUtils._