[jvm-packages] remove APIs with DMatrix from xgboost-spark (#1519)
* test consistency of prediction functions between DMatrix and RDD * remove APIs with DMatrix from xgboost-spark * fix compilation error in xgboost4j-example * fix test cases
This commit is contained in:
parent
6d65aae091
commit
74db1e8867
@ -47,7 +47,7 @@ object DistTrainWithSpark {
|
||||
"objective" -> "binary:logistic").toMap
|
||||
val xgboostModel = XGBoost.train(trainRDD, paramMap, numRound, nWorkers = args(1).toInt,
|
||||
useExternalMemory = true)
|
||||
xgboostModel.predict(new DMatrix(testSet))
|
||||
xgboostModel.booster.predict(new DMatrix(testSet))
|
||||
// save model to HDFS path
|
||||
xgboostModel.saveModelAsHadoopFile(outputModelPath)
|
||||
}
|
||||
|
||||
@ -122,15 +122,6 @@ class XGBoostModel(_booster: Booster) extends Serializable {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Predict result with the given test set (represented as DMatrix)
|
||||
*
|
||||
* @param testSet test set represented as DMatrix
|
||||
*/
|
||||
def predict(testSet: DMatrix): Array[Array[Float]] = {
|
||||
_booster.predict(testSet)
|
||||
}
|
||||
|
||||
/**
|
||||
* Predict leaf instances with the given test set (represented as RDD)
|
||||
*
|
||||
@ -149,15 +140,6 @@ class XGBoostModel(_booster: Booster) extends Serializable {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Predict leaf instances with the given test set (represented as DMatrix)
|
||||
*
|
||||
* @param testSet test set represented as DMatrix
|
||||
*/
|
||||
def predictLeaves(testSet: DMatrix): Array[Array[Float]] = {
|
||||
_booster.predictLeaf(testSet, 0)
|
||||
}
|
||||
|
||||
/**
|
||||
* Save the model as to HDFS-compatible file system.
|
||||
*
|
||||
|
||||
@ -125,7 +125,7 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter {
|
||||
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
||||
val boosterRDD = XGBoost.buildDistributedBoosters(
|
||||
trainingRDD,
|
||||
List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
||||
List("eta" -> "1", "max_depth" -> "6", "silent" -> "0",
|
||||
"objective" -> "binary:logistic").toMap,
|
||||
new scala.collection.mutable.HashMap[String, String],
|
||||
numWorkers = 2, round = 5, null, null, useExternalMemory = false)
|
||||
@ -134,8 +134,9 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter {
|
||||
val boosters = boosterRDD.collect()
|
||||
val eval = new EvalError()
|
||||
for (booster <- boosters) {
|
||||
// the threshold is 0.11 because it does not sync boosters with AllReduce
|
||||
val predicts = booster.predict(testSetDMatrix, outPutMargin = true)
|
||||
assert(eval.eval(predicts, testSetDMatrix) < 0.17)
|
||||
assert(eval.eval(predicts, testSetDMatrix) < 0.11)
|
||||
}
|
||||
}
|
||||
|
||||
@ -211,7 +212,7 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter {
|
||||
val predRDD = xgBoostModel.predict(testRDD)
|
||||
val predResult1 = predRDD.collect()(0)
|
||||
import DataUtils._
|
||||
val predResult2 = xgBoostModel.predict(new DMatrix(testSet.iterator))
|
||||
val predResult2 = xgBoostModel.booster.predict(new DMatrix(testSet.iterator))
|
||||
for (i <- predResult1.indices; j <- predResult1(i).indices) {
|
||||
assert(predResult1(i)(j) === predResult2(i)(j))
|
||||
}
|
||||
@ -222,7 +223,6 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter {
|
||||
val sampleList = new ListBuffer[SparkVector]
|
||||
sparkContext.getOrElse(sc).parallelize(sampleList, numWorkers)
|
||||
}
|
||||
|
||||
val trainingRDD = buildTrainingRDD()
|
||||
val testRDD = buildEmptyRDD()
|
||||
import DataUtils._
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user