[jvm-packages] Exposed train-time evaluation metrics (#2836)
* [jvm-packages] Exposed train-time evaluation metrics
They are accessible via 'XGBoostModel.summary'. The summary is not
serialized with the model and is only available after the training.
* Addressed review comments
* Extracted model-related tests into 'XGBoostModelSuite'
* Added tests for copying the 'XGBoostModel'
* [jvm-packages] Fixed a subtle bug in train/test split
Iterator.partition (naturally) assumes that the predicate is deterministic
but this is not the case for
r.nextDouble() <= trainTestRatio
therefore sometimes the DMatrix(...) call got a NoSuchElementException
and crashed the JVM due to lack of exception handling in
XGBoost4jCallbackDataIterNext.
* Make sure train/test objectives are different
This commit is contained in:
parent
88177691b8
commit
8e141427aa
@ -104,7 +104,7 @@ object XGBoost extends Serializable {
|
|||||||
obj: ObjectiveTrait,
|
obj: ObjectiveTrait,
|
||||||
eval: EvalTrait,
|
eval: EvalTrait,
|
||||||
useExternalMemory: Boolean,
|
useExternalMemory: Boolean,
|
||||||
missing: Float): RDD[Booster] = {
|
missing: Float): RDD[(Booster, Map[String, Array[Float]])] = {
|
||||||
val partitionedData = if (data.getNumPartitions != numWorkers) {
|
val partitionedData = if (data.getNumPartitions != numWorkers) {
|
||||||
logger.info(s"repartitioning training set to $numWorkers partitions")
|
logger.info(s"repartitioning training set to $numWorkers partitions")
|
||||||
data.repartition(numWorkers)
|
data.repartition(numWorkers)
|
||||||
@ -136,11 +136,12 @@ object XGBoost extends Serializable {
|
|||||||
|
|
||||||
try {
|
try {
|
||||||
val numEarlyStoppingRounds = params.get("numEarlyStoppingRounds")
|
val numEarlyStoppingRounds = params.get("numEarlyStoppingRounds")
|
||||||
.map(_.toString.toInt).getOrElse(0)
|
.map(_.toString.toInt).getOrElse(0)
|
||||||
|
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round))
|
||||||
val booster = SXGBoost.train(watches.train, params, round,
|
val booster = SXGBoost.train(watches.train, params, round,
|
||||||
watches = watches.toMap, obj = obj, eval = eval,
|
watches.toMap, metrics, obj, eval,
|
||||||
earlyStoppingRound = numEarlyStoppingRounds)
|
earlyStoppingRound = numEarlyStoppingRounds)
|
||||||
Iterator(booster)
|
Iterator(booster -> watches.toMap.keys.zip(metrics).toMap)
|
||||||
} finally {
|
} finally {
|
||||||
Rabit.shutdown()
|
Rabit.shutdown()
|
||||||
watches.delete()
|
watches.delete()
|
||||||
@ -330,12 +331,12 @@ object XGBoost extends Serializable {
|
|||||||
val sc = trainingData.sparkContext
|
val sc = trainingData.sparkContext
|
||||||
val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, nWorkers)
|
val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, nWorkers)
|
||||||
val overriddenParams = overrideParamsAccordingToTaskCPUs(params, trainingData.sparkContext)
|
val overriddenParams = overrideParamsAccordingToTaskCPUs(params, trainingData.sparkContext)
|
||||||
val boosters = buildDistributedBoosters(trainingData, overriddenParams,
|
val boostersAndMetrics = buildDistributedBoosters(trainingData, overriddenParams,
|
||||||
tracker.getWorkerEnvs, nWorkers, round, obj, eval, useExternalMemory, missing)
|
tracker.getWorkerEnvs, nWorkers, round, obj, eval, useExternalMemory, missing)
|
||||||
val sparkJobThread = new Thread() {
|
val sparkJobThread = new Thread() {
|
||||||
override def run() {
|
override def run() {
|
||||||
// force the job
|
// force the job
|
||||||
boosters.foreachPartition(() => _)
|
boostersAndMetrics.foreachPartition(() => _)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
sparkJobThread.setUncaughtExceptionHandler(tracker)
|
sparkJobThread.setUncaughtExceptionHandler(tracker)
|
||||||
@ -343,7 +344,8 @@ object XGBoost extends Serializable {
|
|||||||
val isClsTask = isClassificationTask(params)
|
val isClsTask = isClassificationTask(params)
|
||||||
val trackerReturnVal = parallelismTracker.execute(tracker.waitFor(0L))
|
val trackerReturnVal = parallelismTracker.execute(tracker.waitFor(0L))
|
||||||
logger.info(s"Rabit returns with exit code $trackerReturnVal")
|
logger.info(s"Rabit returns with exit code $trackerReturnVal")
|
||||||
val model = postTrackerReturnProcessing(trackerReturnVal, boosters, sparkJobThread, isClsTask)
|
val model = postTrackerReturnProcessing(trackerReturnVal, boostersAndMetrics,
|
||||||
|
sparkJobThread, isClsTask)
|
||||||
if (isClsTask){
|
if (isClsTask){
|
||||||
model.asInstanceOf[XGBoostClassificationModel].numOfClasses =
|
model.asInstanceOf[XGBoostClassificationModel].numOfClasses =
|
||||||
params.getOrElse("num_class", "2").toString.toInt
|
params.getOrElse("num_class", "2").toString.toInt
|
||||||
@ -356,15 +358,18 @@ object XGBoost extends Serializable {
|
|||||||
|
|
||||||
private def postTrackerReturnProcessing(
|
private def postTrackerReturnProcessing(
|
||||||
trackerReturnVal: Int,
|
trackerReturnVal: Int,
|
||||||
distributedBoosters: RDD[Booster],
|
distributedBoostersAndMetrics: RDD[(Booster, Map[String, Array[Float]])],
|
||||||
sparkJobThread: Thread,
|
sparkJobThread: Thread,
|
||||||
isClassificationTask: Boolean): XGBoostModel = {
|
isClassificationTask: Boolean
|
||||||
|
): XGBoostModel = {
|
||||||
if (trackerReturnVal == 0) {
|
if (trackerReturnVal == 0) {
|
||||||
// Copies of the finished model reside in each partition of the `distributedBoosters`.
|
// Copies of the final booster and the corresponding metrics
|
||||||
// Any of them can be used to create the model. Here, just choose the first partition.
|
// reside in each partition of the `distributedBoostersAndMetrics`.
|
||||||
val xgboostModel = XGBoostModel(distributedBoosters.first(), isClassificationTask)
|
// Any of them can be used to create the model.
|
||||||
distributedBoosters.unpersist(false)
|
val (booster, metrics) = distributedBoostersAndMetrics.first()
|
||||||
xgboostModel
|
val xgboostModel = XGBoostModel(booster, isClassificationTask)
|
||||||
|
distributedBoostersAndMetrics.unpersist(false)
|
||||||
|
xgboostModel.setSummary(XGBoostTrainingSummary(metrics))
|
||||||
} else {
|
} else {
|
||||||
try {
|
try {
|
||||||
if (sparkJobThread.isAlive) {
|
if (sparkJobThread.isAlive) {
|
||||||
@ -461,11 +466,17 @@ private object Watches {
|
|||||||
val trainTestRatio = params.get("trainTestRatio").map(_.toString.toDouble).getOrElse(1.0)
|
val trainTestRatio = params.get("trainTestRatio").map(_.toString.toDouble).getOrElse(1.0)
|
||||||
val seed = params.get("seed").map(_.toString.toLong).getOrElse(System.nanoTime())
|
val seed = params.get("seed").map(_.toString.toLong).getOrElse(System.nanoTime())
|
||||||
val r = new Random(seed)
|
val r = new Random(seed)
|
||||||
// In the worst-case this would store [[trainTestRatio]] of points
|
val testPoints = mutable.ArrayBuffer.empty[XGBLabeledPoint]
|
||||||
// buffered in memory.
|
val trainPoints = labeledPoints.filter { labeledPoint =>
|
||||||
val (trainPoints, testPoints) = labeledPoints.partition(_ => r.nextDouble() <= trainTestRatio)
|
val accepted = r.nextDouble() <= trainTestRatio
|
||||||
|
if (!accepted) {
|
||||||
|
testPoints += labeledPoint
|
||||||
|
}
|
||||||
|
|
||||||
|
accepted
|
||||||
|
}
|
||||||
val trainMatrix = new DMatrix(trainPoints, cacheFileName)
|
val trainMatrix = new DMatrix(trainPoints, cacheFileName)
|
||||||
val testMatrix = new DMatrix(testPoints, cacheFileName)
|
val testMatrix = new DMatrix(testPoints.iterator, cacheFileName)
|
||||||
r.setSeed(seed)
|
r.setSeed(seed)
|
||||||
for (baseMargins <- baseMarginsOpt) {
|
for (baseMargins <- baseMarginsOpt) {
|
||||||
val (trainMargin, testMargin) = baseMargins.partition(_ => r.nextDouble() <= trainTestRatio)
|
val (trainMargin, testMargin) = baseMargins.partition(_ => r.nextDouble() <= trainTestRatio)
|
||||||
|
|||||||
@ -171,9 +171,8 @@ class XGBoostClassificationModel private[spark](
|
|||||||
def numClasses: Int = numOfClasses
|
def numClasses: Int = numOfClasses
|
||||||
|
|
||||||
override def copy(extra: ParamMap): XGBoostClassificationModel = {
|
override def copy(extra: ParamMap): XGBoostClassificationModel = {
|
||||||
val clsModel = defaultCopy(extra).asInstanceOf[XGBoostClassificationModel]
|
val newModel = copyValues(new XGBoostClassificationModel(booster), extra)
|
||||||
clsModel._booster = booster
|
newModel.setSummary(summary)
|
||||||
clsModel
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override protected def predict(features: MLVector): Double = {
|
override protected def predict(features: MLVector): Double = {
|
||||||
|
|||||||
@ -42,6 +42,21 @@ abstract class XGBoostModel(protected var _booster: Booster)
|
|||||||
extends PredictionModel[MLVector, XGBoostModel] with BoosterParams with Serializable
|
extends PredictionModel[MLVector, XGBoostModel] with BoosterParams with Serializable
|
||||||
with Params with MLWritable {
|
with Params with MLWritable {
|
||||||
|
|
||||||
|
private var trainingSummary: Option[XGBoostTrainingSummary] = None
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns summary (e.g. train/test objective history) of model on the
|
||||||
|
* training set. An exception is thrown if no summary is available.
|
||||||
|
*/
|
||||||
|
def summary: XGBoostTrainingSummary = trainingSummary.getOrElse {
|
||||||
|
throw new IllegalStateException("No training summary available for this XGBoostModel")
|
||||||
|
}
|
||||||
|
|
||||||
|
private[spark] def setSummary(summary: XGBoostTrainingSummary): this.type = {
|
||||||
|
trainingSummary = Some(summary)
|
||||||
|
this
|
||||||
|
}
|
||||||
|
|
||||||
def setLabelCol(name: String): XGBoostModel = set(labelCol, name)
|
def setLabelCol(name: String): XGBoostModel = set(labelCol, name)
|
||||||
|
|
||||||
// scalastyle:off
|
// scalastyle:off
|
||||||
|
|||||||
@ -55,8 +55,7 @@ class XGBoostRegressionModel private[spark](override val uid: String, booster: B
|
|||||||
}
|
}
|
||||||
|
|
||||||
override def copy(extra: ParamMap): XGBoostRegressionModel = {
|
override def copy(extra: ParamMap): XGBoostRegressionModel = {
|
||||||
val regModel = defaultCopy(extra).asInstanceOf[XGBoostRegressionModel]
|
val newModel = copyValues(new XGBoostRegressionModel(booster), extra)
|
||||||
regModel._booster = booster
|
newModel.setSummary(summary)
|
||||||
regModel
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -0,0 +1,36 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
|
class XGBoostTrainingSummary private(
|
||||||
|
val trainObjectiveHistory: Array[Float],
|
||||||
|
val testObjectiveHistory: Option[Array[Float]]
|
||||||
|
) extends Serializable {
|
||||||
|
override def toString: String = {
|
||||||
|
val train = trainObjectiveHistory.toList
|
||||||
|
val test = testObjectiveHistory.map(_.toList)
|
||||||
|
s"XGBoostTrainingSummary(trainObjectiveHistory=$train, testObjectiveHistory=$test)"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private[xgboost4j] object XGBoostTrainingSummary {
|
||||||
|
def apply(metrics: Map[String, Array[Float]]): XGBoostTrainingSummary = {
|
||||||
|
new XGBoostTrainingSummary(
|
||||||
|
trainObjectiveHistory = metrics("train"),
|
||||||
|
testObjectiveHistory = metrics.get("test"))
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -236,4 +236,28 @@ class XGBoostDFSuite extends FunSuite with PerTest {
|
|||||||
// The predictions heavily relies on the first training instance, and thus are very close.
|
// The predictions heavily relies on the first training instance, and thus are very close.
|
||||||
predictions.foreach(pred => assert(math.abs(pred - predictions.head) <= 0.01f))
|
predictions.foreach(pred => assert(math.abs(pred - predictions.head) <= 0.01f))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("training summary") {
|
||||||
|
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
|
"objective" -> "binary:logistic").toMap
|
||||||
|
|
||||||
|
val trainingDf = buildDataFrame(Classification.train)
|
||||||
|
val model = XGBoost.trainWithDataFrame(trainingDf, paramMap, round = 5,
|
||||||
|
nWorkers = numWorkers)
|
||||||
|
|
||||||
|
assert(model.summary.trainObjectiveHistory.length === 5)
|
||||||
|
assert(model.summary.testObjectiveHistory.isEmpty)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("train/test split") {
|
||||||
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
|
"objective" -> "binary:logistic", "trainTestRatio" -> "0.5")
|
||||||
|
|
||||||
|
val trainingDf = buildDataFrame(Classification.train)
|
||||||
|
val model = XGBoost.trainWithDataFrame(trainingDf, paramMap, round = 5,
|
||||||
|
nWorkers = numWorkers)
|
||||||
|
val Some(testObjectiveHistory) = model.summary.testObjectiveHistory
|
||||||
|
assert(testObjectiveHistory.length === 5)
|
||||||
|
assert(model.summary.trainObjectiveHistory !== testObjectiveHistory)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -16,7 +16,6 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
import java.nio.file.Files
|
|
||||||
import java.util.concurrent.LinkedBlockingDeque
|
import java.util.concurrent.LinkedBlockingDeque
|
||||||
|
|
||||||
import scala.util.Random
|
import scala.util.Random
|
||||||
@ -24,7 +23,6 @@ import scala.util.Random
|
|||||||
import ml.dmlc.xgboost4j.java.Rabit
|
import ml.dmlc.xgboost4j.java.Rabit
|
||||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||||
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||||
|
|
||||||
import org.apache.spark.SparkContext
|
import org.apache.spark.SparkContext
|
||||||
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
|
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
|
||||||
import org.apache.spark.ml.linalg.{DenseVector, Vectors, Vector => SparkVector}
|
import org.apache.spark.ml.linalg.{DenseVector, Vectors, Vector => SparkVector}
|
||||||
@ -262,84 +260,6 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
|||||||
println(xgBoostModel.predict(testRDD).collect().length === 0)
|
println(xgBoostModel.predict(testRDD).collect().length === 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("test model consistency after save and load") {
|
|
||||||
import DataUtils._
|
|
||||||
val eval = new EvalError()
|
|
||||||
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
|
|
||||||
val testSetDMatrix = new DMatrix(Classification.test.iterator)
|
|
||||||
val tempDir = Files.createTempDirectory("xgboosttest-")
|
|
||||||
val tempFile = Files.createTempFile(tempDir, "", "")
|
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
|
||||||
"objective" -> "binary:logistic")
|
|
||||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
|
|
||||||
val evalResults = eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
|
||||||
testSetDMatrix)
|
|
||||||
assert(evalResults < 0.1)
|
|
||||||
xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath)
|
|
||||||
val loadedXGBooostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath)
|
|
||||||
val predicts = loadedXGBooostModel.booster.predict(testSetDMatrix, outPutMargin = true)
|
|
||||||
val loadedEvalResults = eval.eval(predicts, testSetDMatrix)
|
|
||||||
assert(loadedEvalResults == evalResults)
|
|
||||||
}
|
|
||||||
|
|
||||||
test("test save and load of different types of models") {
|
|
||||||
import DataUtils._
|
|
||||||
val tempDir = Files.createTempDirectory("xgboosttest-")
|
|
||||||
val tempFile = Files.createTempFile(tempDir, "", "")
|
|
||||||
var trainingRDD = sc.parallelize(Classification.train).map(_.asML)
|
|
||||||
var paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
|
||||||
"objective" -> "reg:linear")
|
|
||||||
// validate regression model
|
|
||||||
var xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
|
||||||
nWorkers = numWorkers, useExternalMemory = false)
|
|
||||||
xgBoostModel.setFeaturesCol("feature_col")
|
|
||||||
xgBoostModel.setLabelCol("label_col")
|
|
||||||
xgBoostModel.setPredictionCol("prediction_col")
|
|
||||||
xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath)
|
|
||||||
var loadedXGBoostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath)
|
|
||||||
assert(loadedXGBoostModel.isInstanceOf[XGBoostRegressionModel])
|
|
||||||
assert(loadedXGBoostModel.getFeaturesCol == "feature_col")
|
|
||||||
assert(loadedXGBoostModel.getLabelCol == "label_col")
|
|
||||||
assert(loadedXGBoostModel.getPredictionCol == "prediction_col")
|
|
||||||
// classification model
|
|
||||||
paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
|
||||||
"objective" -> "binary:logistic")
|
|
||||||
xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
|
||||||
nWorkers = numWorkers, useExternalMemory = false)
|
|
||||||
xgBoostModel.asInstanceOf[XGBoostClassificationModel].setRawPredictionCol("raw_col")
|
|
||||||
xgBoostModel.asInstanceOf[XGBoostClassificationModel].setThresholds(Array(0.5, 0.5))
|
|
||||||
xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath)
|
|
||||||
loadedXGBoostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath)
|
|
||||||
assert(loadedXGBoostModel.isInstanceOf[XGBoostClassificationModel])
|
|
||||||
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getRawPredictionCol ==
|
|
||||||
"raw_col")
|
|
||||||
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getThresholds.deep ==
|
|
||||||
Array(0.5, 0.5).deep)
|
|
||||||
assert(loadedXGBoostModel.getFeaturesCol == "features")
|
|
||||||
assert(loadedXGBoostModel.getLabelCol == "label")
|
|
||||||
assert(loadedXGBoostModel.getPredictionCol == "prediction")
|
|
||||||
// (multiclass) classification model
|
|
||||||
trainingRDD = sc.parallelize(MultiClassification.train).map(_.asML)
|
|
||||||
paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
|
||||||
"objective" -> "multi:softmax", "num_class" -> "6")
|
|
||||||
xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
|
||||||
nWorkers = numWorkers, useExternalMemory = false)
|
|
||||||
xgBoostModel.asInstanceOf[XGBoostClassificationModel].setRawPredictionCol("raw_col")
|
|
||||||
xgBoostModel.asInstanceOf[XGBoostClassificationModel].setThresholds(
|
|
||||||
Array(0.5, 0.5, 0.5, 0.5, 0.5, 0.5))
|
|
||||||
xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath)
|
|
||||||
loadedXGBoostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath)
|
|
||||||
assert(loadedXGBoostModel.isInstanceOf[XGBoostClassificationModel])
|
|
||||||
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getRawPredictionCol ==
|
|
||||||
"raw_col")
|
|
||||||
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getThresholds.deep ==
|
|
||||||
Array(0.5, 0.5, 0.5, 0.5, 0.5, 0.5).deep)
|
|
||||||
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].numOfClasses == 6)
|
|
||||||
assert(loadedXGBoostModel.getFeaturesCol == "features")
|
|
||||||
assert(loadedXGBoostModel.getLabelCol == "label")
|
|
||||||
assert(loadedXGBoostModel.getPredictionCol == "prediction")
|
|
||||||
}
|
|
||||||
|
|
||||||
test("test use groupData") {
|
test("test use groupData") {
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
val trainingRDD = sc.parallelize(Ranking.train0, numSlices = 1).map(_.asML)
|
val trainingRDD = sc.parallelize(Ranking.train0, numSlices = 1).map(_.asML)
|
||||||
|
|||||||
@ -0,0 +1,133 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
|
import java.nio.file.Files
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||||
|
import org.apache.spark.ml.linalg.Vector
|
||||||
|
import org.apache.spark.ml.param.ParamMap
|
||||||
|
import org.apache.spark.rdd.RDD
|
||||||
|
import org.scalatest.FunSuite
|
||||||
|
|
||||||
|
class XGBoostModelSuite extends FunSuite with PerTest {
|
||||||
|
test("test model consistency after save and load") {
|
||||||
|
import DataUtils._
|
||||||
|
val eval = new EvalError()
|
||||||
|
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
|
||||||
|
val testSetDMatrix = new DMatrix(Classification.test.iterator)
|
||||||
|
val tempDir = Files.createTempDirectory("xgboosttest-")
|
||||||
|
val tempFile = Files.createTempFile(tempDir, "", "")
|
||||||
|
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||||
|
"objective" -> "binary:logistic")
|
||||||
|
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
|
||||||
|
val evalResults = eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||||
|
testSetDMatrix)
|
||||||
|
assert(evalResults < 0.1)
|
||||||
|
xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath)
|
||||||
|
val loadedXGBooostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath)
|
||||||
|
val predicts = loadedXGBooostModel.booster.predict(testSetDMatrix, outPutMargin = true)
|
||||||
|
val loadedEvalResults = eval.eval(predicts, testSetDMatrix)
|
||||||
|
assert(loadedEvalResults == evalResults)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("test save and load of different types of models") {
|
||||||
|
import DataUtils._
|
||||||
|
val tempDir = Files.createTempDirectory("xgboosttest-")
|
||||||
|
val tempFile = Files.createTempFile(tempDir, "", "")
|
||||||
|
var trainingRDD = sc.parallelize(Classification.train).map(_.asML)
|
||||||
|
var paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
|
"objective" -> "reg:linear")
|
||||||
|
// validate regression model
|
||||||
|
var xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
||||||
|
nWorkers = numWorkers, useExternalMemory = false)
|
||||||
|
xgBoostModel.setFeaturesCol("feature_col")
|
||||||
|
xgBoostModel.setLabelCol("label_col")
|
||||||
|
xgBoostModel.setPredictionCol("prediction_col")
|
||||||
|
xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath)
|
||||||
|
var loadedXGBoostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath)
|
||||||
|
assert(loadedXGBoostModel.isInstanceOf[XGBoostRegressionModel])
|
||||||
|
assert(loadedXGBoostModel.getFeaturesCol == "feature_col")
|
||||||
|
assert(loadedXGBoostModel.getLabelCol == "label_col")
|
||||||
|
assert(loadedXGBoostModel.getPredictionCol == "prediction_col")
|
||||||
|
// classification model
|
||||||
|
paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
|
"objective" -> "binary:logistic")
|
||||||
|
xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
||||||
|
nWorkers = numWorkers, useExternalMemory = false)
|
||||||
|
xgBoostModel.asInstanceOf[XGBoostClassificationModel].setRawPredictionCol("raw_col")
|
||||||
|
xgBoostModel.asInstanceOf[XGBoostClassificationModel].setThresholds(Array(0.5, 0.5))
|
||||||
|
xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath)
|
||||||
|
loadedXGBoostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath)
|
||||||
|
assert(loadedXGBoostModel.isInstanceOf[XGBoostClassificationModel])
|
||||||
|
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getRawPredictionCol ==
|
||||||
|
"raw_col")
|
||||||
|
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getThresholds.deep ==
|
||||||
|
Array(0.5, 0.5).deep)
|
||||||
|
assert(loadedXGBoostModel.getFeaturesCol == "features")
|
||||||
|
assert(loadedXGBoostModel.getLabelCol == "label")
|
||||||
|
assert(loadedXGBoostModel.getPredictionCol == "prediction")
|
||||||
|
// (multiclass) classification model
|
||||||
|
trainingRDD = sc.parallelize(MultiClassification.train).map(_.asML)
|
||||||
|
paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
|
"objective" -> "multi:softmax", "num_class" -> "6")
|
||||||
|
xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
||||||
|
nWorkers = numWorkers, useExternalMemory = false)
|
||||||
|
xgBoostModel.asInstanceOf[XGBoostClassificationModel].setRawPredictionCol("raw_col")
|
||||||
|
xgBoostModel.asInstanceOf[XGBoostClassificationModel].setThresholds(
|
||||||
|
Array(0.5, 0.5, 0.5, 0.5, 0.5, 0.5))
|
||||||
|
xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath)
|
||||||
|
loadedXGBoostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath)
|
||||||
|
assert(loadedXGBoostModel.isInstanceOf[XGBoostClassificationModel])
|
||||||
|
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getRawPredictionCol ==
|
||||||
|
"raw_col")
|
||||||
|
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getThresholds.deep ==
|
||||||
|
Array(0.5, 0.5, 0.5, 0.5, 0.5, 0.5).deep)
|
||||||
|
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].numOfClasses == 6)
|
||||||
|
assert(loadedXGBoostModel.getFeaturesCol == "features")
|
||||||
|
assert(loadedXGBoostModel.getLabelCol == "label")
|
||||||
|
assert(loadedXGBoostModel.getPredictionCol == "prediction")
|
||||||
|
}
|
||||||
|
|
||||||
|
test("copy and predict ClassificationModel") {
|
||||||
|
import DataUtils._
|
||||||
|
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
|
||||||
|
val testRDD = sc.parallelize(Classification.test).map(_.features)
|
||||||
|
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||||
|
"objective" -> "binary:logistic")
|
||||||
|
val model = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
|
||||||
|
testCopy(model, testRDD)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("copy and predict RegressionModel") {
|
||||||
|
import DataUtils._
|
||||||
|
val trainingRDD = sc.parallelize(Regression.train).map(_.asML)
|
||||||
|
val testRDD = sc.parallelize(Regression.test).map(_.features)
|
||||||
|
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||||
|
"objective" -> "reg:linear")
|
||||||
|
val model = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
|
||||||
|
testCopy(model, testRDD)
|
||||||
|
}
|
||||||
|
|
||||||
|
private def testCopy(model: XGBoostModel, testRDD: RDD[Vector]): Unit = {
|
||||||
|
val modelCopy = model.copy(ParamMap.empty)
|
||||||
|
modelCopy.summary // Ensure no exception.
|
||||||
|
|
||||||
|
val expected = model.predict(testRDD).collect
|
||||||
|
assert(modelCopy.predict(testRDD).collect === expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -63,6 +63,12 @@ XGB_EXTERN_C int XGBoost4jCallbackDataIterNext(
|
|||||||
if (jenv->CallBooleanMethod(jiter, hasNext)) {
|
if (jenv->CallBooleanMethod(jiter, hasNext)) {
|
||||||
ret_value = 1;
|
ret_value = 1;
|
||||||
jobject batch = jenv->CallObjectMethod(jiter, next);
|
jobject batch = jenv->CallObjectMethod(jiter, next);
|
||||||
|
if (batch == nullptr) {
|
||||||
|
CHECK(jenv->ExceptionOccurred());
|
||||||
|
jenv->ExceptionDescribe();
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
jclass batchClass = jenv->GetObjectClass(batch);
|
jclass batchClass = jenv->GetObjectClass(batch);
|
||||||
jlongArray joffset = (jlongArray)jenv->GetObjectField(
|
jlongArray joffset = (jlongArray)jenv->GetObjectField(
|
||||||
batch, jenv->GetFieldID(batchClass, "rowOffset", "[J"));
|
batch, jenv->GetFieldID(batchClass, "rowOffset", "[J"));
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user