default eval func (#1574)
This commit is contained in:
parent
4733357278
commit
bb388cbb31
@ -42,16 +42,20 @@ class XGBoostModel(_booster: Booster) extends Model[XGBoostModel] with Serializa
|
|||||||
/**
|
/**
|
||||||
* evaluate XGBoostModel with a RDD-wrapped dataset
|
* evaluate XGBoostModel with a RDD-wrapped dataset
|
||||||
*
|
*
|
||||||
|
* NOTE: you have to specify value of either eval or iter; when you specify both, this adopts
|
||||||
|
* the default eval metric of model
|
||||||
|
*
|
||||||
* @param evalDataset the dataset used for evaluation
|
* @param evalDataset the dataset used for evaluation
|
||||||
* @param eval the customized evaluation function, can be null for using default in the model
|
* @param evalName the name of evaluation
|
||||||
|
* @param evalFunc the customized evaluation function, null by default to use the default metric
|
||||||
|
* of model
|
||||||
|
* @param iter the current iteration, -1 to be null to use customized evaluation functions
|
||||||
* @param useExternalCache if use external cache
|
* @param useExternalCache if use external cache
|
||||||
* @return the average metric over all partitions
|
* @return the average metric over all partitions
|
||||||
*/
|
*/
|
||||||
def eval(
|
def eval(evalDataset: RDD[LabeledPoint], evalName: String, evalFunc: EvalTrait = null,
|
||||||
evalDataset: RDD[LabeledPoint],
|
iter: Int = -1, useExternalCache: Boolean = false): String = {
|
||||||
eval: EvalTrait,
|
require(evalFunc != null || iter != -1, "you have to specify value of either eval or iter")
|
||||||
evalName: String,
|
|
||||||
useExternalCache: Boolean = false): String = {
|
|
||||||
val broadcastBooster = evalDataset.sparkContext.broadcast(_booster)
|
val broadcastBooster = evalDataset.sparkContext.broadcast(_booster)
|
||||||
val appName = evalDataset.context.appName
|
val appName = evalDataset.context.appName
|
||||||
val allEvalMetrics = evalDataset.mapPartitions {
|
val allEvalMetrics = evalDataset.mapPartitions {
|
||||||
@ -62,20 +66,29 @@ class XGBoostModel(_booster: Booster) extends Model[XGBoostModel] with Serializa
|
|||||||
import DataUtils._
|
import DataUtils._
|
||||||
val cacheFileName = {
|
val cacheFileName = {
|
||||||
if (useExternalCache) {
|
if (useExternalCache) {
|
||||||
s"$appName-deval_cache-${TaskContext.getPartitionId()}"
|
s"$appName-${TaskContext.get().stageId()}-deval_cache-${TaskContext.getPartitionId()}"
|
||||||
} else {
|
} else {
|
||||||
null
|
null
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
val dMatrix = new DMatrix(labeledPointsPartition, cacheFileName)
|
val dMatrix = new DMatrix(labeledPointsPartition, cacheFileName)
|
||||||
val predictions = broadcastBooster.value.predict(dMatrix)
|
if (iter == -1) {
|
||||||
Rabit.shutdown()
|
val predictions = broadcastBooster.value.predict(dMatrix)
|
||||||
Iterator(Some(eval.eval(predictions, dMatrix)))
|
Rabit.shutdown()
|
||||||
|
Iterator(Some((evalName, evalFunc.eval(predictions, dMatrix))))
|
||||||
|
} else {
|
||||||
|
val predStr = broadcastBooster.value.evalSet(Array(dMatrix), Array(evalName), iter)
|
||||||
|
val Array(evName, predNumeric) = predStr.split(":")
|
||||||
|
Rabit.shutdown()
|
||||||
|
Iterator(Some(evName, predNumeric.toFloat))
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
Iterator(None)
|
Iterator(None)
|
||||||
}
|
}
|
||||||
}.filter(_.isDefined).collect()
|
}.filter(_.isDefined).collect()
|
||||||
s"$evalName-${eval.getMetric} = ${allEvalMetrics.map(_.get).sum / allEvalMetrics.length}"
|
val evalPrefix = allEvalMetrics.map(_.get._1).head
|
||||||
|
val evalMetricMean = allEvalMetrics.map(_.get._2).sum / allEvalMetrics.length
|
||||||
|
s"$evalPrefix = $evalMetricMean"
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -176,6 +189,7 @@ class XGBoostModel(_booster: Booster) extends Model[XGBoostModel] with Serializa
|
|||||||
/**
|
/**
|
||||||
* produces the prediction results and append as an additional column in the original dataset
|
* produces the prediction results and append as an additional column in the original dataset
|
||||||
* NOTE: the prediction results is kept as the original format of xgboost
|
* NOTE: the prediction results is kept as the original format of xgboost
|
||||||
|
*
|
||||||
* @return the original dataframe with an additional column containing prediction results
|
* @return the original dataframe with an additional column containing prediction results
|
||||||
*/
|
*/
|
||||||
override def transform(testSet: Dataset[_]): DataFrame = {
|
override def transform(testSet: Dataset[_]): DataFrame = {
|
||||||
@ -186,6 +200,7 @@ class XGBoostModel(_booster: Booster) extends Model[XGBoostModel] with Serializa
|
|||||||
* produces the prediction results and append as an additional column in the original dataset
|
* produces the prediction results and append as an additional column in the original dataset
|
||||||
* NOTE: the prediction results is transformed by applying the transformation function
|
* NOTE: the prediction results is transformed by applying the transformation function
|
||||||
* predictResultTrans to the original xgboost output
|
* predictResultTrans to the original xgboost output
|
||||||
|
*
|
||||||
* @param predictResultTrans the function to transform xgboost output to the expected format
|
* @param predictResultTrans the function to transform xgboost output to the expected format
|
||||||
* @return the original dataframe with an additional column containing prediction results
|
* @return the original dataframe with an additional column containing prediction results
|
||||||
*/
|
*/
|
||||||
|
|||||||
@ -0,0 +1,63 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.java.XGBoostError
|
||||||
|
import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait}
|
||||||
|
import org.apache.commons.logging.LogFactory
|
||||||
|
|
||||||
|
class EvalError extends EvalTrait {
|
||||||
|
|
||||||
|
val logger = LogFactory.getLog(classOf[EvalError])
|
||||||
|
|
||||||
|
private[xgboost4j] var evalMetric: String = "custom_error"
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get evaluate metric
|
||||||
|
*
|
||||||
|
* @return evalMetric
|
||||||
|
*/
|
||||||
|
override def getMetric: String = evalMetric
|
||||||
|
|
||||||
|
/**
|
||||||
|
* evaluate with predicts and data
|
||||||
|
*
|
||||||
|
* @param predicts predictions as array
|
||||||
|
* @param dmat data matrix to evaluate
|
||||||
|
* @return result of the metric
|
||||||
|
*/
|
||||||
|
override def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float = {
|
||||||
|
var error: Float = 0f
|
||||||
|
var labels: Array[Float] = null
|
||||||
|
try {
|
||||||
|
labels = dmat.getLabel
|
||||||
|
} catch {
|
||||||
|
case ex: XGBoostError =>
|
||||||
|
logger.error(ex)
|
||||||
|
return -1f
|
||||||
|
}
|
||||||
|
val nrow: Int = predicts.length
|
||||||
|
for (i <- 0 until nrow) {
|
||||||
|
if (labels(i) == 0.0 && predicts(i)(0) > 0) {
|
||||||
|
error += 1
|
||||||
|
} else if (labels(i) == 1.0 && predicts(i)(0) <= 0) {
|
||||||
|
error += 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
error / labels.length
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -19,9 +19,9 @@ package ml.dmlc.xgboost4j.scala.spark
|
|||||||
import org.apache.spark.{SparkConf, SparkContext}
|
import org.apache.spark.{SparkConf, SparkContext}
|
||||||
import org.scalatest.{BeforeAndAfter, FunSuite}
|
import org.scalatest.{BeforeAndAfter, FunSuite}
|
||||||
|
|
||||||
trait SharedSparkContext extends FunSuite with BeforeAndAfter {
|
class SharedSparkContext extends FunSuite with BeforeAndAfter with Serializable {
|
||||||
|
|
||||||
protected implicit var sc: SparkContext = null
|
@transient protected implicit var sc: SparkContext = null
|
||||||
|
|
||||||
before {
|
before {
|
||||||
// build SparkContext
|
// build SparkContext
|
||||||
|
|||||||
@ -29,51 +29,9 @@ import org.apache.spark.mllib.linalg.{DenseVector, Vector => SparkVector}
|
|||||||
import org.apache.spark.mllib.regression.LabeledPoint
|
import org.apache.spark.mllib.regression.LabeledPoint
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
|
|
||||||
trait Utils extends SharedSparkContext {
|
trait Utils extends Serializable {
|
||||||
protected val numWorkers = Runtime.getRuntime().availableProcessors()
|
protected val numWorkers = Runtime.getRuntime().availableProcessors()
|
||||||
|
|
||||||
protected class EvalError extends EvalTrait {
|
|
||||||
|
|
||||||
val logger = LogFactory.getLog(classOf[EvalError])
|
|
||||||
|
|
||||||
private[xgboost4j] var evalMetric: String = "custom_error"
|
|
||||||
|
|
||||||
/**
|
|
||||||
* get evaluate metric
|
|
||||||
*
|
|
||||||
* @return evalMetric
|
|
||||||
*/
|
|
||||||
override def getMetric: String = evalMetric
|
|
||||||
|
|
||||||
/**
|
|
||||||
* evaluate with predicts and data
|
|
||||||
*
|
|
||||||
* @param predicts predictions as array
|
|
||||||
* @param dmat data matrix to evaluate
|
|
||||||
* @return result of the metric
|
|
||||||
*/
|
|
||||||
override def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float = {
|
|
||||||
var error: Float = 0f
|
|
||||||
var labels: Array[Float] = null
|
|
||||||
try {
|
|
||||||
labels = dmat.getLabel
|
|
||||||
} catch {
|
|
||||||
case ex: XGBoostError =>
|
|
||||||
logger.error(ex)
|
|
||||||
return -1f
|
|
||||||
}
|
|
||||||
val nrow: Int = predicts.length
|
|
||||||
for (i <- 0 until nrow) {
|
|
||||||
if (labels(i) == 0.0 && predicts(i)(0) > 0) {
|
|
||||||
error += 1
|
|
||||||
} else if (labels(i) == 1.0 && predicts(i)(0) <= 0) {
|
|
||||||
error += 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
error / labels.length
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
protected def loadLabelPoints(filePath: String): List[LabeledPoint] = {
|
protected def loadLabelPoints(filePath: String): List[LabeledPoint] = {
|
||||||
val file = Source.fromFile(new File(filePath))
|
val file = Source.fromFile(new File(filePath))
|
||||||
val sampleList = new ListBuffer[LabeledPoint]
|
val sampleList = new ListBuffer[LabeledPoint]
|
||||||
@ -100,8 +58,8 @@ trait Utils extends SharedSparkContext {
|
|||||||
LabeledPoint(label, sv)
|
LabeledPoint(label, sv)
|
||||||
}
|
}
|
||||||
|
|
||||||
protected def buildTrainingRDD(sparkContext: Option[SparkContext] = None): RDD[LabeledPoint] = {
|
protected def buildTrainingRDD(sparkContext: SparkContext): RDD[LabeledPoint] = {
|
||||||
val sampleList = loadLabelPoints(getClass.getResource("/agaricus.txt.train").getFile)
|
val sampleList = loadLabelPoints(getClass.getResource("/agaricus.txt.train").getFile)
|
||||||
sparkContext.getOrElse(sc).parallelize(sampleList, numWorkers)
|
sparkContext.parallelize(sampleList, numWorkers)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -30,7 +30,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
|
|||||||
import org.apache.spark.sql._
|
import org.apache.spark.sql._
|
||||||
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructField, StructType}
|
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructField, StructType}
|
||||||
|
|
||||||
class XGBoostDFSuite extends Utils {
|
class XGBoostDFSuite extends SharedSparkContext with Utils {
|
||||||
|
|
||||||
private def loadRow(filePath: String): List[Row] = {
|
private def loadRow(filePath: String): List[Row] = {
|
||||||
val file = Source.fromFile(new File(filePath))
|
val file = Source.fromFile(new File(filePath))
|
||||||
@ -58,7 +58,7 @@ class XGBoostDFSuite extends Utils {
|
|||||||
|
|
||||||
test("test consistency between training with dataframe and RDD") {
|
test("test consistency between training with dataframe and RDD") {
|
||||||
val trainingDF = buildTrainingDataframe()
|
val trainingDF = buildTrainingDataframe()
|
||||||
val trainingRDD = buildTrainingRDD()
|
val trainingRDD = buildTrainingRDD(sc)
|
||||||
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "0",
|
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "0",
|
||||||
"objective" -> "binary:logistic").toMap
|
"objective" -> "binary:logistic").toMap
|
||||||
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
||||||
|
|||||||
@ -29,10 +29,10 @@ import org.apache.spark.mllib.regression.LabeledPoint
|
|||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.{SparkConf, SparkContext}
|
import org.apache.spark.{SparkConf, SparkContext}
|
||||||
|
|
||||||
class XGBoostGeneralSuite extends Utils {
|
class XGBoostGeneralSuite extends SharedSparkContext with Utils {
|
||||||
|
|
||||||
test("build RDD containing boosters with the specified worker number") {
|
test("build RDD containing boosters with the specified worker number") {
|
||||||
val trainingRDD = buildTrainingRDD()
|
val trainingRDD = buildTrainingRDD(sc)
|
||||||
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
||||||
@ -60,7 +60,7 @@ class XGBoostGeneralSuite extends Utils {
|
|||||||
val customSparkContext = new SparkContext(sparkConf)
|
val customSparkContext = new SparkContext(sparkConf)
|
||||||
customSparkContext.setLogLevel("ERROR")
|
customSparkContext.setLogLevel("ERROR")
|
||||||
val eval = new EvalError()
|
val eval = new EvalError()
|
||||||
val trainingRDD = buildTrainingRDD(Some(customSparkContext))
|
val trainingRDD = buildTrainingRDD(customSparkContext)
|
||||||
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
||||||
@ -113,7 +113,7 @@ class XGBoostGeneralSuite extends Utils {
|
|||||||
}
|
}
|
||||||
|
|
||||||
test("test consistency of prediction functions with RDD") {
|
test("test consistency of prediction functions with RDD") {
|
||||||
val trainingRDD = buildTrainingRDD()
|
val trainingRDD = buildTrainingRDD(sc)
|
||||||
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile)
|
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile)
|
||||||
val testRDD = sc.parallelize(testSet, numSlices = 1).map(_.features)
|
val testRDD = sc.parallelize(testSet, numSlices = 1).map(_.features)
|
||||||
val testCollection = testRDD.collect()
|
val testCollection = testRDD.collect()
|
||||||
@ -132,12 +132,22 @@ class XGBoostGeneralSuite extends Utils {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("test eval functions with RDD") {
|
||||||
|
val trainingRDD = buildTrainingRDD(sc)
|
||||||
|
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
||||||
|
"objective" -> "binary:logistic").toMap
|
||||||
|
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
|
||||||
|
val evalFunc = new EvalError
|
||||||
|
xgBoostModel.eval(trainingRDD, "eval1", iter = 5, useExternalCache = false)
|
||||||
|
xgBoostModel.eval(trainingRDD, "eval2", evalFunc = evalFunc, useExternalCache = false)
|
||||||
|
}
|
||||||
|
|
||||||
test("test prediction functionality with empty partition") {
|
test("test prediction functionality with empty partition") {
|
||||||
def buildEmptyRDD(sparkContext: Option[SparkContext] = None): RDD[SparkVector] = {
|
def buildEmptyRDD(sparkContext: Option[SparkContext] = None): RDD[SparkVector] = {
|
||||||
val sampleList = new ListBuffer[SparkVector]
|
val sampleList = new ListBuffer[SparkVector]
|
||||||
sparkContext.getOrElse(sc).parallelize(sampleList, numWorkers)
|
sparkContext.getOrElse(sc).parallelize(sampleList, numWorkers)
|
||||||
}
|
}
|
||||||
val trainingRDD = buildTrainingRDD()
|
val trainingRDD = buildTrainingRDD(sc)
|
||||||
val testRDD = buildEmptyRDD()
|
val testRDD = buildEmptyRDD()
|
||||||
val tempDir = Files.createTempDirectory("xgboosttest-")
|
val tempDir = Files.createTempDirectory("xgboosttest-")
|
||||||
val tempFile = Files.createTempFile(tempDir, "", "")
|
val tempFile = Files.createTempFile(tempDir, "", "")
|
||||||
@ -149,7 +159,7 @@ class XGBoostGeneralSuite extends Utils {
|
|||||||
|
|
||||||
test("test model consistency after save and load") {
|
test("test model consistency after save and load") {
|
||||||
val eval = new EvalError()
|
val eval = new EvalError()
|
||||||
val trainingRDD = buildTrainingRDD()
|
val trainingRDD = buildTrainingRDD(sc)
|
||||||
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
||||||
@ -176,7 +186,7 @@ class XGBoostGeneralSuite extends Utils {
|
|||||||
val customSparkContext = new SparkContext(sparkConf)
|
val customSparkContext = new SparkContext(sparkConf)
|
||||||
customSparkContext.setLogLevel("ERROR")
|
customSparkContext.setLogLevel("ERROR")
|
||||||
// start another app
|
// start another app
|
||||||
val trainingRDD = buildTrainingRDD(Some(customSparkContext))
|
val trainingRDD = buildTrainingRDD(customSparkContext)
|
||||||
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
||||||
"objective" -> "binary:logistic", "nthread" -> 6).toMap
|
"objective" -> "binary:logistic", "nthread" -> 6).toMap
|
||||||
intercept[IllegalArgumentException] {
|
intercept[IllegalArgumentException] {
|
||||||
@ -194,7 +204,7 @@ class XGBoostGeneralSuite extends Utils {
|
|||||||
sparkConf.registerKryoClasses(Array(classOf[Booster]))
|
sparkConf.registerKryoClasses(Array(classOf[Booster]))
|
||||||
val customSparkContext = new SparkContext(sparkConf)
|
val customSparkContext = new SparkContext(sparkConf)
|
||||||
customSparkContext.setLogLevel("ERROR")
|
customSparkContext.setLogLevel("ERROR")
|
||||||
val trainingRDD = buildTrainingRDD(Some(customSparkContext))
|
val trainingRDD = buildTrainingRDD(customSparkContext)
|
||||||
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user