default eval func (#1574)

This commit is contained in:
Nan Zhu
2016-09-14 13:26:16 -04:00
committed by GitHub
parent 4733357278
commit bb388cbb31
6 changed files with 114 additions and 68 deletions

View File

@@ -42,16 +42,20 @@ class XGBoostModel(_booster: Booster) extends Model[XGBoostModel] with Serializa
/**
* evaluate XGBoostModel with a RDD-wrapped dataset
*
* NOTE: you have to specify value of either eval or iter; when you specify both, this adopts
* the default eval metric of model
*
* @param evalDataset the dataset used for evaluation
* @param eval the customized evaluation function, can be null for using default in the model
* @param evalName the name of evaluation
* @param evalFunc the customized evaluation function, null by default to use the default metric
* of model
* @param iter the current iteration, -1 to be null to use customized evaluation functions
* @param useExternalCache if use external cache
* @return the average metric over all partitions
*/
def eval(
evalDataset: RDD[LabeledPoint],
eval: EvalTrait,
evalName: String,
useExternalCache: Boolean = false): String = {
def eval(evalDataset: RDD[LabeledPoint], evalName: String, evalFunc: EvalTrait = null,
iter: Int = -1, useExternalCache: Boolean = false): String = {
require(evalFunc != null || iter != -1, "you have to specify value of either eval or iter")
val broadcastBooster = evalDataset.sparkContext.broadcast(_booster)
val appName = evalDataset.context.appName
val allEvalMetrics = evalDataset.mapPartitions {
@@ -62,20 +66,29 @@ class XGBoostModel(_booster: Booster) extends Model[XGBoostModel] with Serializa
import DataUtils._
val cacheFileName = {
if (useExternalCache) {
s"$appName-deval_cache-${TaskContext.getPartitionId()}"
s"$appName-${TaskContext.get().stageId()}-deval_cache-${TaskContext.getPartitionId()}"
} else {
null
}
}
val dMatrix = new DMatrix(labeledPointsPartition, cacheFileName)
val predictions = broadcastBooster.value.predict(dMatrix)
Rabit.shutdown()
Iterator(Some(eval.eval(predictions, dMatrix)))
if (iter == -1) {
val predictions = broadcastBooster.value.predict(dMatrix)
Rabit.shutdown()
Iterator(Some((evalName, evalFunc.eval(predictions, dMatrix))))
} else {
val predStr = broadcastBooster.value.evalSet(Array(dMatrix), Array(evalName), iter)
val Array(evName, predNumeric) = predStr.split(":")
Rabit.shutdown()
Iterator(Some(evName, predNumeric.toFloat))
}
} else {
Iterator(None)
}
}.filter(_.isDefined).collect()
s"$evalName-${eval.getMetric} = ${allEvalMetrics.map(_.get).sum / allEvalMetrics.length}"
val evalPrefix = allEvalMetrics.map(_.get._1).head
val evalMetricMean = allEvalMetrics.map(_.get._2).sum / allEvalMetrics.length
s"$evalPrefix = $evalMetricMean"
}
/**
@@ -176,6 +189,7 @@ class XGBoostModel(_booster: Booster) extends Model[XGBoostModel] with Serializa
/**
* produces the prediction results and append as an additional column in the original dataset
* NOTE: the prediction results is kept as the original format of xgboost
*
* @return the original dataframe with an additional column containing prediction results
*/
override def transform(testSet: Dataset[_]): DataFrame = {
@@ -186,6 +200,7 @@ class XGBoostModel(_booster: Booster) extends Model[XGBoostModel] with Serializa
* produces the prediction results and append as an additional column in the original dataset
* NOTE: the prediction results is transformed by applying the transformation function
* predictResultTrans to the original xgboost output
*
* @param predictResultTrans the function to transform xgboost output to the expected format
* @return the original dataframe with an additional column containing prediction results
*/