default eval func (#1574)
This commit is contained in:
parent
4733357278
commit
bb388cbb31
@ -42,16 +42,20 @@ class XGBoostModel(_booster: Booster) extends Model[XGBoostModel] with Serializa
|
||||
/**
|
||||
* evaluate XGBoostModel with a RDD-wrapped dataset
|
||||
*
|
||||
* NOTE: you have to specify value of either eval or iter; when you specify both, this adopts
|
||||
* the default eval metric of model
|
||||
*
|
||||
* @param evalDataset the dataset used for evaluation
|
||||
* @param eval the customized evaluation function, can be null for using default in the model
|
||||
* @param evalName the name of evaluation
|
||||
* @param evalFunc the customized evaluation function, null by default to use the default metric
|
||||
* of model
|
||||
* @param iter the current iteration, -1 to be null to use customized evaluation functions
|
||||
* @param useExternalCache if use external cache
|
||||
* @return the average metric over all partitions
|
||||
*/
|
||||
def eval(
|
||||
evalDataset: RDD[LabeledPoint],
|
||||
eval: EvalTrait,
|
||||
evalName: String,
|
||||
useExternalCache: Boolean = false): String = {
|
||||
def eval(evalDataset: RDD[LabeledPoint], evalName: String, evalFunc: EvalTrait = null,
|
||||
iter: Int = -1, useExternalCache: Boolean = false): String = {
|
||||
require(evalFunc != null || iter != -1, "you have to specify value of either eval or iter")
|
||||
val broadcastBooster = evalDataset.sparkContext.broadcast(_booster)
|
||||
val appName = evalDataset.context.appName
|
||||
val allEvalMetrics = evalDataset.mapPartitions {
|
||||
@ -62,20 +66,29 @@ class XGBoostModel(_booster: Booster) extends Model[XGBoostModel] with Serializa
|
||||
import DataUtils._
|
||||
val cacheFileName = {
|
||||
if (useExternalCache) {
|
||||
s"$appName-deval_cache-${TaskContext.getPartitionId()}"
|
||||
s"$appName-${TaskContext.get().stageId()}-deval_cache-${TaskContext.getPartitionId()}"
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}
|
||||
val dMatrix = new DMatrix(labeledPointsPartition, cacheFileName)
|
||||
val predictions = broadcastBooster.value.predict(dMatrix)
|
||||
Rabit.shutdown()
|
||||
Iterator(Some(eval.eval(predictions, dMatrix)))
|
||||
if (iter == -1) {
|
||||
val predictions = broadcastBooster.value.predict(dMatrix)
|
||||
Rabit.shutdown()
|
||||
Iterator(Some((evalName, evalFunc.eval(predictions, dMatrix))))
|
||||
} else {
|
||||
val predStr = broadcastBooster.value.evalSet(Array(dMatrix), Array(evalName), iter)
|
||||
val Array(evName, predNumeric) = predStr.split(":")
|
||||
Rabit.shutdown()
|
||||
Iterator(Some(evName, predNumeric.toFloat))
|
||||
}
|
||||
} else {
|
||||
Iterator(None)
|
||||
}
|
||||
}.filter(_.isDefined).collect()
|
||||
s"$evalName-${eval.getMetric} = ${allEvalMetrics.map(_.get).sum / allEvalMetrics.length}"
|
||||
val evalPrefix = allEvalMetrics.map(_.get._1).head
|
||||
val evalMetricMean = allEvalMetrics.map(_.get._2).sum / allEvalMetrics.length
|
||||
s"$evalPrefix = $evalMetricMean"
|
||||
}
|
||||
|
||||
/**
|
||||
@ -176,6 +189,7 @@ class XGBoostModel(_booster: Booster) extends Model[XGBoostModel] with Serializa
|
||||
/**
|
||||
* produces the prediction results and append as an additional column in the original dataset
|
||||
* NOTE: the prediction results is kept as the original format of xgboost
|
||||
*
|
||||
* @return the original dataframe with an additional column containing prediction results
|
||||
*/
|
||||
override def transform(testSet: Dataset[_]): DataFrame = {
|
||||
@ -186,6 +200,7 @@ class XGBoostModel(_booster: Booster) extends Model[XGBoostModel] with Serializa
|
||||
* produces the prediction results and append as an additional column in the original dataset
|
||||
* NOTE: the prediction results is transformed by applying the transformation function
|
||||
* predictResultTrans to the original xgboost output
|
||||
*
|
||||
* @param predictResultTrans the function to transform xgboost output to the expected format
|
||||
* @return the original dataframe with an additional column containing prediction results
|
||||
*/
|
||||
|
||||
@ -0,0 +1,63 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait}
|
||||
import org.apache.commons.logging.LogFactory
|
||||
|
||||
class EvalError extends EvalTrait {
|
||||
|
||||
val logger = LogFactory.getLog(classOf[EvalError])
|
||||
|
||||
private[xgboost4j] var evalMetric: String = "custom_error"
|
||||
|
||||
/**
|
||||
* get evaluate metric
|
||||
*
|
||||
* @return evalMetric
|
||||
*/
|
||||
override def getMetric: String = evalMetric
|
||||
|
||||
/**
|
||||
* evaluate with predicts and data
|
||||
*
|
||||
* @param predicts predictions as array
|
||||
* @param dmat data matrix to evaluate
|
||||
* @return result of the metric
|
||||
*/
|
||||
override def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float = {
|
||||
var error: Float = 0f
|
||||
var labels: Array[Float] = null
|
||||
try {
|
||||
labels = dmat.getLabel
|
||||
} catch {
|
||||
case ex: XGBoostError =>
|
||||
logger.error(ex)
|
||||
return -1f
|
||||
}
|
||||
val nrow: Int = predicts.length
|
||||
for (i <- 0 until nrow) {
|
||||
if (labels(i) == 0.0 && predicts(i)(0) > 0) {
|
||||
error += 1
|
||||
} else if (labels(i) == 1.0 && predicts(i)(0) <= 0) {
|
||||
error += 1
|
||||
}
|
||||
}
|
||||
error / labels.length
|
||||
}
|
||||
}
|
||||
@ -19,9 +19,9 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
import org.apache.spark.{SparkConf, SparkContext}
|
||||
import org.scalatest.{BeforeAndAfter, FunSuite}
|
||||
|
||||
trait SharedSparkContext extends FunSuite with BeforeAndAfter {
|
||||
class SharedSparkContext extends FunSuite with BeforeAndAfter with Serializable {
|
||||
|
||||
protected implicit var sc: SparkContext = null
|
||||
@transient protected implicit var sc: SparkContext = null
|
||||
|
||||
before {
|
||||
// build SparkContext
|
||||
|
||||
@ -29,51 +29,9 @@ import org.apache.spark.mllib.linalg.{DenseVector, Vector => SparkVector}
|
||||
import org.apache.spark.mllib.regression.LabeledPoint
|
||||
import org.apache.spark.rdd.RDD
|
||||
|
||||
trait Utils extends SharedSparkContext {
|
||||
trait Utils extends Serializable {
|
||||
protected val numWorkers = Runtime.getRuntime().availableProcessors()
|
||||
|
||||
protected class EvalError extends EvalTrait {
|
||||
|
||||
val logger = LogFactory.getLog(classOf[EvalError])
|
||||
|
||||
private[xgboost4j] var evalMetric: String = "custom_error"
|
||||
|
||||
/**
|
||||
* get evaluate metric
|
||||
*
|
||||
* @return evalMetric
|
||||
*/
|
||||
override def getMetric: String = evalMetric
|
||||
|
||||
/**
|
||||
* evaluate with predicts and data
|
||||
*
|
||||
* @param predicts predictions as array
|
||||
* @param dmat data matrix to evaluate
|
||||
* @return result of the metric
|
||||
*/
|
||||
override def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float = {
|
||||
var error: Float = 0f
|
||||
var labels: Array[Float] = null
|
||||
try {
|
||||
labels = dmat.getLabel
|
||||
} catch {
|
||||
case ex: XGBoostError =>
|
||||
logger.error(ex)
|
||||
return -1f
|
||||
}
|
||||
val nrow: Int = predicts.length
|
||||
for (i <- 0 until nrow) {
|
||||
if (labels(i) == 0.0 && predicts(i)(0) > 0) {
|
||||
error += 1
|
||||
} else if (labels(i) == 1.0 && predicts(i)(0) <= 0) {
|
||||
error += 1
|
||||
}
|
||||
}
|
||||
error / labels.length
|
||||
}
|
||||
}
|
||||
|
||||
protected def loadLabelPoints(filePath: String): List[LabeledPoint] = {
|
||||
val file = Source.fromFile(new File(filePath))
|
||||
val sampleList = new ListBuffer[LabeledPoint]
|
||||
@ -100,8 +58,8 @@ trait Utils extends SharedSparkContext {
|
||||
LabeledPoint(label, sv)
|
||||
}
|
||||
|
||||
protected def buildTrainingRDD(sparkContext: Option[SparkContext] = None): RDD[LabeledPoint] = {
|
||||
protected def buildTrainingRDD(sparkContext: SparkContext): RDD[LabeledPoint] = {
|
||||
val sampleList = loadLabelPoints(getClass.getResource("/agaricus.txt.train").getFile)
|
||||
sparkContext.getOrElse(sc).parallelize(sampleList, numWorkers)
|
||||
sparkContext.parallelize(sampleList, numWorkers)
|
||||
}
|
||||
}
|
||||
|
||||
@ -30,7 +30,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructField, StructType}
|
||||
|
||||
class XGBoostDFSuite extends Utils {
|
||||
class XGBoostDFSuite extends SharedSparkContext with Utils {
|
||||
|
||||
private def loadRow(filePath: String): List[Row] = {
|
||||
val file = Source.fromFile(new File(filePath))
|
||||
@ -58,7 +58,7 @@ class XGBoostDFSuite extends Utils {
|
||||
|
||||
test("test consistency between training with dataframe and RDD") {
|
||||
val trainingDF = buildTrainingDataframe()
|
||||
val trainingRDD = buildTrainingRDD()
|
||||
val trainingRDD = buildTrainingRDD(sc)
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "0",
|
||||
"objective" -> "binary:logistic").toMap
|
||||
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
||||
|
||||
@ -29,10 +29,10 @@ import org.apache.spark.mllib.regression.LabeledPoint
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.{SparkConf, SparkContext}
|
||||
|
||||
class XGBoostGeneralSuite extends Utils {
|
||||
class XGBoostGeneralSuite extends SharedSparkContext with Utils {
|
||||
|
||||
test("build RDD containing boosters with the specified worker number") {
|
||||
val trainingRDD = buildTrainingRDD()
|
||||
val trainingRDD = buildTrainingRDD(sc)
|
||||
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
||||
import DataUtils._
|
||||
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
||||
@ -60,7 +60,7 @@ class XGBoostGeneralSuite extends Utils {
|
||||
val customSparkContext = new SparkContext(sparkConf)
|
||||
customSparkContext.setLogLevel("ERROR")
|
||||
val eval = new EvalError()
|
||||
val trainingRDD = buildTrainingRDD(Some(customSparkContext))
|
||||
val trainingRDD = buildTrainingRDD(customSparkContext)
|
||||
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
||||
import DataUtils._
|
||||
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
||||
@ -113,7 +113,7 @@ class XGBoostGeneralSuite extends Utils {
|
||||
}
|
||||
|
||||
test("test consistency of prediction functions with RDD") {
|
||||
val trainingRDD = buildTrainingRDD()
|
||||
val trainingRDD = buildTrainingRDD(sc)
|
||||
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile)
|
||||
val testRDD = sc.parallelize(testSet, numSlices = 1).map(_.features)
|
||||
val testCollection = testRDD.collect()
|
||||
@ -132,12 +132,22 @@ class XGBoostGeneralSuite extends Utils {
|
||||
}
|
||||
}
|
||||
|
||||
test("test eval functions with RDD") {
|
||||
val trainingRDD = buildTrainingRDD(sc)
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
||||
"objective" -> "binary:logistic").toMap
|
||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
|
||||
val evalFunc = new EvalError
|
||||
xgBoostModel.eval(trainingRDD, "eval1", iter = 5, useExternalCache = false)
|
||||
xgBoostModel.eval(trainingRDD, "eval2", evalFunc = evalFunc, useExternalCache = false)
|
||||
}
|
||||
|
||||
test("test prediction functionality with empty partition") {
|
||||
def buildEmptyRDD(sparkContext: Option[SparkContext] = None): RDD[SparkVector] = {
|
||||
val sampleList = new ListBuffer[SparkVector]
|
||||
sparkContext.getOrElse(sc).parallelize(sampleList, numWorkers)
|
||||
}
|
||||
val trainingRDD = buildTrainingRDD()
|
||||
val trainingRDD = buildTrainingRDD(sc)
|
||||
val testRDD = buildEmptyRDD()
|
||||
val tempDir = Files.createTempDirectory("xgboosttest-")
|
||||
val tempFile = Files.createTempFile(tempDir, "", "")
|
||||
@ -149,7 +159,7 @@ class XGBoostGeneralSuite extends Utils {
|
||||
|
||||
test("test model consistency after save and load") {
|
||||
val eval = new EvalError()
|
||||
val trainingRDD = buildTrainingRDD()
|
||||
val trainingRDD = buildTrainingRDD(sc)
|
||||
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
||||
import DataUtils._
|
||||
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
||||
@ -176,7 +186,7 @@ class XGBoostGeneralSuite extends Utils {
|
||||
val customSparkContext = new SparkContext(sparkConf)
|
||||
customSparkContext.setLogLevel("ERROR")
|
||||
// start another app
|
||||
val trainingRDD = buildTrainingRDD(Some(customSparkContext))
|
||||
val trainingRDD = buildTrainingRDD(customSparkContext)
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
||||
"objective" -> "binary:logistic", "nthread" -> 6).toMap
|
||||
intercept[IllegalArgumentException] {
|
||||
@ -194,7 +204,7 @@ class XGBoostGeneralSuite extends Utils {
|
||||
sparkConf.registerKryoClasses(Array(classOf[Booster]))
|
||||
val customSparkContext = new SparkContext(sparkConf)
|
||||
customSparkContext.setLogLevel("ERROR")
|
||||
val trainingRDD = buildTrainingRDD(Some(customSparkContext))
|
||||
val trainingRDD = buildTrainingRDD(customSparkContext)
|
||||
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
||||
import DataUtils._
|
||||
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user