From d9584ab82e888de78e1d33829d6de5689d46cd54 Mon Sep 17 00:00:00 2001 From: Ruimin Wang Date: Fri, 9 Dec 2016 12:33:40 +0800 Subject: [PATCH] refactor duplicate evaluation implementation (#1852) --- .../xgboost4j/scala/spark/XGBoostModel.scala | 55 ++++--------------- 1 file changed, 11 insertions(+), 44 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala index 26bff11ab..8f934229c 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala @@ -82,15 +82,6 @@ abstract class XGBoostModel(protected var _booster: Booster) def eval(evalDataset: RDD[MLLabeledPoint], evalName: String, evalFunc: EvalTrait = null, iter: Int = -1, useExternalCache: Boolean = false): String = { require(evalFunc != null || iter != -1, "you have to specify the value of either eval or iter") - if (evalFunc == null) { - eval(evalDataset, evalName, iter) - } else { - eval(evalDataset, evalName, evalFunc) - } - } - - // TODO: refactor to remove duplicate code in two variations of eval() - private def eval(evalDataset: RDD[MLLabeledPoint], evalName: String, iter: Int): String = { val broadcastBooster = evalDataset.sparkContext.broadcast(_booster) val broadcastUseExternalCache = evalDataset.sparkContext.broadcast($(useExternalMemory)) val appName = evalDataset.context.appName @@ -109,43 +100,19 @@ abstract class XGBoostModel(protected var _booster: Booster) } import DataUtils._ val dMatrix = new DMatrix(labeledPointsPartition, cacheFileName) - val predStr = broadcastBooster.value.evalSet(Array(dMatrix), Array(evalName), iter) - val Array(evName, predNumeric) = predStr.split(":") - Rabit.shutdown() - Iterator(Some(evName, predNumeric.toFloat)) - } else { - Iterator(None) - } - }.filter(_.isDefined).collect() - val evalPrefix = allEvalMetrics.map(_.get._1).head - val evalMetricMean = allEvalMetrics.map(_.get._2).sum / allEvalMetrics.length - s"$evalPrefix = $evalMetricMean" - } - - private def eval(evalDataset: RDD[MLLabeledPoint], evalName: String, evalFunc: EvalTrait): - String = { - require(evalFunc != null, "you have to specify the value of either eval or iter") - val broadcastBooster = evalDataset.sparkContext.broadcast(_booster) - val broadcastUseExternalCache = evalDataset.sparkContext.broadcast($(useExternalMemory)) - val appName = evalDataset.context.appName - val allEvalMetrics = evalDataset.mapPartitions { - labeledPointsPartition => - if (labeledPointsPartition.hasNext) { - val rabitEnv = Map("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString) - Rabit.init(rabitEnv.asJava) - val cacheFileName = { - if (broadcastUseExternalCache.value) { - s"$appName-${TaskContext.get().stageId()}-$evalName" + - s"-deval_cache-${TaskContext.getPartitionId()}" - } else { - null + (evalFunc, iter) match { + case (null, _) => { + val predStr = broadcastBooster.value.evalSet(Array(dMatrix), Array(evalName), iter) + val Array(evName, predNumeric) = predStr.split(":") + Rabit.shutdown() + Iterator(Some(evName, predNumeric.toFloat)) + } + case _ => { + val predictions = broadcastBooster.value.predict(dMatrix) + Rabit.shutdown() + Iterator(Some((evalName, evalFunc.eval(predictions, dMatrix)))) } } - import DataUtils._ - val dMatrix = new DMatrix(labeledPointsPartition, cacheFileName) - val predictions = broadcastBooster.value.predict(dMatrix) - Rabit.shutdown() - Iterator(Some((evalName, evalFunc.eval(predictions, dMatrix)))) } else { Iterator(None) }