evaluation with RDD data (#1492)

This commit is contained in:
Nan Zhu 2016-08-20 18:31:10 -04:00 committed by GitHub
parent 582ee63e34
commit dc1125eb56

View File

@ -17,15 +17,54 @@
package ml.dmlc.xgboost4j.scala.spark
import org.apache.hadoop.fs.{Path, FileSystem}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.{TaskContext, SparkContext}
import org.apache.spark.mllib.linalg.{DenseVector, Vector}
import org.apache.spark.rdd.RDD
import ml.dmlc.xgboost4j.java.{Rabit, DMatrix => JDMatrix}
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
import ml.dmlc.xgboost4j.scala.{EvalTrait, Booster, DMatrix}
import scala.collection.JavaConverters._
class XGBoostModel(_booster: Booster) extends Serializable {
/**
* evaluate XGBoostModel with a RDD-wrapped dataset
*
* @param evalDataset the dataset used for evaluation
* @param eval the customized evaluation function, can be null for using default in the model
* @param useExternalCache if use external cache
* @return the average metric over all partitions
*/
def eval(
evalDataset: RDD[LabeledPoint],
eval: EvalTrait,
evalName: String,
useExternalCache: Boolean = false): String = {
val appName = evalDataset.context.appName
val allEvalMetrics = evalDataset.mapPartitions {
labeledPointsPartition =>
if (labeledPointsPartition.hasNext) {
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
Rabit.init(rabitEnv.asJava)
import DataUtils._
val cacheFileName = {
if (useExternalCache) {
s"$appName-deval_cache-${TaskContext.getPartitionId()}"
} else {
null
}
}
val dMatrix = new DMatrix(labeledPointsPartition, cacheFileName)
val predictions = _booster.predict(dMatrix)
Rabit.shutdown()
Iterator(Some(eval.eval(predictions, dMatrix)))
} else {
Iterator(None)
}
}.filter(_.isDefined).collect()
s"$evalName-${eval.getMetric} = ${allEvalMetrics.map(_.get).sum / allEvalMetrics.length}"
}
/**
* Predict result with the given test set (represented as RDD)
*