evaluation with RDD data (#1492)
This commit is contained in:
parent
582ee63e34
commit
dc1125eb56
@ -17,15 +17,54 @@
|
|||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
import org.apache.hadoop.fs.{Path, FileSystem}
|
import org.apache.hadoop.fs.{Path, FileSystem}
|
||||||
|
import org.apache.spark.mllib.regression.LabeledPoint
|
||||||
import org.apache.spark.{TaskContext, SparkContext}
|
import org.apache.spark.{TaskContext, SparkContext}
|
||||||
import org.apache.spark.mllib.linalg.{DenseVector, Vector}
|
import org.apache.spark.mllib.linalg.{DenseVector, Vector}
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import ml.dmlc.xgboost4j.java.{Rabit, DMatrix => JDMatrix}
|
import ml.dmlc.xgboost4j.java.{Rabit, DMatrix => JDMatrix}
|
||||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
|
import ml.dmlc.xgboost4j.scala.{EvalTrait, Booster, DMatrix}
|
||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
|
|
||||||
class XGBoostModel(_booster: Booster) extends Serializable {
|
class XGBoostModel(_booster: Booster) extends Serializable {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* evaluate XGBoostModel with a RDD-wrapped dataset
|
||||||
|
*
|
||||||
|
* @param evalDataset the dataset used for evaluation
|
||||||
|
* @param eval the customized evaluation function, can be null for using default in the model
|
||||||
|
* @param useExternalCache if use external cache
|
||||||
|
* @return the average metric over all partitions
|
||||||
|
*/
|
||||||
|
def eval(
|
||||||
|
evalDataset: RDD[LabeledPoint],
|
||||||
|
eval: EvalTrait,
|
||||||
|
evalName: String,
|
||||||
|
useExternalCache: Boolean = false): String = {
|
||||||
|
val appName = evalDataset.context.appName
|
||||||
|
val allEvalMetrics = evalDataset.mapPartitions {
|
||||||
|
labeledPointsPartition =>
|
||||||
|
if (labeledPointsPartition.hasNext) {
|
||||||
|
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
||||||
|
Rabit.init(rabitEnv.asJava)
|
||||||
|
import DataUtils._
|
||||||
|
val cacheFileName = {
|
||||||
|
if (useExternalCache) {
|
||||||
|
s"$appName-deval_cache-${TaskContext.getPartitionId()}"
|
||||||
|
} else {
|
||||||
|
null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
val dMatrix = new DMatrix(labeledPointsPartition, cacheFileName)
|
||||||
|
val predictions = _booster.predict(dMatrix)
|
||||||
|
Rabit.shutdown()
|
||||||
|
Iterator(Some(eval.eval(predictions, dMatrix)))
|
||||||
|
} else {
|
||||||
|
Iterator(None)
|
||||||
|
}
|
||||||
|
}.filter(_.isDefined).collect()
|
||||||
|
s"$evalName-${eval.getMetric} = ${allEvalMetrics.map(_.get).sum / allEvalMetrics.length}"
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Predict result with the given test set (represented as RDD)
|
* Predict result with the given test set (represented as RDD)
|
||||||
*
|
*
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user