[jvm-packages] (xgboost-spark) preserving num_class across save & load (#2742)
* [bugfix] (xgboost-spark) preserving num_class across save & load * add testcase for save & load of multiclass model
This commit is contained in:
@@ -331,8 +331,13 @@ object XGBoost extends Serializable {
|
||||
val isClsTask = isClassificationTask(params)
|
||||
val trackerReturnVal = tracker.waitFor(0L)
|
||||
logger.info(s"Rabit returns with exit code $trackerReturnVal")
|
||||
postTrackerReturnProcessing(trackerReturnVal, boosters, overriddenParams, sparkJobThread,
|
||||
isClsTask)
|
||||
val model = postTrackerReturnProcessing(trackerReturnVal, boosters, overriddenParams,
|
||||
sparkJobThread, isClsTask)
|
||||
if (isClsTask){
|
||||
model.asInstanceOf[XGBoostClassificationModel].numOfClasses =
|
||||
params.getOrElse("num_class", "2").toString.toInt
|
||||
}
|
||||
model
|
||||
} finally {
|
||||
tracker.stop()
|
||||
}
|
||||
@@ -389,6 +394,7 @@ object XGBoost extends Serializable {
|
||||
modelType match {
|
||||
case "_cls_" =>
|
||||
val rawPredictionCol = dataInStream.readUTF()
|
||||
val numClasses = dataInStream.readInt()
|
||||
val thresholdLength = dataInStream.readInt()
|
||||
var thresholds: Array[Double] = null
|
||||
if (thresholdLength != -1) {
|
||||
@@ -403,6 +409,7 @@ object XGBoost extends Serializable {
|
||||
if (thresholdLength != -1) {
|
||||
xgBoostModel.setThresholds(thresholds)
|
||||
}
|
||||
xgBoostModel.asInstanceOf[XGBoostClassificationModel].numOfClasses = numClasses
|
||||
xgBoostModel
|
||||
case "_reg_" =>
|
||||
val xgBoostModel = new XGBoostRegressionModel(SXGBoost.loadModel(dataInStream))
|
||||
|
||||
@@ -305,6 +305,7 @@ abstract class XGBoostModel(protected var _booster: Booster)
|
||||
outputStream.writeUTF("_cls_")
|
||||
saveGeneralModelParam(outputStream)
|
||||
outputStream.writeUTF(model.getRawPredictionCol)
|
||||
outputStream.writeInt(model.numClasses)
|
||||
// threshold
|
||||
// threshold length
|
||||
if (!isDefined(model.thresholds)) {
|
||||
|
||||
Reference in New Issue
Block a user