[jvm-packages] Exposed train-time evaluation metrics (#2836)

* [jvm-packages] Exposed train-time evaluation metrics

They are accessible via 'XGBoostModel.summary'. The summary is not
serialized with the model and is only available after the training.

* Addressed review comments

* Extracted model-related tests into 'XGBoostModelSuite'

* Added tests for copying the 'XGBoostModel'

* [jvm-packages] Fixed a subtle bug in train/test split

Iterator.partition (naturally) assumes that the predicate is deterministic
but this is not the case for

    r.nextDouble() <= trainTestRatio

therefore sometimes the DMatrix(...) call got a NoSuchElementException
and crashed the JVM due to lack of exception handling in
XGBoost4jCallbackDataIterNext.

* Make sure train/test objectives are different
This commit is contained in:
Sergei Lebedev 2017-11-20 22:21:54 +01:00 committed by GitHub
parent 88177691b8
commit 8e141427aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 247 additions and 104 deletions

View File

@ -104,7 +104,7 @@ object XGBoost extends Serializable {
obj: ObjectiveTrait,
eval: EvalTrait,
useExternalMemory: Boolean,
missing: Float): RDD[Booster] = {
missing: Float): RDD[(Booster, Map[String, Array[Float]])] = {
val partitionedData = if (data.getNumPartitions != numWorkers) {
logger.info(s"repartitioning training set to $numWorkers partitions")
data.repartition(numWorkers)
@ -136,11 +136,12 @@ object XGBoost extends Serializable {
try {
val numEarlyStoppingRounds = params.get("numEarlyStoppingRounds")
.map(_.toString.toInt).getOrElse(0)
.map(_.toString.toInt).getOrElse(0)
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round))
val booster = SXGBoost.train(watches.train, params, round,
watches = watches.toMap, obj = obj, eval = eval,
watches.toMap, metrics, obj, eval,
earlyStoppingRound = numEarlyStoppingRounds)
Iterator(booster)
Iterator(booster -> watches.toMap.keys.zip(metrics).toMap)
} finally {
Rabit.shutdown()
watches.delete()
@ -330,12 +331,12 @@ object XGBoost extends Serializable {
val sc = trainingData.sparkContext
val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, nWorkers)
val overriddenParams = overrideParamsAccordingToTaskCPUs(params, trainingData.sparkContext)
val boosters = buildDistributedBoosters(trainingData, overriddenParams,
val boostersAndMetrics = buildDistributedBoosters(trainingData, overriddenParams,
tracker.getWorkerEnvs, nWorkers, round, obj, eval, useExternalMemory, missing)
val sparkJobThread = new Thread() {
override def run() {
// force the job
boosters.foreachPartition(() => _)
boostersAndMetrics.foreachPartition(() => _)
}
}
sparkJobThread.setUncaughtExceptionHandler(tracker)
@ -343,7 +344,8 @@ object XGBoost extends Serializable {
val isClsTask = isClassificationTask(params)
val trackerReturnVal = parallelismTracker.execute(tracker.waitFor(0L))
logger.info(s"Rabit returns with exit code $trackerReturnVal")
val model = postTrackerReturnProcessing(trackerReturnVal, boosters, sparkJobThread, isClsTask)
val model = postTrackerReturnProcessing(trackerReturnVal, boostersAndMetrics,
sparkJobThread, isClsTask)
if (isClsTask){
model.asInstanceOf[XGBoostClassificationModel].numOfClasses =
params.getOrElse("num_class", "2").toString.toInt
@ -356,15 +358,18 @@ object XGBoost extends Serializable {
private def postTrackerReturnProcessing(
trackerReturnVal: Int,
distributedBoosters: RDD[Booster],
distributedBoostersAndMetrics: RDD[(Booster, Map[String, Array[Float]])],
sparkJobThread: Thread,
isClassificationTask: Boolean): XGBoostModel = {
isClassificationTask: Boolean
): XGBoostModel = {
if (trackerReturnVal == 0) {
// Copies of the finished model reside in each partition of the `distributedBoosters`.
// Any of them can be used to create the model. Here, just choose the first partition.
val xgboostModel = XGBoostModel(distributedBoosters.first(), isClassificationTask)
distributedBoosters.unpersist(false)
xgboostModel
// Copies of the final booster and the corresponding metrics
// reside in each partition of the `distributedBoostersAndMetrics`.
// Any of them can be used to create the model.
val (booster, metrics) = distributedBoostersAndMetrics.first()
val xgboostModel = XGBoostModel(booster, isClassificationTask)
distributedBoostersAndMetrics.unpersist(false)
xgboostModel.setSummary(XGBoostTrainingSummary(metrics))
} else {
try {
if (sparkJobThread.isAlive) {
@ -461,11 +466,17 @@ private object Watches {
val trainTestRatio = params.get("trainTestRatio").map(_.toString.toDouble).getOrElse(1.0)
val seed = params.get("seed").map(_.toString.toLong).getOrElse(System.nanoTime())
val r = new Random(seed)
// In the worst-case this would store [[trainTestRatio]] of points
// buffered in memory.
val (trainPoints, testPoints) = labeledPoints.partition(_ => r.nextDouble() <= trainTestRatio)
val testPoints = mutable.ArrayBuffer.empty[XGBLabeledPoint]
val trainPoints = labeledPoints.filter { labeledPoint =>
val accepted = r.nextDouble() <= trainTestRatio
if (!accepted) {
testPoints += labeledPoint
}
accepted
}
val trainMatrix = new DMatrix(trainPoints, cacheFileName)
val testMatrix = new DMatrix(testPoints, cacheFileName)
val testMatrix = new DMatrix(testPoints.iterator, cacheFileName)
r.setSeed(seed)
for (baseMargins <- baseMarginsOpt) {
val (trainMargin, testMargin) = baseMargins.partition(_ => r.nextDouble() <= trainTestRatio)

View File

@ -171,9 +171,8 @@ class XGBoostClassificationModel private[spark](
def numClasses: Int = numOfClasses
override def copy(extra: ParamMap): XGBoostClassificationModel = {
val clsModel = defaultCopy(extra).asInstanceOf[XGBoostClassificationModel]
clsModel._booster = booster
clsModel
val newModel = copyValues(new XGBoostClassificationModel(booster), extra)
newModel.setSummary(summary)
}
override protected def predict(features: MLVector): Double = {

View File

@ -42,6 +42,21 @@ abstract class XGBoostModel(protected var _booster: Booster)
extends PredictionModel[MLVector, XGBoostModel] with BoosterParams with Serializable
with Params with MLWritable {
private var trainingSummary: Option[XGBoostTrainingSummary] = None
/**
* Returns summary (e.g. train/test objective history) of model on the
* training set. An exception is thrown if no summary is available.
*/
def summary: XGBoostTrainingSummary = trainingSummary.getOrElse {
throw new IllegalStateException("No training summary available for this XGBoostModel")
}
private[spark] def setSummary(summary: XGBoostTrainingSummary): this.type = {
trainingSummary = Some(summary)
this
}
def setLabelCol(name: String): XGBoostModel = set(labelCol, name)
// scalastyle:off

View File

@ -55,8 +55,7 @@ class XGBoostRegressionModel private[spark](override val uid: String, booster: B
}
override def copy(extra: ParamMap): XGBoostRegressionModel = {
val regModel = defaultCopy(extra).asInstanceOf[XGBoostRegressionModel]
regModel._booster = booster
regModel
val newModel = copyValues(new XGBoostRegressionModel(booster), extra)
newModel.setSummary(summary)
}
}

View File

@ -0,0 +1,36 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
class XGBoostTrainingSummary private(
val trainObjectiveHistory: Array[Float],
val testObjectiveHistory: Option[Array[Float]]
) extends Serializable {
override def toString: String = {
val train = trainObjectiveHistory.toList
val test = testObjectiveHistory.map(_.toList)
s"XGBoostTrainingSummary(trainObjectiveHistory=$train, testObjectiveHistory=$test)"
}
}
private[xgboost4j] object XGBoostTrainingSummary {
def apply(metrics: Map[String, Array[Float]]): XGBoostTrainingSummary = {
new XGBoostTrainingSummary(
trainObjectiveHistory = metrics("train"),
testObjectiveHistory = metrics.get("test"))
}
}

View File

@ -236,4 +236,28 @@ class XGBoostDFSuite extends FunSuite with PerTest {
// The predictions heavily relies on the first training instance, and thus are very close.
predictions.foreach(pred => assert(math.abs(pred - predictions.head) <= 0.01f))
}
test("training summary") {
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic").toMap
val trainingDf = buildDataFrame(Classification.train)
val model = XGBoost.trainWithDataFrame(trainingDf, paramMap, round = 5,
nWorkers = numWorkers)
assert(model.summary.trainObjectiveHistory.length === 5)
assert(model.summary.testObjectiveHistory.isEmpty)
}
test("train/test split") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "trainTestRatio" -> "0.5")
val trainingDf = buildDataFrame(Classification.train)
val model = XGBoost.trainWithDataFrame(trainingDf, paramMap, round = 5,
nWorkers = numWorkers)
val Some(testObjectiveHistory) = model.summary.testObjectiveHistory
assert(testObjectiveHistory.length === 5)
assert(model.summary.trainObjectiveHistory !== testObjectiveHistory)
}
}

View File

@ -16,7 +16,6 @@
package ml.dmlc.xgboost4j.scala.spark
import java.nio.file.Files
import java.util.concurrent.LinkedBlockingDeque
import scala.util.Random
@ -24,7 +23,6 @@ import scala.util.Random
import ml.dmlc.xgboost4j.java.Rabit
import ml.dmlc.xgboost4j.scala.DMatrix
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
import org.apache.spark.SparkContext
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
import org.apache.spark.ml.linalg.{DenseVector, Vectors, Vector => SparkVector}
@ -262,84 +260,6 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
println(xgBoostModel.predict(testRDD).collect().length === 0)
}
test("test model consistency after save and load") {
import DataUtils._
val eval = new EvalError()
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
val testSetDMatrix = new DMatrix(Classification.test.iterator)
val tempDir = Files.createTempDirectory("xgboosttest-")
val tempFile = Files.createTempFile(tempDir, "", "")
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic")
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
val evalResults = eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
testSetDMatrix)
assert(evalResults < 0.1)
xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath)
val loadedXGBooostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath)
val predicts = loadedXGBooostModel.booster.predict(testSetDMatrix, outPutMargin = true)
val loadedEvalResults = eval.eval(predicts, testSetDMatrix)
assert(loadedEvalResults == evalResults)
}
test("test save and load of different types of models") {
import DataUtils._
val tempDir = Files.createTempDirectory("xgboosttest-")
val tempFile = Files.createTempFile(tempDir, "", "")
var trainingRDD = sc.parallelize(Classification.train).map(_.asML)
var paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "reg:linear")
// validate regression model
var xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
nWorkers = numWorkers, useExternalMemory = false)
xgBoostModel.setFeaturesCol("feature_col")
xgBoostModel.setLabelCol("label_col")
xgBoostModel.setPredictionCol("prediction_col")
xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath)
var loadedXGBoostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath)
assert(loadedXGBoostModel.isInstanceOf[XGBoostRegressionModel])
assert(loadedXGBoostModel.getFeaturesCol == "feature_col")
assert(loadedXGBoostModel.getLabelCol == "label_col")
assert(loadedXGBoostModel.getPredictionCol == "prediction_col")
// classification model
paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic")
xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
nWorkers = numWorkers, useExternalMemory = false)
xgBoostModel.asInstanceOf[XGBoostClassificationModel].setRawPredictionCol("raw_col")
xgBoostModel.asInstanceOf[XGBoostClassificationModel].setThresholds(Array(0.5, 0.5))
xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath)
loadedXGBoostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath)
assert(loadedXGBoostModel.isInstanceOf[XGBoostClassificationModel])
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getRawPredictionCol ==
"raw_col")
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getThresholds.deep ==
Array(0.5, 0.5).deep)
assert(loadedXGBoostModel.getFeaturesCol == "features")
assert(loadedXGBoostModel.getLabelCol == "label")
assert(loadedXGBoostModel.getPredictionCol == "prediction")
// (multiclass) classification model
trainingRDD = sc.parallelize(MultiClassification.train).map(_.asML)
paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "multi:softmax", "num_class" -> "6")
xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
nWorkers = numWorkers, useExternalMemory = false)
xgBoostModel.asInstanceOf[XGBoostClassificationModel].setRawPredictionCol("raw_col")
xgBoostModel.asInstanceOf[XGBoostClassificationModel].setThresholds(
Array(0.5, 0.5, 0.5, 0.5, 0.5, 0.5))
xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath)
loadedXGBoostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath)
assert(loadedXGBoostModel.isInstanceOf[XGBoostClassificationModel])
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getRawPredictionCol ==
"raw_col")
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getThresholds.deep ==
Array(0.5, 0.5, 0.5, 0.5, 0.5, 0.5).deep)
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].numOfClasses == 6)
assert(loadedXGBoostModel.getFeaturesCol == "features")
assert(loadedXGBoostModel.getLabelCol == "label")
assert(loadedXGBoostModel.getPredictionCol == "prediction")
}
test("test use groupData") {
import DataUtils._
val trainingRDD = sc.parallelize(Ranking.train0, numSlices = 1).map(_.asML)

View File

@ -0,0 +1,133 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import java.nio.file.Files
import ml.dmlc.xgboost4j.scala.DMatrix
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.rdd.RDD
import org.scalatest.FunSuite
class XGBoostModelSuite extends FunSuite with PerTest {
test("test model consistency after save and load") {
import DataUtils._
val eval = new EvalError()
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
val testSetDMatrix = new DMatrix(Classification.test.iterator)
val tempDir = Files.createTempDirectory("xgboosttest-")
val tempFile = Files.createTempFile(tempDir, "", "")
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic")
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
val evalResults = eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
testSetDMatrix)
assert(evalResults < 0.1)
xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath)
val loadedXGBooostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath)
val predicts = loadedXGBooostModel.booster.predict(testSetDMatrix, outPutMargin = true)
val loadedEvalResults = eval.eval(predicts, testSetDMatrix)
assert(loadedEvalResults == evalResults)
}
test("test save and load of different types of models") {
import DataUtils._
val tempDir = Files.createTempDirectory("xgboosttest-")
val tempFile = Files.createTempFile(tempDir, "", "")
var trainingRDD = sc.parallelize(Classification.train).map(_.asML)
var paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "reg:linear")
// validate regression model
var xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
nWorkers = numWorkers, useExternalMemory = false)
xgBoostModel.setFeaturesCol("feature_col")
xgBoostModel.setLabelCol("label_col")
xgBoostModel.setPredictionCol("prediction_col")
xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath)
var loadedXGBoostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath)
assert(loadedXGBoostModel.isInstanceOf[XGBoostRegressionModel])
assert(loadedXGBoostModel.getFeaturesCol == "feature_col")
assert(loadedXGBoostModel.getLabelCol == "label_col")
assert(loadedXGBoostModel.getPredictionCol == "prediction_col")
// classification model
paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic")
xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
nWorkers = numWorkers, useExternalMemory = false)
xgBoostModel.asInstanceOf[XGBoostClassificationModel].setRawPredictionCol("raw_col")
xgBoostModel.asInstanceOf[XGBoostClassificationModel].setThresholds(Array(0.5, 0.5))
xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath)
loadedXGBoostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath)
assert(loadedXGBoostModel.isInstanceOf[XGBoostClassificationModel])
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getRawPredictionCol ==
"raw_col")
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getThresholds.deep ==
Array(0.5, 0.5).deep)
assert(loadedXGBoostModel.getFeaturesCol == "features")
assert(loadedXGBoostModel.getLabelCol == "label")
assert(loadedXGBoostModel.getPredictionCol == "prediction")
// (multiclass) classification model
trainingRDD = sc.parallelize(MultiClassification.train).map(_.asML)
paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "multi:softmax", "num_class" -> "6")
xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
nWorkers = numWorkers, useExternalMemory = false)
xgBoostModel.asInstanceOf[XGBoostClassificationModel].setRawPredictionCol("raw_col")
xgBoostModel.asInstanceOf[XGBoostClassificationModel].setThresholds(
Array(0.5, 0.5, 0.5, 0.5, 0.5, 0.5))
xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath)
loadedXGBoostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath)
assert(loadedXGBoostModel.isInstanceOf[XGBoostClassificationModel])
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getRawPredictionCol ==
"raw_col")
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getThresholds.deep ==
Array(0.5, 0.5, 0.5, 0.5, 0.5, 0.5).deep)
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].numOfClasses == 6)
assert(loadedXGBoostModel.getFeaturesCol == "features")
assert(loadedXGBoostModel.getLabelCol == "label")
assert(loadedXGBoostModel.getPredictionCol == "prediction")
}
test("copy and predict ClassificationModel") {
import DataUtils._
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
val testRDD = sc.parallelize(Classification.test).map(_.features)
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic")
val model = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
testCopy(model, testRDD)
}
test("copy and predict RegressionModel") {
import DataUtils._
val trainingRDD = sc.parallelize(Regression.train).map(_.asML)
val testRDD = sc.parallelize(Regression.test).map(_.features)
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "reg:linear")
val model = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
testCopy(model, testRDD)
}
private def testCopy(model: XGBoostModel, testRDD: RDD[Vector]): Unit = {
val modelCopy = model.copy(ParamMap.empty)
modelCopy.summary // Ensure no exception.
val expected = model.predict(testRDD).collect
assert(modelCopy.predict(testRDD).collect === expected)
}
}

View File

@ -63,6 +63,12 @@ XGB_EXTERN_C int XGBoost4jCallbackDataIterNext(
if (jenv->CallBooleanMethod(jiter, hasNext)) {
ret_value = 1;
jobject batch = jenv->CallObjectMethod(jiter, next);
if (batch == nullptr) {
CHECK(jenv->ExceptionOccurred());
jenv->ExceptionDescribe();
return -1;
}
jclass batchClass = jenv->GetObjectClass(batch);
jlongArray joffset = (jlongArray)jenv->GetObjectField(
batch, jenv->GetFieldID(batchClass, "rowOffset", "[J"));