refactor duplicate evaluation implementation (#1852)

This commit is contained in:
Ruimin Wang 2016-12-09 12:33:40 +08:00 committed by Nan Zhu
parent 2b6aa7736f
commit d9584ab82e

View File

@ -82,15 +82,6 @@ abstract class XGBoostModel(protected var _booster: Booster)
def eval(evalDataset: RDD[MLLabeledPoint], evalName: String, evalFunc: EvalTrait = null,
iter: Int = -1, useExternalCache: Boolean = false): String = {
require(evalFunc != null || iter != -1, "you have to specify the value of either eval or iter")
if (evalFunc == null) {
eval(evalDataset, evalName, iter)
} else {
eval(evalDataset, evalName, evalFunc)
}
}
// TODO: refactor to remove duplicate code in two variations of eval()
private def eval(evalDataset: RDD[MLLabeledPoint], evalName: String, iter: Int): String = {
val broadcastBooster = evalDataset.sparkContext.broadcast(_booster)
val broadcastUseExternalCache = evalDataset.sparkContext.broadcast($(useExternalMemory))
val appName = evalDataset.context.appName
@ -109,43 +100,19 @@ abstract class XGBoostModel(protected var _booster: Booster)
}
import DataUtils._
val dMatrix = new DMatrix(labeledPointsPartition, cacheFileName)
val predStr = broadcastBooster.value.evalSet(Array(dMatrix), Array(evalName), iter)
val Array(evName, predNumeric) = predStr.split(":")
Rabit.shutdown()
Iterator(Some(evName, predNumeric.toFloat))
} else {
Iterator(None)
}
}.filter(_.isDefined).collect()
val evalPrefix = allEvalMetrics.map(_.get._1).head
val evalMetricMean = allEvalMetrics.map(_.get._2).sum / allEvalMetrics.length
s"$evalPrefix = $evalMetricMean"
}
private def eval(evalDataset: RDD[MLLabeledPoint], evalName: String, evalFunc: EvalTrait):
String = {
require(evalFunc != null, "you have to specify the value of either eval or iter")
val broadcastBooster = evalDataset.sparkContext.broadcast(_booster)
val broadcastUseExternalCache = evalDataset.sparkContext.broadcast($(useExternalMemory))
val appName = evalDataset.context.appName
val allEvalMetrics = evalDataset.mapPartitions {
labeledPointsPartition =>
if (labeledPointsPartition.hasNext) {
val rabitEnv = Map("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString)
Rabit.init(rabitEnv.asJava)
val cacheFileName = {
if (broadcastUseExternalCache.value) {
s"$appName-${TaskContext.get().stageId()}-$evalName" +
s"-deval_cache-${TaskContext.getPartitionId()}"
} else {
null
(evalFunc, iter) match {
case (null, _) => {
val predStr = broadcastBooster.value.evalSet(Array(dMatrix), Array(evalName), iter)
val Array(evName, predNumeric) = predStr.split(":")
Rabit.shutdown()
Iterator(Some(evName, predNumeric.toFloat))
}
case _ => {
val predictions = broadcastBooster.value.predict(dMatrix)
Rabit.shutdown()
Iterator(Some((evalName, evalFunc.eval(predictions, dMatrix))))
}
}
import DataUtils._
val dMatrix = new DMatrix(labeledPointsPartition, cacheFileName)
val predictions = broadcastBooster.value.predict(dMatrix)
Rabit.shutdown()
Iterator(Some((evalName, evalFunc.eval(predictions, dMatrix))))
} else {
Iterator(None)
}