[jvm-packages] (xgboost-spark) preserving num_class across save & load (#2742)

* [bugfix] (xgboost-spark) preserving num_class across save & load

* add testcase for save & load of multiclass model
This commit is contained in:
Sergei Lebedev
2017-09-24 16:03:30 +02:00
committed by GitHub
parent c09204fa70
commit d570337262
4 changed files with 32 additions and 3 deletions

View File

@@ -331,8 +331,13 @@ object XGBoost extends Serializable {
val isClsTask = isClassificationTask(params)
val trackerReturnVal = tracker.waitFor(0L)
logger.info(s"Rabit returns with exit code $trackerReturnVal")
postTrackerReturnProcessing(trackerReturnVal, boosters, overriddenParams, sparkJobThread,
isClsTask)
val model = postTrackerReturnProcessing(trackerReturnVal, boosters, overriddenParams,
sparkJobThread, isClsTask)
if (isClsTask){
model.asInstanceOf[XGBoostClassificationModel].numOfClasses =
params.getOrElse("num_class", "2").toString.toInt
}
model
} finally {
tracker.stop()
}
@@ -389,6 +394,7 @@ object XGBoost extends Serializable {
modelType match {
case "_cls_" =>
val rawPredictionCol = dataInStream.readUTF()
val numClasses = dataInStream.readInt()
val thresholdLength = dataInStream.readInt()
var thresholds: Array[Double] = null
if (thresholdLength != -1) {
@@ -403,6 +409,7 @@ object XGBoost extends Serializable {
if (thresholdLength != -1) {
xgBoostModel.setThresholds(thresholds)
}
xgBoostModel.asInstanceOf[XGBoostClassificationModel].numOfClasses = numClasses
xgBoostModel
case "_reg_" =>
val xgBoostModel = new XGBoostRegressionModel(SXGBoost.loadModel(dataInStream))

View File

@@ -305,6 +305,7 @@ abstract class XGBoostModel(protected var _booster: Booster)
outputStream.writeUTF("_cls_")
saveGeneralModelParam(outputStream)
outputStream.writeUTF(model.getRawPredictionCol)
outputStream.writeInt(model.numClasses)
// threshold
// threshold length
if (!isDefined(model.thresholds)) {