default eval func (#1574)

This commit is contained in:
Nan Zhu 2016-09-14 13:26:16 -04:00 committed by GitHub
parent 4733357278
commit bb388cbb31
6 changed files with 114 additions and 68 deletions

View File

@ -42,16 +42,20 @@ class XGBoostModel(_booster: Booster) extends Model[XGBoostModel] with Serializa
/**
* evaluate XGBoostModel with a RDD-wrapped dataset
*
* NOTE: you have to specify value of either eval or iter; when you specify both, this adopts
* the default eval metric of model
*
* @param evalDataset the dataset used for evaluation
* @param eval the customized evaluation function, can be null for using default in the model
* @param evalName the name of evaluation
* @param evalFunc the customized evaluation function, null by default to use the default metric
* of model
* @param iter the current iteration, -1 to be null to use customized evaluation functions
* @param useExternalCache if use external cache
* @return the average metric over all partitions
*/
def eval(
evalDataset: RDD[LabeledPoint],
eval: EvalTrait,
evalName: String,
useExternalCache: Boolean = false): String = {
def eval(evalDataset: RDD[LabeledPoint], evalName: String, evalFunc: EvalTrait = null,
iter: Int = -1, useExternalCache: Boolean = false): String = {
require(evalFunc != null || iter != -1, "you have to specify value of either eval or iter")
val broadcastBooster = evalDataset.sparkContext.broadcast(_booster)
val appName = evalDataset.context.appName
val allEvalMetrics = evalDataset.mapPartitions {
@ -62,20 +66,29 @@ class XGBoostModel(_booster: Booster) extends Model[XGBoostModel] with Serializa
import DataUtils._
val cacheFileName = {
if (useExternalCache) {
s"$appName-deval_cache-${TaskContext.getPartitionId()}"
s"$appName-${TaskContext.get().stageId()}-deval_cache-${TaskContext.getPartitionId()}"
} else {
null
}
}
val dMatrix = new DMatrix(labeledPointsPartition, cacheFileName)
val predictions = broadcastBooster.value.predict(dMatrix)
Rabit.shutdown()
Iterator(Some(eval.eval(predictions, dMatrix)))
if (iter == -1) {
val predictions = broadcastBooster.value.predict(dMatrix)
Rabit.shutdown()
Iterator(Some((evalName, evalFunc.eval(predictions, dMatrix))))
} else {
val predStr = broadcastBooster.value.evalSet(Array(dMatrix), Array(evalName), iter)
val Array(evName, predNumeric) = predStr.split(":")
Rabit.shutdown()
Iterator(Some(evName, predNumeric.toFloat))
}
} else {
Iterator(None)
}
}.filter(_.isDefined).collect()
s"$evalName-${eval.getMetric} = ${allEvalMetrics.map(_.get).sum / allEvalMetrics.length}"
val evalPrefix = allEvalMetrics.map(_.get._1).head
val evalMetricMean = allEvalMetrics.map(_.get._2).sum / allEvalMetrics.length
s"$evalPrefix = $evalMetricMean"
}
/**
@ -176,6 +189,7 @@ class XGBoostModel(_booster: Booster) extends Model[XGBoostModel] with Serializa
/**
* produces the prediction results and append as an additional column in the original dataset
* NOTE: the prediction results is kept as the original format of xgboost
*
* @return the original dataframe with an additional column containing prediction results
*/
override def transform(testSet: Dataset[_]): DataFrame = {
@ -186,6 +200,7 @@ class XGBoostModel(_booster: Booster) extends Model[XGBoostModel] with Serializa
* produces the prediction results and append as an additional column in the original dataset
* NOTE: the prediction results is transformed by applying the transformation function
* predictResultTrans to the original xgboost output
*
* @param predictResultTrans the function to transform xgboost output to the expected format
* @return the original dataframe with an additional column containing prediction results
*/

View File

@ -0,0 +1,63 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.java.XGBoostError
import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait}
import org.apache.commons.logging.LogFactory
class EvalError extends EvalTrait {
val logger = LogFactory.getLog(classOf[EvalError])
private[xgboost4j] var evalMetric: String = "custom_error"
/**
* get evaluate metric
*
* @return evalMetric
*/
override def getMetric: String = evalMetric
/**
* evaluate with predicts and data
*
* @param predicts predictions as array
* @param dmat data matrix to evaluate
* @return result of the metric
*/
override def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float = {
var error: Float = 0f
var labels: Array[Float] = null
try {
labels = dmat.getLabel
} catch {
case ex: XGBoostError =>
logger.error(ex)
return -1f
}
val nrow: Int = predicts.length
for (i <- 0 until nrow) {
if (labels(i) == 0.0 && predicts(i)(0) > 0) {
error += 1
} else if (labels(i) == 1.0 && predicts(i)(0) <= 0) {
error += 1
}
}
error / labels.length
}
}

View File

@ -19,9 +19,9 @@ package ml.dmlc.xgboost4j.scala.spark
import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest.{BeforeAndAfter, FunSuite}
trait SharedSparkContext extends FunSuite with BeforeAndAfter {
class SharedSparkContext extends FunSuite with BeforeAndAfter with Serializable {
protected implicit var sc: SparkContext = null
@transient protected implicit var sc: SparkContext = null
before {
// build SparkContext

View File

@ -29,51 +29,9 @@ import org.apache.spark.mllib.linalg.{DenseVector, Vector => SparkVector}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
trait Utils extends SharedSparkContext {
trait Utils extends Serializable {
protected val numWorkers = Runtime.getRuntime().availableProcessors()
protected class EvalError extends EvalTrait {
val logger = LogFactory.getLog(classOf[EvalError])
private[xgboost4j] var evalMetric: String = "custom_error"
/**
* get evaluate metric
*
* @return evalMetric
*/
override def getMetric: String = evalMetric
/**
* evaluate with predicts and data
*
* @param predicts predictions as array
* @param dmat data matrix to evaluate
* @return result of the metric
*/
override def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float = {
var error: Float = 0f
var labels: Array[Float] = null
try {
labels = dmat.getLabel
} catch {
case ex: XGBoostError =>
logger.error(ex)
return -1f
}
val nrow: Int = predicts.length
for (i <- 0 until nrow) {
if (labels(i) == 0.0 && predicts(i)(0) > 0) {
error += 1
} else if (labels(i) == 1.0 && predicts(i)(0) <= 0) {
error += 1
}
}
error / labels.length
}
}
protected def loadLabelPoints(filePath: String): List[LabeledPoint] = {
val file = Source.fromFile(new File(filePath))
val sampleList = new ListBuffer[LabeledPoint]
@ -100,8 +58,8 @@ trait Utils extends SharedSparkContext {
LabeledPoint(label, sv)
}
protected def buildTrainingRDD(sparkContext: Option[SparkContext] = None): RDD[LabeledPoint] = {
protected def buildTrainingRDD(sparkContext: SparkContext): RDD[LabeledPoint] = {
val sampleList = loadLabelPoints(getClass.getResource("/agaricus.txt.train").getFile)
sparkContext.getOrElse(sc).parallelize(sampleList, numWorkers)
sparkContext.parallelize(sampleList, numWorkers)
}
}

View File

@ -30,7 +30,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.sql._
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructField, StructType}
class XGBoostDFSuite extends Utils {
class XGBoostDFSuite extends SharedSparkContext with Utils {
private def loadRow(filePath: String): List[Row] = {
val file = Source.fromFile(new File(filePath))
@ -58,7 +58,7 @@ class XGBoostDFSuite extends Utils {
test("test consistency between training with dataframe and RDD") {
val trainingDF = buildTrainingDataframe()
val trainingRDD = buildTrainingRDD()
val trainingRDD = buildTrainingRDD(sc)
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "0",
"objective" -> "binary:logistic").toMap
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,

View File

@ -29,10 +29,10 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkConf, SparkContext}
class XGBoostGeneralSuite extends Utils {
class XGBoostGeneralSuite extends SharedSparkContext with Utils {
test("build RDD containing boosters with the specified worker number") {
val trainingRDD = buildTrainingRDD()
val trainingRDD = buildTrainingRDD(sc)
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
import DataUtils._
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
@ -60,7 +60,7 @@ class XGBoostGeneralSuite extends Utils {
val customSparkContext = new SparkContext(sparkConf)
customSparkContext.setLogLevel("ERROR")
val eval = new EvalError()
val trainingRDD = buildTrainingRDD(Some(customSparkContext))
val trainingRDD = buildTrainingRDD(customSparkContext)
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
import DataUtils._
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
@ -113,7 +113,7 @@ class XGBoostGeneralSuite extends Utils {
}
test("test consistency of prediction functions with RDD") {
val trainingRDD = buildTrainingRDD()
val trainingRDD = buildTrainingRDD(sc)
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile)
val testRDD = sc.parallelize(testSet, numSlices = 1).map(_.features)
val testCollection = testRDD.collect()
@ -132,12 +132,22 @@ class XGBoostGeneralSuite extends Utils {
}
}
test("test eval functions with RDD") {
val trainingRDD = buildTrainingRDD(sc)
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
"objective" -> "binary:logistic").toMap
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
val evalFunc = new EvalError
xgBoostModel.eval(trainingRDD, "eval1", iter = 5, useExternalCache = false)
xgBoostModel.eval(trainingRDD, "eval2", evalFunc = evalFunc, useExternalCache = false)
}
test("test prediction functionality with empty partition") {
def buildEmptyRDD(sparkContext: Option[SparkContext] = None): RDD[SparkVector] = {
val sampleList = new ListBuffer[SparkVector]
sparkContext.getOrElse(sc).parallelize(sampleList, numWorkers)
}
val trainingRDD = buildTrainingRDD()
val trainingRDD = buildTrainingRDD(sc)
val testRDD = buildEmptyRDD()
val tempDir = Files.createTempDirectory("xgboosttest-")
val tempFile = Files.createTempFile(tempDir, "", "")
@ -149,7 +159,7 @@ class XGBoostGeneralSuite extends Utils {
test("test model consistency after save and load") {
val eval = new EvalError()
val trainingRDD = buildTrainingRDD()
val trainingRDD = buildTrainingRDD(sc)
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
import DataUtils._
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
@ -176,7 +186,7 @@ class XGBoostGeneralSuite extends Utils {
val customSparkContext = new SparkContext(sparkConf)
customSparkContext.setLogLevel("ERROR")
// start another app
val trainingRDD = buildTrainingRDD(Some(customSparkContext))
val trainingRDD = buildTrainingRDD(customSparkContext)
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
"objective" -> "binary:logistic", "nthread" -> 6).toMap
intercept[IllegalArgumentException] {
@ -194,7 +204,7 @@ class XGBoostGeneralSuite extends Utils {
sparkConf.registerKryoClasses(Array(classOf[Booster]))
val customSparkContext = new SparkContext(sparkConf)
customSparkContext.setLogLevel("ERROR")
val trainingRDD = buildTrainingRDD(Some(customSparkContext))
val trainingRDD = buildTrainingRDD(customSparkContext)
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
import DataUtils._
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))