default eval func (#1574)

This commit is contained in:
Nan Zhu
2016-09-14 13:26:16 -04:00
committed by GitHub
parent 4733357278
commit bb388cbb31
6 changed files with 114 additions and 68 deletions

View File

@@ -0,0 +1,63 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.java.XGBoostError
import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait}
import org.apache.commons.logging.LogFactory
class EvalError extends EvalTrait {
val logger = LogFactory.getLog(classOf[EvalError])
private[xgboost4j] var evalMetric: String = "custom_error"
/**
* get evaluate metric
*
* @return evalMetric
*/
override def getMetric: String = evalMetric
/**
* evaluate with predicts and data
*
* @param predicts predictions as array
* @param dmat data matrix to evaluate
* @return result of the metric
*/
override def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float = {
var error: Float = 0f
var labels: Array[Float] = null
try {
labels = dmat.getLabel
} catch {
case ex: XGBoostError =>
logger.error(ex)
return -1f
}
val nrow: Int = predicts.length
for (i <- 0 until nrow) {
if (labels(i) == 0.0 && predicts(i)(0) > 0) {
error += 1
} else if (labels(i) == 1.0 && predicts(i)(0) <= 0) {
error += 1
}
}
error / labels.length
}
}

View File

@@ -19,9 +19,9 @@ package ml.dmlc.xgboost4j.scala.spark
import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest.{BeforeAndAfter, FunSuite}
trait SharedSparkContext extends FunSuite with BeforeAndAfter {
class SharedSparkContext extends FunSuite with BeforeAndAfter with Serializable {
protected implicit var sc: SparkContext = null
@transient protected implicit var sc: SparkContext = null
before {
// build SparkContext

View File

@@ -29,51 +29,9 @@ import org.apache.spark.mllib.linalg.{DenseVector, Vector => SparkVector}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
trait Utils extends SharedSparkContext {
trait Utils extends Serializable {
protected val numWorkers = Runtime.getRuntime().availableProcessors()
protected class EvalError extends EvalTrait {
val logger = LogFactory.getLog(classOf[EvalError])
private[xgboost4j] var evalMetric: String = "custom_error"
/**
* get evaluate metric
*
* @return evalMetric
*/
override def getMetric: String = evalMetric
/**
* evaluate with predicts and data
*
* @param predicts predictions as array
* @param dmat data matrix to evaluate
* @return result of the metric
*/
override def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float = {
var error: Float = 0f
var labels: Array[Float] = null
try {
labels = dmat.getLabel
} catch {
case ex: XGBoostError =>
logger.error(ex)
return -1f
}
val nrow: Int = predicts.length
for (i <- 0 until nrow) {
if (labels(i) == 0.0 && predicts(i)(0) > 0) {
error += 1
} else if (labels(i) == 1.0 && predicts(i)(0) <= 0) {
error += 1
}
}
error / labels.length
}
}
protected def loadLabelPoints(filePath: String): List[LabeledPoint] = {
val file = Source.fromFile(new File(filePath))
val sampleList = new ListBuffer[LabeledPoint]
@@ -100,8 +58,8 @@ trait Utils extends SharedSparkContext {
LabeledPoint(label, sv)
}
protected def buildTrainingRDD(sparkContext: Option[SparkContext] = None): RDD[LabeledPoint] = {
protected def buildTrainingRDD(sparkContext: SparkContext): RDD[LabeledPoint] = {
val sampleList = loadLabelPoints(getClass.getResource("/agaricus.txt.train").getFile)
sparkContext.getOrElse(sc).parallelize(sampleList, numWorkers)
sparkContext.parallelize(sampleList, numWorkers)
}
}

View File

@@ -30,7 +30,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.sql._
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructField, StructType}
class XGBoostDFSuite extends Utils {
class XGBoostDFSuite extends SharedSparkContext with Utils {
private def loadRow(filePath: String): List[Row] = {
val file = Source.fromFile(new File(filePath))
@@ -58,7 +58,7 @@ class XGBoostDFSuite extends Utils {
test("test consistency between training with dataframe and RDD") {
val trainingDF = buildTrainingDataframe()
val trainingRDD = buildTrainingRDD()
val trainingRDD = buildTrainingRDD(sc)
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "0",
"objective" -> "binary:logistic").toMap
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,

View File

@@ -29,10 +29,10 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkConf, SparkContext}
class XGBoostGeneralSuite extends Utils {
class XGBoostGeneralSuite extends SharedSparkContext with Utils {
test("build RDD containing boosters with the specified worker number") {
val trainingRDD = buildTrainingRDD()
val trainingRDD = buildTrainingRDD(sc)
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
import DataUtils._
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
@@ -60,7 +60,7 @@ class XGBoostGeneralSuite extends Utils {
val customSparkContext = new SparkContext(sparkConf)
customSparkContext.setLogLevel("ERROR")
val eval = new EvalError()
val trainingRDD = buildTrainingRDD(Some(customSparkContext))
val trainingRDD = buildTrainingRDD(customSparkContext)
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
import DataUtils._
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
@@ -113,7 +113,7 @@ class XGBoostGeneralSuite extends Utils {
}
test("test consistency of prediction functions with RDD") {
val trainingRDD = buildTrainingRDD()
val trainingRDD = buildTrainingRDD(sc)
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile)
val testRDD = sc.parallelize(testSet, numSlices = 1).map(_.features)
val testCollection = testRDD.collect()
@@ -132,12 +132,22 @@ class XGBoostGeneralSuite extends Utils {
}
}
test("test eval functions with RDD") {
val trainingRDD = buildTrainingRDD(sc)
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
"objective" -> "binary:logistic").toMap
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
val evalFunc = new EvalError
xgBoostModel.eval(trainingRDD, "eval1", iter = 5, useExternalCache = false)
xgBoostModel.eval(trainingRDD, "eval2", evalFunc = evalFunc, useExternalCache = false)
}
test("test prediction functionality with empty partition") {
def buildEmptyRDD(sparkContext: Option[SparkContext] = None): RDD[SparkVector] = {
val sampleList = new ListBuffer[SparkVector]
sparkContext.getOrElse(sc).parallelize(sampleList, numWorkers)
}
val trainingRDD = buildTrainingRDD()
val trainingRDD = buildTrainingRDD(sc)
val testRDD = buildEmptyRDD()
val tempDir = Files.createTempDirectory("xgboosttest-")
val tempFile = Files.createTempFile(tempDir, "", "")
@@ -149,7 +159,7 @@ class XGBoostGeneralSuite extends Utils {
test("test model consistency after save and load") {
val eval = new EvalError()
val trainingRDD = buildTrainingRDD()
val trainingRDD = buildTrainingRDD(sc)
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
import DataUtils._
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
@@ -176,7 +186,7 @@ class XGBoostGeneralSuite extends Utils {
val customSparkContext = new SparkContext(sparkConf)
customSparkContext.setLogLevel("ERROR")
// start another app
val trainingRDD = buildTrainingRDD(Some(customSparkContext))
val trainingRDD = buildTrainingRDD(customSparkContext)
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
"objective" -> "binary:logistic", "nthread" -> 6).toMap
intercept[IllegalArgumentException] {
@@ -194,7 +204,7 @@ class XGBoostGeneralSuite extends Utils {
sparkConf.registerKryoClasses(Array(classOf[Booster]))
val customSparkContext = new SparkContext(sparkConf)
customSparkContext.setLogLevel("ERROR")
val trainingRDD = buildTrainingRDD(Some(customSparkContext))
val trainingRDD = buildTrainingRDD(customSparkContext)
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
import DataUtils._
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))