From ccccf8a015a495329d6544b58e3ddcdaa9d9ded4 Mon Sep 17 00:00:00 2001 From: ebernhardson Date: Tue, 2 May 2017 10:03:20 -0700 Subject: [PATCH] [jvm-packages] Accept groupData in spark model eval (#2244) * Support model evaluation for ranking tasks by accepting groupData in XGBoostModel.eval --- .../ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala | 8 +++++++- .../dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala | 5 ++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala index c1b615993..2731b9dd9 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala @@ -87,10 +87,13 @@ abstract class XGBoostModel(protected var _booster: Booster) * @param evalFunc the customized evaluation function, null by default to use the default metric * of model * @param iter the current iteration, -1 to be null to use customized evaluation functions + * @param groupData group data specify each group size for ranking task. Top level corresponds + * to partition id, second level is the group sizes. * @return the average metric over all partitions */ def eval(evalDataset: RDD[MLLabeledPoint], evalName: String, evalFunc: EvalTrait = null, - iter: Int = -1, useExternalCache: Boolean = false): String = { + iter: Int = -1, useExternalCache: Boolean = false, + groupData: Seq[Seq[Int]] = null): String = { require(evalFunc != null || iter != -1, "you have to specify the value of either eval or iter") val broadcastBooster = evalDataset.sparkContext.broadcast(_booster) val broadcastUseExternalCache = evalDataset.sparkContext.broadcast($(useExternalMemory)) @@ -110,6 +113,9 @@ abstract class XGBoostModel(protected var _booster: Booster) } import DataUtils._ val dMatrix = new DMatrix(labeledPointsPartition, cacheFileName) + if (groupData != null) { + dMatrix.setGroup(groupData(TaskContext.getPartitionId()).toArray) + } (evalFunc, iter) match { case (null, _) => { val predStr = broadcastBooster.value.evalSet(Array(dMatrix), Array(evalName), iter) diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala index fb41becea..29cbf5c47 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala @@ -352,12 +352,15 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils { val testRDD = sc.parallelize(testSet, numSlices = 1).map(_.features) val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", - "objective" -> "rank:pairwise", "groupData" -> trainGroupData) + "objective" -> "rank:pairwise", "eval_metric" -> "ndcg", "groupData" -> trainGroupData) val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, nWorkers = 1) val predRDD = xgBoostModel.predict(testRDD) val predResult1: Array[Array[Float]] = predRDD.collect()(0) assert(testRDD.count() === predResult1.length) + + val avgMetric = xgBoostModel.eval(trainingRDD, "test", iter = 0, groupData = trainGroupData) + assert(avgMetric contains "ndcg") } test("test use nested groupData") {