refactor duplicate evaluation implementation (#1852)
This commit is contained in:
parent
2b6aa7736f
commit
d9584ab82e
@ -82,15 +82,6 @@ abstract class XGBoostModel(protected var _booster: Booster)
|
|||||||
def eval(evalDataset: RDD[MLLabeledPoint], evalName: String, evalFunc: EvalTrait = null,
|
def eval(evalDataset: RDD[MLLabeledPoint], evalName: String, evalFunc: EvalTrait = null,
|
||||||
iter: Int = -1, useExternalCache: Boolean = false): String = {
|
iter: Int = -1, useExternalCache: Boolean = false): String = {
|
||||||
require(evalFunc != null || iter != -1, "you have to specify the value of either eval or iter")
|
require(evalFunc != null || iter != -1, "you have to specify the value of either eval or iter")
|
||||||
if (evalFunc == null) {
|
|
||||||
eval(evalDataset, evalName, iter)
|
|
||||||
} else {
|
|
||||||
eval(evalDataset, evalName, evalFunc)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: refactor to remove duplicate code in two variations of eval()
|
|
||||||
private def eval(evalDataset: RDD[MLLabeledPoint], evalName: String, iter: Int): String = {
|
|
||||||
val broadcastBooster = evalDataset.sparkContext.broadcast(_booster)
|
val broadcastBooster = evalDataset.sparkContext.broadcast(_booster)
|
||||||
val broadcastUseExternalCache = evalDataset.sparkContext.broadcast($(useExternalMemory))
|
val broadcastUseExternalCache = evalDataset.sparkContext.broadcast($(useExternalMemory))
|
||||||
val appName = evalDataset.context.appName
|
val appName = evalDataset.context.appName
|
||||||
@ -109,43 +100,19 @@ abstract class XGBoostModel(protected var _booster: Booster)
|
|||||||
}
|
}
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
val dMatrix = new DMatrix(labeledPointsPartition, cacheFileName)
|
val dMatrix = new DMatrix(labeledPointsPartition, cacheFileName)
|
||||||
val predStr = broadcastBooster.value.evalSet(Array(dMatrix), Array(evalName), iter)
|
(evalFunc, iter) match {
|
||||||
val Array(evName, predNumeric) = predStr.split(":")
|
case (null, _) => {
|
||||||
Rabit.shutdown()
|
val predStr = broadcastBooster.value.evalSet(Array(dMatrix), Array(evalName), iter)
|
||||||
Iterator(Some(evName, predNumeric.toFloat))
|
val Array(evName, predNumeric) = predStr.split(":")
|
||||||
} else {
|
Rabit.shutdown()
|
||||||
Iterator(None)
|
Iterator(Some(evName, predNumeric.toFloat))
|
||||||
}
|
}
|
||||||
}.filter(_.isDefined).collect()
|
case _ => {
|
||||||
val evalPrefix = allEvalMetrics.map(_.get._1).head
|
val predictions = broadcastBooster.value.predict(dMatrix)
|
||||||
val evalMetricMean = allEvalMetrics.map(_.get._2).sum / allEvalMetrics.length
|
Rabit.shutdown()
|
||||||
s"$evalPrefix = $evalMetricMean"
|
Iterator(Some((evalName, evalFunc.eval(predictions, dMatrix))))
|
||||||
}
|
|
||||||
|
|
||||||
private def eval(evalDataset: RDD[MLLabeledPoint], evalName: String, evalFunc: EvalTrait):
|
|
||||||
String = {
|
|
||||||
require(evalFunc != null, "you have to specify the value of either eval or iter")
|
|
||||||
val broadcastBooster = evalDataset.sparkContext.broadcast(_booster)
|
|
||||||
val broadcastUseExternalCache = evalDataset.sparkContext.broadcast($(useExternalMemory))
|
|
||||||
val appName = evalDataset.context.appName
|
|
||||||
val allEvalMetrics = evalDataset.mapPartitions {
|
|
||||||
labeledPointsPartition =>
|
|
||||||
if (labeledPointsPartition.hasNext) {
|
|
||||||
val rabitEnv = Map("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString)
|
|
||||||
Rabit.init(rabitEnv.asJava)
|
|
||||||
val cacheFileName = {
|
|
||||||
if (broadcastUseExternalCache.value) {
|
|
||||||
s"$appName-${TaskContext.get().stageId()}-$evalName" +
|
|
||||||
s"-deval_cache-${TaskContext.getPartitionId()}"
|
|
||||||
} else {
|
|
||||||
null
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
import DataUtils._
|
|
||||||
val dMatrix = new DMatrix(labeledPointsPartition, cacheFileName)
|
|
||||||
val predictions = broadcastBooster.value.predict(dMatrix)
|
|
||||||
Rabit.shutdown()
|
|
||||||
Iterator(Some((evalName, evalFunc.eval(predictions, dMatrix))))
|
|
||||||
} else {
|
} else {
|
||||||
Iterator(None)
|
Iterator(None)
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user