refactor duplicate evaluation implementation (#1852)
This commit is contained in:
parent
2b6aa7736f
commit
d9584ab82e
@ -82,15 +82,6 @@ abstract class XGBoostModel(protected var _booster: Booster)
|
||||
def eval(evalDataset: RDD[MLLabeledPoint], evalName: String, evalFunc: EvalTrait = null,
|
||||
iter: Int = -1, useExternalCache: Boolean = false): String = {
|
||||
require(evalFunc != null || iter != -1, "you have to specify the value of either eval or iter")
|
||||
if (evalFunc == null) {
|
||||
eval(evalDataset, evalName, iter)
|
||||
} else {
|
||||
eval(evalDataset, evalName, evalFunc)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: refactor to remove duplicate code in two variations of eval()
|
||||
private def eval(evalDataset: RDD[MLLabeledPoint], evalName: String, iter: Int): String = {
|
||||
val broadcastBooster = evalDataset.sparkContext.broadcast(_booster)
|
||||
val broadcastUseExternalCache = evalDataset.sparkContext.broadcast($(useExternalMemory))
|
||||
val appName = evalDataset.context.appName
|
||||
@ -109,43 +100,19 @@ abstract class XGBoostModel(protected var _booster: Booster)
|
||||
}
|
||||
import DataUtils._
|
||||
val dMatrix = new DMatrix(labeledPointsPartition, cacheFileName)
|
||||
val predStr = broadcastBooster.value.evalSet(Array(dMatrix), Array(evalName), iter)
|
||||
val Array(evName, predNumeric) = predStr.split(":")
|
||||
Rabit.shutdown()
|
||||
Iterator(Some(evName, predNumeric.toFloat))
|
||||
} else {
|
||||
Iterator(None)
|
||||
}
|
||||
}.filter(_.isDefined).collect()
|
||||
val evalPrefix = allEvalMetrics.map(_.get._1).head
|
||||
val evalMetricMean = allEvalMetrics.map(_.get._2).sum / allEvalMetrics.length
|
||||
s"$evalPrefix = $evalMetricMean"
|
||||
}
|
||||
|
||||
private def eval(evalDataset: RDD[MLLabeledPoint], evalName: String, evalFunc: EvalTrait):
|
||||
String = {
|
||||
require(evalFunc != null, "you have to specify the value of either eval or iter")
|
||||
val broadcastBooster = evalDataset.sparkContext.broadcast(_booster)
|
||||
val broadcastUseExternalCache = evalDataset.sparkContext.broadcast($(useExternalMemory))
|
||||
val appName = evalDataset.context.appName
|
||||
val allEvalMetrics = evalDataset.mapPartitions {
|
||||
labeledPointsPartition =>
|
||||
if (labeledPointsPartition.hasNext) {
|
||||
val rabitEnv = Map("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString)
|
||||
Rabit.init(rabitEnv.asJava)
|
||||
val cacheFileName = {
|
||||
if (broadcastUseExternalCache.value) {
|
||||
s"$appName-${TaskContext.get().stageId()}-$evalName" +
|
||||
s"-deval_cache-${TaskContext.getPartitionId()}"
|
||||
} else {
|
||||
null
|
||||
(evalFunc, iter) match {
|
||||
case (null, _) => {
|
||||
val predStr = broadcastBooster.value.evalSet(Array(dMatrix), Array(evalName), iter)
|
||||
val Array(evName, predNumeric) = predStr.split(":")
|
||||
Rabit.shutdown()
|
||||
Iterator(Some(evName, predNumeric.toFloat))
|
||||
}
|
||||
case _ => {
|
||||
val predictions = broadcastBooster.value.predict(dMatrix)
|
||||
Rabit.shutdown()
|
||||
Iterator(Some((evalName, evalFunc.eval(predictions, dMatrix))))
|
||||
}
|
||||
}
|
||||
import DataUtils._
|
||||
val dMatrix = new DMatrix(labeledPointsPartition, cacheFileName)
|
||||
val predictions = broadcastBooster.value.predict(dMatrix)
|
||||
Rabit.shutdown()
|
||||
Iterator(Some((evalName, evalFunc.eval(predictions, dMatrix))))
|
||||
} else {
|
||||
Iterator(None)
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user