[jvm-packages] Accept groupData in spark model eval (#2244)
* Support model evaluation for ranking tasks by accepting groupData in XGBoostModel.eval
This commit is contained in:
@@ -87,10 +87,13 @@ abstract class XGBoostModel(protected var _booster: Booster)
|
||||
* @param evalFunc the customized evaluation function, null by default to use the default metric
|
||||
* of model
|
||||
* @param iter the current iteration, -1 to be null to use customized evaluation functions
|
||||
* @param groupData group data specify each group size for ranking task. Top level corresponds
|
||||
* to partition id, second level is the group sizes.
|
||||
* @return the average metric over all partitions
|
||||
*/
|
||||
def eval(evalDataset: RDD[MLLabeledPoint], evalName: String, evalFunc: EvalTrait = null,
|
||||
iter: Int = -1, useExternalCache: Boolean = false): String = {
|
||||
iter: Int = -1, useExternalCache: Boolean = false,
|
||||
groupData: Seq[Seq[Int]] = null): String = {
|
||||
require(evalFunc != null || iter != -1, "you have to specify the value of either eval or iter")
|
||||
val broadcastBooster = evalDataset.sparkContext.broadcast(_booster)
|
||||
val broadcastUseExternalCache = evalDataset.sparkContext.broadcast($(useExternalMemory))
|
||||
@@ -110,6 +113,9 @@ abstract class XGBoostModel(protected var _booster: Booster)
|
||||
}
|
||||
import DataUtils._
|
||||
val dMatrix = new DMatrix(labeledPointsPartition, cacheFileName)
|
||||
if (groupData != null) {
|
||||
dMatrix.setGroup(groupData(TaskContext.getPartitionId()).toArray)
|
||||
}
|
||||
(evalFunc, iter) match {
|
||||
case (null, _) => {
|
||||
val predStr = broadcastBooster.value.evalSet(Array(dMatrix), Array(evalName), iter)
|
||||
|
||||
Reference in New Issue
Block a user