default eval func (#1574)
This commit is contained in:
@@ -0,0 +1,63 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait}
|
||||
import org.apache.commons.logging.LogFactory
|
||||
|
||||
class EvalError extends EvalTrait {
|
||||
|
||||
val logger = LogFactory.getLog(classOf[EvalError])
|
||||
|
||||
private[xgboost4j] var evalMetric: String = "custom_error"
|
||||
|
||||
/**
|
||||
* get evaluate metric
|
||||
*
|
||||
* @return evalMetric
|
||||
*/
|
||||
override def getMetric: String = evalMetric
|
||||
|
||||
/**
|
||||
* evaluate with predicts and data
|
||||
*
|
||||
* @param predicts predictions as array
|
||||
* @param dmat data matrix to evaluate
|
||||
* @return result of the metric
|
||||
*/
|
||||
override def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float = {
|
||||
var error: Float = 0f
|
||||
var labels: Array[Float] = null
|
||||
try {
|
||||
labels = dmat.getLabel
|
||||
} catch {
|
||||
case ex: XGBoostError =>
|
||||
logger.error(ex)
|
||||
return -1f
|
||||
}
|
||||
val nrow: Int = predicts.length
|
||||
for (i <- 0 until nrow) {
|
||||
if (labels(i) == 0.0 && predicts(i)(0) > 0) {
|
||||
error += 1
|
||||
} else if (labels(i) == 1.0 && predicts(i)(0) <= 0) {
|
||||
error += 1
|
||||
}
|
||||
}
|
||||
error / labels.length
|
||||
}
|
||||
}
|
||||
@@ -19,9 +19,9 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
import org.apache.spark.{SparkConf, SparkContext}
|
||||
import org.scalatest.{BeforeAndAfter, FunSuite}
|
||||
|
||||
trait SharedSparkContext extends FunSuite with BeforeAndAfter {
|
||||
class SharedSparkContext extends FunSuite with BeforeAndAfter with Serializable {
|
||||
|
||||
protected implicit var sc: SparkContext = null
|
||||
@transient protected implicit var sc: SparkContext = null
|
||||
|
||||
before {
|
||||
// build SparkContext
|
||||
|
||||
@@ -29,51 +29,9 @@ import org.apache.spark.mllib.linalg.{DenseVector, Vector => SparkVector}
|
||||
import org.apache.spark.mllib.regression.LabeledPoint
|
||||
import org.apache.spark.rdd.RDD
|
||||
|
||||
trait Utils extends SharedSparkContext {
|
||||
trait Utils extends Serializable {
|
||||
protected val numWorkers = Runtime.getRuntime().availableProcessors()
|
||||
|
||||
protected class EvalError extends EvalTrait {
|
||||
|
||||
val logger = LogFactory.getLog(classOf[EvalError])
|
||||
|
||||
private[xgboost4j] var evalMetric: String = "custom_error"
|
||||
|
||||
/**
|
||||
* get evaluate metric
|
||||
*
|
||||
* @return evalMetric
|
||||
*/
|
||||
override def getMetric: String = evalMetric
|
||||
|
||||
/**
|
||||
* evaluate with predicts and data
|
||||
*
|
||||
* @param predicts predictions as array
|
||||
* @param dmat data matrix to evaluate
|
||||
* @return result of the metric
|
||||
*/
|
||||
override def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float = {
|
||||
var error: Float = 0f
|
||||
var labels: Array[Float] = null
|
||||
try {
|
||||
labels = dmat.getLabel
|
||||
} catch {
|
||||
case ex: XGBoostError =>
|
||||
logger.error(ex)
|
||||
return -1f
|
||||
}
|
||||
val nrow: Int = predicts.length
|
||||
for (i <- 0 until nrow) {
|
||||
if (labels(i) == 0.0 && predicts(i)(0) > 0) {
|
||||
error += 1
|
||||
} else if (labels(i) == 1.0 && predicts(i)(0) <= 0) {
|
||||
error += 1
|
||||
}
|
||||
}
|
||||
error / labels.length
|
||||
}
|
||||
}
|
||||
|
||||
protected def loadLabelPoints(filePath: String): List[LabeledPoint] = {
|
||||
val file = Source.fromFile(new File(filePath))
|
||||
val sampleList = new ListBuffer[LabeledPoint]
|
||||
@@ -100,8 +58,8 @@ trait Utils extends SharedSparkContext {
|
||||
LabeledPoint(label, sv)
|
||||
}
|
||||
|
||||
protected def buildTrainingRDD(sparkContext: Option[SparkContext] = None): RDD[LabeledPoint] = {
|
||||
protected def buildTrainingRDD(sparkContext: SparkContext): RDD[LabeledPoint] = {
|
||||
val sampleList = loadLabelPoints(getClass.getResource("/agaricus.txt.train").getFile)
|
||||
sparkContext.getOrElse(sc).parallelize(sampleList, numWorkers)
|
||||
sparkContext.parallelize(sampleList, numWorkers)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -30,7 +30,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructField, StructType}
|
||||
|
||||
class XGBoostDFSuite extends Utils {
|
||||
class XGBoostDFSuite extends SharedSparkContext with Utils {
|
||||
|
||||
private def loadRow(filePath: String): List[Row] = {
|
||||
val file = Source.fromFile(new File(filePath))
|
||||
@@ -58,7 +58,7 @@ class XGBoostDFSuite extends Utils {
|
||||
|
||||
test("test consistency between training with dataframe and RDD") {
|
||||
val trainingDF = buildTrainingDataframe()
|
||||
val trainingRDD = buildTrainingRDD()
|
||||
val trainingRDD = buildTrainingRDD(sc)
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "0",
|
||||
"objective" -> "binary:logistic").toMap
|
||||
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
||||
|
||||
@@ -29,10 +29,10 @@ import org.apache.spark.mllib.regression.LabeledPoint
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.{SparkConf, SparkContext}
|
||||
|
||||
class XGBoostGeneralSuite extends Utils {
|
||||
class XGBoostGeneralSuite extends SharedSparkContext with Utils {
|
||||
|
||||
test("build RDD containing boosters with the specified worker number") {
|
||||
val trainingRDD = buildTrainingRDD()
|
||||
val trainingRDD = buildTrainingRDD(sc)
|
||||
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
||||
import DataUtils._
|
||||
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
||||
@@ -60,7 +60,7 @@ class XGBoostGeneralSuite extends Utils {
|
||||
val customSparkContext = new SparkContext(sparkConf)
|
||||
customSparkContext.setLogLevel("ERROR")
|
||||
val eval = new EvalError()
|
||||
val trainingRDD = buildTrainingRDD(Some(customSparkContext))
|
||||
val trainingRDD = buildTrainingRDD(customSparkContext)
|
||||
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
||||
import DataUtils._
|
||||
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
||||
@@ -113,7 +113,7 @@ class XGBoostGeneralSuite extends Utils {
|
||||
}
|
||||
|
||||
test("test consistency of prediction functions with RDD") {
|
||||
val trainingRDD = buildTrainingRDD()
|
||||
val trainingRDD = buildTrainingRDD(sc)
|
||||
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile)
|
||||
val testRDD = sc.parallelize(testSet, numSlices = 1).map(_.features)
|
||||
val testCollection = testRDD.collect()
|
||||
@@ -132,12 +132,22 @@ class XGBoostGeneralSuite extends Utils {
|
||||
}
|
||||
}
|
||||
|
||||
test("test eval functions with RDD") {
|
||||
val trainingRDD = buildTrainingRDD(sc)
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
||||
"objective" -> "binary:logistic").toMap
|
||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
|
||||
val evalFunc = new EvalError
|
||||
xgBoostModel.eval(trainingRDD, "eval1", iter = 5, useExternalCache = false)
|
||||
xgBoostModel.eval(trainingRDD, "eval2", evalFunc = evalFunc, useExternalCache = false)
|
||||
}
|
||||
|
||||
test("test prediction functionality with empty partition") {
|
||||
def buildEmptyRDD(sparkContext: Option[SparkContext] = None): RDD[SparkVector] = {
|
||||
val sampleList = new ListBuffer[SparkVector]
|
||||
sparkContext.getOrElse(sc).parallelize(sampleList, numWorkers)
|
||||
}
|
||||
val trainingRDD = buildTrainingRDD()
|
||||
val trainingRDD = buildTrainingRDD(sc)
|
||||
val testRDD = buildEmptyRDD()
|
||||
val tempDir = Files.createTempDirectory("xgboosttest-")
|
||||
val tempFile = Files.createTempFile(tempDir, "", "")
|
||||
@@ -149,7 +159,7 @@ class XGBoostGeneralSuite extends Utils {
|
||||
|
||||
test("test model consistency after save and load") {
|
||||
val eval = new EvalError()
|
||||
val trainingRDD = buildTrainingRDD()
|
||||
val trainingRDD = buildTrainingRDD(sc)
|
||||
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
||||
import DataUtils._
|
||||
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
||||
@@ -176,7 +186,7 @@ class XGBoostGeneralSuite extends Utils {
|
||||
val customSparkContext = new SparkContext(sparkConf)
|
||||
customSparkContext.setLogLevel("ERROR")
|
||||
// start another app
|
||||
val trainingRDD = buildTrainingRDD(Some(customSparkContext))
|
||||
val trainingRDD = buildTrainingRDD(customSparkContext)
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
||||
"objective" -> "binary:logistic", "nthread" -> 6).toMap
|
||||
intercept[IllegalArgumentException] {
|
||||
@@ -194,7 +204,7 @@ class XGBoostGeneralSuite extends Utils {
|
||||
sparkConf.registerKryoClasses(Array(classOf[Booster]))
|
||||
val customSparkContext = new SparkContext(sparkConf)
|
||||
customSparkContext.setLogLevel("ERROR")
|
||||
val trainingRDD = buildTrainingRDD(Some(customSparkContext))
|
||||
val trainingRDD = buildTrainingRDD(customSparkContext)
|
||||
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
||||
import DataUtils._
|
||||
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
||||
|
||||
Reference in New Issue
Block a user