default eval func (#1574)
This commit is contained in:
@@ -42,16 +42,20 @@ class XGBoostModel(_booster: Booster) extends Model[XGBoostModel] with Serializa
|
||||
/**
|
||||
* evaluate XGBoostModel with a RDD-wrapped dataset
|
||||
*
|
||||
* NOTE: you have to specify value of either eval or iter; when you specify both, this adopts
|
||||
* the default eval metric of model
|
||||
*
|
||||
* @param evalDataset the dataset used for evaluation
|
||||
* @param eval the customized evaluation function, can be null for using default in the model
|
||||
* @param evalName the name of evaluation
|
||||
* @param evalFunc the customized evaluation function, null by default to use the default metric
|
||||
* of model
|
||||
* @param iter the current iteration, -1 to be null to use customized evaluation functions
|
||||
* @param useExternalCache if use external cache
|
||||
* @return the average metric over all partitions
|
||||
*/
|
||||
def eval(
|
||||
evalDataset: RDD[LabeledPoint],
|
||||
eval: EvalTrait,
|
||||
evalName: String,
|
||||
useExternalCache: Boolean = false): String = {
|
||||
def eval(evalDataset: RDD[LabeledPoint], evalName: String, evalFunc: EvalTrait = null,
|
||||
iter: Int = -1, useExternalCache: Boolean = false): String = {
|
||||
require(evalFunc != null || iter != -1, "you have to specify value of either eval or iter")
|
||||
val broadcastBooster = evalDataset.sparkContext.broadcast(_booster)
|
||||
val appName = evalDataset.context.appName
|
||||
val allEvalMetrics = evalDataset.mapPartitions {
|
||||
@@ -62,20 +66,29 @@ class XGBoostModel(_booster: Booster) extends Model[XGBoostModel] with Serializa
|
||||
import DataUtils._
|
||||
val cacheFileName = {
|
||||
if (useExternalCache) {
|
||||
s"$appName-deval_cache-${TaskContext.getPartitionId()}"
|
||||
s"$appName-${TaskContext.get().stageId()}-deval_cache-${TaskContext.getPartitionId()}"
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}
|
||||
val dMatrix = new DMatrix(labeledPointsPartition, cacheFileName)
|
||||
val predictions = broadcastBooster.value.predict(dMatrix)
|
||||
Rabit.shutdown()
|
||||
Iterator(Some(eval.eval(predictions, dMatrix)))
|
||||
if (iter == -1) {
|
||||
val predictions = broadcastBooster.value.predict(dMatrix)
|
||||
Rabit.shutdown()
|
||||
Iterator(Some((evalName, evalFunc.eval(predictions, dMatrix))))
|
||||
} else {
|
||||
val predStr = broadcastBooster.value.evalSet(Array(dMatrix), Array(evalName), iter)
|
||||
val Array(evName, predNumeric) = predStr.split(":")
|
||||
Rabit.shutdown()
|
||||
Iterator(Some(evName, predNumeric.toFloat))
|
||||
}
|
||||
} else {
|
||||
Iterator(None)
|
||||
}
|
||||
}.filter(_.isDefined).collect()
|
||||
s"$evalName-${eval.getMetric} = ${allEvalMetrics.map(_.get).sum / allEvalMetrics.length}"
|
||||
val evalPrefix = allEvalMetrics.map(_.get._1).head
|
||||
val evalMetricMean = allEvalMetrics.map(_.get._2).sum / allEvalMetrics.length
|
||||
s"$evalPrefix = $evalMetricMean"
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -176,6 +189,7 @@ class XGBoostModel(_booster: Booster) extends Model[XGBoostModel] with Serializa
|
||||
/**
|
||||
* produces the prediction results and append as an additional column in the original dataset
|
||||
* NOTE: the prediction results is kept as the original format of xgboost
|
||||
*
|
||||
* @return the original dataframe with an additional column containing prediction results
|
||||
*/
|
||||
override def transform(testSet: Dataset[_]): DataFrame = {
|
||||
@@ -186,6 +200,7 @@ class XGBoostModel(_booster: Booster) extends Model[XGBoostModel] with Serializa
|
||||
* produces the prediction results and append as an additional column in the original dataset
|
||||
* NOTE: the prediction results is transformed by applying the transformation function
|
||||
* predictResultTrans to the original xgboost output
|
||||
*
|
||||
* @param predictResultTrans the function to transform xgboost output to the expected format
|
||||
* @return the original dataframe with an additional column containing prediction results
|
||||
*/
|
||||
|
||||
Reference in New Issue
Block a user