From bb388cbb31ddc0284ab0469bfdd6913aed5ad23b Mon Sep 17 00:00:00 2001 From: Nan Zhu Date: Wed, 14 Sep 2016 13:26:16 -0400 Subject: [PATCH] default eval func (#1574) --- .../xgboost4j/scala/spark/XGBoostModel.scala | 37 +++++++---- .../xgboost4j/scala/spark/EvalError.scala | 63 +++++++++++++++++++ .../scala/spark/SharedSparkContext.scala | 4 +- .../ml/dmlc/xgboost4j/scala/spark/Utils.scala | 48 +------------- .../scala/spark/XGBoostDFSuite.scala | 4 +- .../scala/spark/XGBoostGeneralSuite.scala | 26 +++++--- 6 files changed, 114 insertions(+), 68 deletions(-) create mode 100644 jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala index b33bfd33e..7070109b8 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala @@ -42,16 +42,20 @@ class XGBoostModel(_booster: Booster) extends Model[XGBoostModel] with Serializa /** * evaluate XGBoostModel with a RDD-wrapped dataset * + * NOTE: you have to specify value of either eval or iter; when you specify both, this adopts + * the default eval metric of model + * * @param evalDataset the dataset used for evaluation - * @param eval the customized evaluation function, can be null for using default in the model + * @param evalName the name of evaluation + * @param evalFunc the customized evaluation function, null by default to use the default metric + * of model + * @param iter the current iteration, -1 to be null to use customized evaluation functions * @param useExternalCache if use external cache * @return the average metric over all partitions */ - def eval( - evalDataset: RDD[LabeledPoint], - eval: EvalTrait, - evalName: String, - useExternalCache: Boolean = false): String = { + def eval(evalDataset: RDD[LabeledPoint], evalName: String, evalFunc: EvalTrait = null, + iter: Int = -1, useExternalCache: Boolean = false): String = { + require(evalFunc != null || iter != -1, "you have to specify value of either eval or iter") val broadcastBooster = evalDataset.sparkContext.broadcast(_booster) val appName = evalDataset.context.appName val allEvalMetrics = evalDataset.mapPartitions { @@ -62,20 +66,29 @@ class XGBoostModel(_booster: Booster) extends Model[XGBoostModel] with Serializa import DataUtils._ val cacheFileName = { if (useExternalCache) { - s"$appName-deval_cache-${TaskContext.getPartitionId()}" + s"$appName-${TaskContext.get().stageId()}-deval_cache-${TaskContext.getPartitionId()}" } else { null } } val dMatrix = new DMatrix(labeledPointsPartition, cacheFileName) - val predictions = broadcastBooster.value.predict(dMatrix) - Rabit.shutdown() - Iterator(Some(eval.eval(predictions, dMatrix))) + if (iter == -1) { + val predictions = broadcastBooster.value.predict(dMatrix) + Rabit.shutdown() + Iterator(Some((evalName, evalFunc.eval(predictions, dMatrix)))) + } else { + val predStr = broadcastBooster.value.evalSet(Array(dMatrix), Array(evalName), iter) + val Array(evName, predNumeric) = predStr.split(":") + Rabit.shutdown() + Iterator(Some(evName, predNumeric.toFloat)) + } } else { Iterator(None) } }.filter(_.isDefined).collect() - s"$evalName-${eval.getMetric} = ${allEvalMetrics.map(_.get).sum / allEvalMetrics.length}" + val evalPrefix = allEvalMetrics.map(_.get._1).head + val evalMetricMean = allEvalMetrics.map(_.get._2).sum / allEvalMetrics.length + s"$evalPrefix = $evalMetricMean" } /** @@ -176,6 +189,7 @@ class XGBoostModel(_booster: Booster) extends Model[XGBoostModel] with Serializa /** * produces the prediction results and append as an additional column in the original dataset * NOTE: the prediction results is kept as the original format of xgboost + * * @return the original dataframe with an additional column containing prediction results */ override def transform(testSet: Dataset[_]): DataFrame = { @@ -186,6 +200,7 @@ class XGBoostModel(_booster: Booster) extends Model[XGBoostModel] with Serializa * produces the prediction results and append as an additional column in the original dataset * NOTE: the prediction results is transformed by applying the transformation function * predictResultTrans to the original xgboost output + * * @param predictResultTrans the function to transform xgboost output to the expected format * @return the original dataframe with an additional column containing prediction results */ diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala new file mode 100644 index 000000000..ec37ec0cd --- /dev/null +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala @@ -0,0 +1,63 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala.spark + +import ml.dmlc.xgboost4j.java.XGBoostError +import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait} +import org.apache.commons.logging.LogFactory + +class EvalError extends EvalTrait { + + val logger = LogFactory.getLog(classOf[EvalError]) + + private[xgboost4j] var evalMetric: String = "custom_error" + + /** + * get evaluate metric + * + * @return evalMetric + */ + override def getMetric: String = evalMetric + + /** + * evaluate with predicts and data + * + * @param predicts predictions as array + * @param dmat data matrix to evaluate + * @return result of the metric + */ + override def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float = { + var error: Float = 0f + var labels: Array[Float] = null + try { + labels = dmat.getLabel + } catch { + case ex: XGBoostError => + logger.error(ex) + return -1f + } + val nrow: Int = predicts.length + for (i <- 0 until nrow) { + if (labels(i) == 0.0 && predicts(i)(0) > 0) { + error += 1 + } else if (labels(i) == 1.0 && predicts(i)(0) <= 0) { + error += 1 + } + } + error / labels.length + } +} diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/SharedSparkContext.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/SharedSparkContext.scala index a73cb9fac..0729cde0d 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/SharedSparkContext.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/SharedSparkContext.scala @@ -19,9 +19,9 @@ package ml.dmlc.xgboost4j.scala.spark import org.apache.spark.{SparkConf, SparkContext} import org.scalatest.{BeforeAndAfter, FunSuite} -trait SharedSparkContext extends FunSuite with BeforeAndAfter { +class SharedSparkContext extends FunSuite with BeforeAndAfter with Serializable { - protected implicit var sc: SparkContext = null + @transient protected implicit var sc: SparkContext = null before { // build SparkContext diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/Utils.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/Utils.scala index 7c8ac1744..83dbb3e1e 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/Utils.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/Utils.scala @@ -29,51 +29,9 @@ import org.apache.spark.mllib.linalg.{DenseVector, Vector => SparkVector} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD -trait Utils extends SharedSparkContext { +trait Utils extends Serializable { protected val numWorkers = Runtime.getRuntime().availableProcessors() - protected class EvalError extends EvalTrait { - - val logger = LogFactory.getLog(classOf[EvalError]) - - private[xgboost4j] var evalMetric: String = "custom_error" - - /** - * get evaluate metric - * - * @return evalMetric - */ - override def getMetric: String = evalMetric - - /** - * evaluate with predicts and data - * - * @param predicts predictions as array - * @param dmat data matrix to evaluate - * @return result of the metric - */ - override def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float = { - var error: Float = 0f - var labels: Array[Float] = null - try { - labels = dmat.getLabel - } catch { - case ex: XGBoostError => - logger.error(ex) - return -1f - } - val nrow: Int = predicts.length - for (i <- 0 until nrow) { - if (labels(i) == 0.0 && predicts(i)(0) > 0) { - error += 1 - } else if (labels(i) == 1.0 && predicts(i)(0) <= 0) { - error += 1 - } - } - error / labels.length - } - } - protected def loadLabelPoints(filePath: String): List[LabeledPoint] = { val file = Source.fromFile(new File(filePath)) val sampleList = new ListBuffer[LabeledPoint] @@ -100,8 +58,8 @@ trait Utils extends SharedSparkContext { LabeledPoint(label, sv) } - protected def buildTrainingRDD(sparkContext: Option[SparkContext] = None): RDD[LabeledPoint] = { + protected def buildTrainingRDD(sparkContext: SparkContext): RDD[LabeledPoint] = { val sampleList = loadLabelPoints(getClass.getResource("/agaricus.txt.train").getFile) - sparkContext.getOrElse(sc).parallelize(sampleList, numWorkers) + sparkContext.parallelize(sampleList, numWorkers) } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala index 527f5bf15..48b450e60 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.sql._ import org.apache.spark.sql.types.{DoubleType, IntegerType, StructField, StructType} -class XGBoostDFSuite extends Utils { +class XGBoostDFSuite extends SharedSparkContext with Utils { private def loadRow(filePath: String): List[Row] = { val file = Source.fromFile(new File(filePath)) @@ -58,7 +58,7 @@ class XGBoostDFSuite extends Utils { test("test consistency between training with dataframe and RDD") { val trainingDF = buildTrainingDataframe() - val trainingRDD = buildTrainingRDD() + val trainingRDD = buildTrainingRDD(sc) val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "0", "objective" -> "binary:logistic").toMap val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap, diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala index a6877b096..f02496096 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala @@ -29,10 +29,10 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD import org.apache.spark.{SparkConf, SparkContext} -class XGBoostGeneralSuite extends Utils { +class XGBoostGeneralSuite extends SharedSparkContext with Utils { test("build RDD containing boosters with the specified worker number") { - val trainingRDD = buildTrainingRDD() + val trainingRDD = buildTrainingRDD(sc) val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator import DataUtils._ val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null)) @@ -60,7 +60,7 @@ class XGBoostGeneralSuite extends Utils { val customSparkContext = new SparkContext(sparkConf) customSparkContext.setLogLevel("ERROR") val eval = new EvalError() - val trainingRDD = buildTrainingRDD(Some(customSparkContext)) + val trainingRDD = buildTrainingRDD(customSparkContext) val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator import DataUtils._ val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null)) @@ -113,7 +113,7 @@ class XGBoostGeneralSuite extends Utils { } test("test consistency of prediction functions with RDD") { - val trainingRDD = buildTrainingRDD() + val trainingRDD = buildTrainingRDD(sc) val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile) val testRDD = sc.parallelize(testSet, numSlices = 1).map(_.features) val testCollection = testRDD.collect() @@ -132,12 +132,22 @@ class XGBoostGeneralSuite extends Utils { } } + test("test eval functions with RDD") { + val trainingRDD = buildTrainingRDD(sc) + val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0", + "objective" -> "binary:logistic").toMap + val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers) + val evalFunc = new EvalError + xgBoostModel.eval(trainingRDD, "eval1", iter = 5, useExternalCache = false) + xgBoostModel.eval(trainingRDD, "eval2", evalFunc = evalFunc, useExternalCache = false) + } + test("test prediction functionality with empty partition") { def buildEmptyRDD(sparkContext: Option[SparkContext] = None): RDD[SparkVector] = { val sampleList = new ListBuffer[SparkVector] sparkContext.getOrElse(sc).parallelize(sampleList, numWorkers) } - val trainingRDD = buildTrainingRDD() + val trainingRDD = buildTrainingRDD(sc) val testRDD = buildEmptyRDD() val tempDir = Files.createTempDirectory("xgboosttest-") val tempFile = Files.createTempFile(tempDir, "", "") @@ -149,7 +159,7 @@ class XGBoostGeneralSuite extends Utils { test("test model consistency after save and load") { val eval = new EvalError() - val trainingRDD = buildTrainingRDD() + val trainingRDD = buildTrainingRDD(sc) val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator import DataUtils._ val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null)) @@ -176,7 +186,7 @@ class XGBoostGeneralSuite extends Utils { val customSparkContext = new SparkContext(sparkConf) customSparkContext.setLogLevel("ERROR") // start another app - val trainingRDD = buildTrainingRDD(Some(customSparkContext)) + val trainingRDD = buildTrainingRDD(customSparkContext) val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0", "objective" -> "binary:logistic", "nthread" -> 6).toMap intercept[IllegalArgumentException] { @@ -194,7 +204,7 @@ class XGBoostGeneralSuite extends Utils { sparkConf.registerKryoClasses(Array(classOf[Booster])) val customSparkContext = new SparkContext(sparkConf) customSparkContext.setLogLevel("ERROR") - val trainingRDD = buildTrainingRDD(Some(customSparkContext)) + val trainingRDD = buildTrainingRDD(customSparkContext) val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator import DataUtils._ val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))