evaluation with RDD data (#1492)
This commit is contained in:
parent
582ee63e34
commit
dc1125eb56
@ -17,15 +17,54 @@
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import org.apache.hadoop.fs.{Path, FileSystem}
|
||||
import org.apache.spark.mllib.regression.LabeledPoint
|
||||
import org.apache.spark.{TaskContext, SparkContext}
|
||||
import org.apache.spark.mllib.linalg.{DenseVector, Vector}
|
||||
import org.apache.spark.rdd.RDD
|
||||
import ml.dmlc.xgboost4j.java.{Rabit, DMatrix => JDMatrix}
|
||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
|
||||
import ml.dmlc.xgboost4j.scala.{EvalTrait, Booster, DMatrix}
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
class XGBoostModel(_booster: Booster) extends Serializable {
|
||||
|
||||
/**
|
||||
* evaluate XGBoostModel with a RDD-wrapped dataset
|
||||
*
|
||||
* @param evalDataset the dataset used for evaluation
|
||||
* @param eval the customized evaluation function, can be null for using default in the model
|
||||
* @param useExternalCache if use external cache
|
||||
* @return the average metric over all partitions
|
||||
*/
|
||||
def eval(
|
||||
evalDataset: RDD[LabeledPoint],
|
||||
eval: EvalTrait,
|
||||
evalName: String,
|
||||
useExternalCache: Boolean = false): String = {
|
||||
val appName = evalDataset.context.appName
|
||||
val allEvalMetrics = evalDataset.mapPartitions {
|
||||
labeledPointsPartition =>
|
||||
if (labeledPointsPartition.hasNext) {
|
||||
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
||||
Rabit.init(rabitEnv.asJava)
|
||||
import DataUtils._
|
||||
val cacheFileName = {
|
||||
if (useExternalCache) {
|
||||
s"$appName-deval_cache-${TaskContext.getPartitionId()}"
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}
|
||||
val dMatrix = new DMatrix(labeledPointsPartition, cacheFileName)
|
||||
val predictions = _booster.predict(dMatrix)
|
||||
Rabit.shutdown()
|
||||
Iterator(Some(eval.eval(predictions, dMatrix)))
|
||||
} else {
|
||||
Iterator(None)
|
||||
}
|
||||
}.filter(_.isDefined).collect()
|
||||
s"$evalName-${eval.getMetric} = ${allEvalMetrics.map(_.get).sum / allEvalMetrics.length}"
|
||||
}
|
||||
|
||||
/**
|
||||
* Predict result with the given test set (represented as RDD)
|
||||
*
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user