[jvm-packages] remove APIs with DMatrix from xgboost-spark (#1519)
* test consistency of prediction functions between DMatrix and RDD * remove APIs with DMatrix from xgboost-spark * fix compilation error in xgboost4j-example * fix test cases
This commit is contained in:
parent
6d65aae091
commit
74db1e8867
@ -47,7 +47,7 @@ object DistTrainWithSpark {
|
|||||||
"objective" -> "binary:logistic").toMap
|
"objective" -> "binary:logistic").toMap
|
||||||
val xgboostModel = XGBoost.train(trainRDD, paramMap, numRound, nWorkers = args(1).toInt,
|
val xgboostModel = XGBoost.train(trainRDD, paramMap, numRound, nWorkers = args(1).toInt,
|
||||||
useExternalMemory = true)
|
useExternalMemory = true)
|
||||||
xgboostModel.predict(new DMatrix(testSet))
|
xgboostModel.booster.predict(new DMatrix(testSet))
|
||||||
// save model to HDFS path
|
// save model to HDFS path
|
||||||
xgboostModel.saveModelAsHadoopFile(outputModelPath)
|
xgboostModel.saveModelAsHadoopFile(outputModelPath)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -122,15 +122,6 @@ class XGBoostModel(_booster: Booster) extends Serializable {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Predict result with the given test set (represented as DMatrix)
|
|
||||||
*
|
|
||||||
* @param testSet test set represented as DMatrix
|
|
||||||
*/
|
|
||||||
def predict(testSet: DMatrix): Array[Array[Float]] = {
|
|
||||||
_booster.predict(testSet)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Predict leaf instances with the given test set (represented as RDD)
|
* Predict leaf instances with the given test set (represented as RDD)
|
||||||
*
|
*
|
||||||
@ -149,15 +140,6 @@ class XGBoostModel(_booster: Booster) extends Serializable {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Predict leaf instances with the given test set (represented as DMatrix)
|
|
||||||
*
|
|
||||||
* @param testSet test set represented as DMatrix
|
|
||||||
*/
|
|
||||||
def predictLeaves(testSet: DMatrix): Array[Array[Float]] = {
|
|
||||||
_booster.predictLeaf(testSet, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Save the model as to HDFS-compatible file system.
|
* Save the model as to HDFS-compatible file system.
|
||||||
*
|
*
|
||||||
|
|||||||
@ -125,7 +125,7 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter {
|
|||||||
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
||||||
val boosterRDD = XGBoost.buildDistributedBoosters(
|
val boosterRDD = XGBoost.buildDistributedBoosters(
|
||||||
trainingRDD,
|
trainingRDD,
|
||||||
List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
List("eta" -> "1", "max_depth" -> "6", "silent" -> "0",
|
||||||
"objective" -> "binary:logistic").toMap,
|
"objective" -> "binary:logistic").toMap,
|
||||||
new scala.collection.mutable.HashMap[String, String],
|
new scala.collection.mutable.HashMap[String, String],
|
||||||
numWorkers = 2, round = 5, null, null, useExternalMemory = false)
|
numWorkers = 2, round = 5, null, null, useExternalMemory = false)
|
||||||
@ -134,8 +134,9 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter {
|
|||||||
val boosters = boosterRDD.collect()
|
val boosters = boosterRDD.collect()
|
||||||
val eval = new EvalError()
|
val eval = new EvalError()
|
||||||
for (booster <- boosters) {
|
for (booster <- boosters) {
|
||||||
|
// the threshold is 0.11 because it does not sync boosters with AllReduce
|
||||||
val predicts = booster.predict(testSetDMatrix, outPutMargin = true)
|
val predicts = booster.predict(testSetDMatrix, outPutMargin = true)
|
||||||
assert(eval.eval(predicts, testSetDMatrix) < 0.17)
|
assert(eval.eval(predicts, testSetDMatrix) < 0.11)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -211,7 +212,7 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter {
|
|||||||
val predRDD = xgBoostModel.predict(testRDD)
|
val predRDD = xgBoostModel.predict(testRDD)
|
||||||
val predResult1 = predRDD.collect()(0)
|
val predResult1 = predRDD.collect()(0)
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
val predResult2 = xgBoostModel.predict(new DMatrix(testSet.iterator))
|
val predResult2 = xgBoostModel.booster.predict(new DMatrix(testSet.iterator))
|
||||||
for (i <- predResult1.indices; j <- predResult1(i).indices) {
|
for (i <- predResult1.indices; j <- predResult1(i).indices) {
|
||||||
assert(predResult1(i)(j) === predResult2(i)(j))
|
assert(predResult1(i)(j) === predResult2(i)(j))
|
||||||
}
|
}
|
||||||
@ -222,7 +223,6 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter {
|
|||||||
val sampleList = new ListBuffer[SparkVector]
|
val sampleList = new ListBuffer[SparkVector]
|
||||||
sparkContext.getOrElse(sc).parallelize(sampleList, numWorkers)
|
sparkContext.getOrElse(sc).parallelize(sampleList, numWorkers)
|
||||||
}
|
}
|
||||||
|
|
||||||
val trainingRDD = buildTrainingRDD()
|
val trainingRDD = buildTrainingRDD()
|
||||||
val testRDD = buildEmptyRDD()
|
val testRDD = buildEmptyRDD()
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user