From dc1125eb562d5711e6e75ab4b75cd6b22c24c398 Mon Sep 17 00:00:00 2001 From: Nan Zhu Date: Sat, 20 Aug 2016 18:31:10 -0400 Subject: [PATCH] evaluation with RDD data (#1492) --- .../xgboost4j/scala/spark/XGBoostModel.scala | 41 ++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala index 8df7ef389..4cebdd900 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala @@ -17,15 +17,54 @@ package ml.dmlc.xgboost4j.scala.spark import org.apache.hadoop.fs.{Path, FileSystem} +import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.{TaskContext, SparkContext} import org.apache.spark.mllib.linalg.{DenseVector, Vector} import org.apache.spark.rdd.RDD import ml.dmlc.xgboost4j.java.{Rabit, DMatrix => JDMatrix} -import ml.dmlc.xgboost4j.scala.{Booster, DMatrix} +import ml.dmlc.xgboost4j.scala.{EvalTrait, Booster, DMatrix} import scala.collection.JavaConverters._ class XGBoostModel(_booster: Booster) extends Serializable { + /** + * evaluate XGBoostModel with a RDD-wrapped dataset + * + * @param evalDataset the dataset used for evaluation + * @param eval the customized evaluation function, can be null for using default in the model + * @param useExternalCache if use external cache + * @return the average metric over all partitions + */ + def eval( + evalDataset: RDD[LabeledPoint], + eval: EvalTrait, + evalName: String, + useExternalCache: Boolean = false): String = { + val appName = evalDataset.context.appName + val allEvalMetrics = evalDataset.mapPartitions { + labeledPointsPartition => + if (labeledPointsPartition.hasNext) { + val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap + Rabit.init(rabitEnv.asJava) + import DataUtils._ + val cacheFileName = { + if (useExternalCache) { + s"$appName-deval_cache-${TaskContext.getPartitionId()}" + } else { + null + } + } + val dMatrix = new DMatrix(labeledPointsPartition, cacheFileName) + val predictions = _booster.predict(dMatrix) + Rabit.shutdown() + Iterator(Some(eval.eval(predictions, dMatrix))) + } else { + Iterator(None) + } + }.filter(_.isDefined).collect() + s"$evalName-${eval.getMetric} = ${allEvalMetrics.map(_.get).sum / allEvalMetrics.length}" + } + /** * Predict result with the given test set (represented as RDD) *