[jvm-packages] (xgboost-spark) preserving num_class across save & load (#2742)
* [bugfix] (xgboost-spark) preserving num_class across save & load * add testcase for save & load of multiclass model
This commit is contained in:
parent
c09204fa70
commit
d570337262
@ -331,8 +331,13 @@ object XGBoost extends Serializable {
|
|||||||
val isClsTask = isClassificationTask(params)
|
val isClsTask = isClassificationTask(params)
|
||||||
val trackerReturnVal = tracker.waitFor(0L)
|
val trackerReturnVal = tracker.waitFor(0L)
|
||||||
logger.info(s"Rabit returns with exit code $trackerReturnVal")
|
logger.info(s"Rabit returns with exit code $trackerReturnVal")
|
||||||
postTrackerReturnProcessing(trackerReturnVal, boosters, overriddenParams, sparkJobThread,
|
val model = postTrackerReturnProcessing(trackerReturnVal, boosters, overriddenParams,
|
||||||
isClsTask)
|
sparkJobThread, isClsTask)
|
||||||
|
if (isClsTask){
|
||||||
|
model.asInstanceOf[XGBoostClassificationModel].numOfClasses =
|
||||||
|
params.getOrElse("num_class", "2").toString.toInt
|
||||||
|
}
|
||||||
|
model
|
||||||
} finally {
|
} finally {
|
||||||
tracker.stop()
|
tracker.stop()
|
||||||
}
|
}
|
||||||
@ -389,6 +394,7 @@ object XGBoost extends Serializable {
|
|||||||
modelType match {
|
modelType match {
|
||||||
case "_cls_" =>
|
case "_cls_" =>
|
||||||
val rawPredictionCol = dataInStream.readUTF()
|
val rawPredictionCol = dataInStream.readUTF()
|
||||||
|
val numClasses = dataInStream.readInt()
|
||||||
val thresholdLength = dataInStream.readInt()
|
val thresholdLength = dataInStream.readInt()
|
||||||
var thresholds: Array[Double] = null
|
var thresholds: Array[Double] = null
|
||||||
if (thresholdLength != -1) {
|
if (thresholdLength != -1) {
|
||||||
@ -403,6 +409,7 @@ object XGBoost extends Serializable {
|
|||||||
if (thresholdLength != -1) {
|
if (thresholdLength != -1) {
|
||||||
xgBoostModel.setThresholds(thresholds)
|
xgBoostModel.setThresholds(thresholds)
|
||||||
}
|
}
|
||||||
|
xgBoostModel.asInstanceOf[XGBoostClassificationModel].numOfClasses = numClasses
|
||||||
xgBoostModel
|
xgBoostModel
|
||||||
case "_reg_" =>
|
case "_reg_" =>
|
||||||
val xgBoostModel = new XGBoostRegressionModel(SXGBoost.loadModel(dataInStream))
|
val xgBoostModel = new XGBoostRegressionModel(SXGBoost.loadModel(dataInStream))
|
||||||
|
|||||||
@ -305,6 +305,7 @@ abstract class XGBoostModel(protected var _booster: Booster)
|
|||||||
outputStream.writeUTF("_cls_")
|
outputStream.writeUTF("_cls_")
|
||||||
saveGeneralModelParam(outputStream)
|
saveGeneralModelParam(outputStream)
|
||||||
outputStream.writeUTF(model.getRawPredictionCol)
|
outputStream.writeUTF(model.getRawPredictionCol)
|
||||||
|
outputStream.writeInt(model.numClasses)
|
||||||
// threshold
|
// threshold
|
||||||
// threshold length
|
// threshold length
|
||||||
if (!isDefined(model.thresholds)) {
|
if (!isDefined(model.thresholds)) {
|
||||||
|
|||||||
@ -192,6 +192,7 @@ class XGBoostDFSuite extends FunSuite with PerTest {
|
|||||||
val model = XGBoost.trainWithDataFrame(trainingDF, paramMap, round = 5, nWorkers = numWorkers)
|
val model = XGBoost.trainWithDataFrame(trainingDF, paramMap, round = 5, nWorkers = numWorkers)
|
||||||
assert(model.get[Double](model.eta).get == 0.1)
|
assert(model.get[Double](model.eta).get == 0.1)
|
||||||
assert(model.get[Int](model.maxDepth).get == 6)
|
assert(model.get[Int](model.maxDepth).get == 6)
|
||||||
|
assert(model.asInstanceOf[XGBoostClassificationModel].numOfClasses == 6)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("test use base margin") {
|
test("test use base margin") {
|
||||||
|
|||||||
@ -286,7 +286,7 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
|||||||
import DataUtils._
|
import DataUtils._
|
||||||
val tempDir = Files.createTempDirectory("xgboosttest-")
|
val tempDir = Files.createTempDirectory("xgboosttest-")
|
||||||
val tempFile = Files.createTempFile(tempDir, "", "")
|
val tempFile = Files.createTempFile(tempDir, "", "")
|
||||||
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
|
var trainingRDD = sc.parallelize(Classification.train).map(_.asML)
|
||||||
var paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
var paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "reg:linear")
|
"objective" -> "reg:linear")
|
||||||
// validate regression model
|
// validate regression model
|
||||||
@ -318,6 +318,26 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
|||||||
assert(loadedXGBoostModel.getFeaturesCol == "features")
|
assert(loadedXGBoostModel.getFeaturesCol == "features")
|
||||||
assert(loadedXGBoostModel.getLabelCol == "label")
|
assert(loadedXGBoostModel.getLabelCol == "label")
|
||||||
assert(loadedXGBoostModel.getPredictionCol == "prediction")
|
assert(loadedXGBoostModel.getPredictionCol == "prediction")
|
||||||
|
// (multiclass) classification model
|
||||||
|
trainingRDD = sc.parallelize(MultiClassification.train).map(_.asML)
|
||||||
|
paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
|
"objective" -> "multi:softmax", "num_class" -> "6")
|
||||||
|
xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
||||||
|
nWorkers = numWorkers, useExternalMemory = false)
|
||||||
|
xgBoostModel.asInstanceOf[XGBoostClassificationModel].setRawPredictionCol("raw_col")
|
||||||
|
xgBoostModel.asInstanceOf[XGBoostClassificationModel].setThresholds(
|
||||||
|
Array(0.5, 0.5, 0.5, 0.5, 0.5, 0.5))
|
||||||
|
xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath)
|
||||||
|
loadedXGBoostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath)
|
||||||
|
assert(loadedXGBoostModel.isInstanceOf[XGBoostClassificationModel])
|
||||||
|
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getRawPredictionCol ==
|
||||||
|
"raw_col")
|
||||||
|
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getThresholds.deep ==
|
||||||
|
Array(0.5, 0.5, 0.5, 0.5, 0.5, 0.5).deep)
|
||||||
|
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].numOfClasses == 6)
|
||||||
|
assert(loadedXGBoostModel.getFeaturesCol == "features")
|
||||||
|
assert(loadedXGBoostModel.getLabelCol == "label")
|
||||||
|
assert(loadedXGBoostModel.getPredictionCol == "prediction")
|
||||||
}
|
}
|
||||||
|
|
||||||
test("test use groupData") {
|
test("test use groupData") {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user