[jvm-packages] Accept groupData in spark model eval (#2244)

* Support model evaluation for ranking tasks by accepting
 groupData in XGBoostModel.eval
This commit is contained in:
ebernhardson 2017-05-02 10:03:20 -07:00 committed by Nan Zhu
parent a375ad2822
commit ccccf8a015
2 changed files with 11 additions and 2 deletions

View File

@ -87,10 +87,13 @@ abstract class XGBoostModel(protected var _booster: Booster)
* @param evalFunc the customized evaluation function, null by default to use the default metric * @param evalFunc the customized evaluation function, null by default to use the default metric
* of model * of model
* @param iter the current iteration, -1 to be null to use customized evaluation functions * @param iter the current iteration, -1 to be null to use customized evaluation functions
* @param groupData group data specify each group size for ranking task. Top level corresponds
* to partition id, second level is the group sizes.
* @return the average metric over all partitions * @return the average metric over all partitions
*/ */
def eval(evalDataset: RDD[MLLabeledPoint], evalName: String, evalFunc: EvalTrait = null, def eval(evalDataset: RDD[MLLabeledPoint], evalName: String, evalFunc: EvalTrait = null,
iter: Int = -1, useExternalCache: Boolean = false): String = { iter: Int = -1, useExternalCache: Boolean = false,
groupData: Seq[Seq[Int]] = null): String = {
require(evalFunc != null || iter != -1, "you have to specify the value of either eval or iter") require(evalFunc != null || iter != -1, "you have to specify the value of either eval or iter")
val broadcastBooster = evalDataset.sparkContext.broadcast(_booster) val broadcastBooster = evalDataset.sparkContext.broadcast(_booster)
val broadcastUseExternalCache = evalDataset.sparkContext.broadcast($(useExternalMemory)) val broadcastUseExternalCache = evalDataset.sparkContext.broadcast($(useExternalMemory))
@ -110,6 +113,9 @@ abstract class XGBoostModel(protected var _booster: Booster)
} }
import DataUtils._ import DataUtils._
val dMatrix = new DMatrix(labeledPointsPartition, cacheFileName) val dMatrix = new DMatrix(labeledPointsPartition, cacheFileName)
if (groupData != null) {
dMatrix.setGroup(groupData(TaskContext.getPartitionId()).toArray)
}
(evalFunc, iter) match { (evalFunc, iter) match {
case (null, _) => { case (null, _) => {
val predStr = broadcastBooster.value.evalSet(Array(dMatrix), Array(evalName), iter) val predStr = broadcastBooster.value.evalSet(Array(dMatrix), Array(evalName), iter)

View File

@ -352,12 +352,15 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
val testRDD = sc.parallelize(testSet, numSlices = 1).map(_.features) val testRDD = sc.parallelize(testSet, numSlices = 1).map(_.features)
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "rank:pairwise", "groupData" -> trainGroupData) "objective" -> "rank:pairwise", "eval_metric" -> "ndcg", "groupData" -> trainGroupData)
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, nWorkers = 1) val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, nWorkers = 1)
val predRDD = xgBoostModel.predict(testRDD) val predRDD = xgBoostModel.predict(testRDD)
val predResult1: Array[Array[Float]] = predRDD.collect()(0) val predResult1: Array[Array[Float]] = predRDD.collect()(0)
assert(testRDD.count() === predResult1.length) assert(testRDD.count() === predResult1.length)
val avgMetric = xgBoostModel.eval(trainingRDD, "test", iter = 0, groupData = trainGroupData)
assert(avgMetric contains "ndcg")
} }
test("test use nested groupData") { test("test use nested groupData") {