diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index ad4dc10e5..9187e2e41 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -331,8 +331,13 @@ object XGBoost extends Serializable { val isClsTask = isClassificationTask(params) val trackerReturnVal = tracker.waitFor(0L) logger.info(s"Rabit returns with exit code $trackerReturnVal") - postTrackerReturnProcessing(trackerReturnVal, boosters, overriddenParams, sparkJobThread, - isClsTask) + val model = postTrackerReturnProcessing(trackerReturnVal, boosters, overriddenParams, + sparkJobThread, isClsTask) + if (isClsTask){ + model.asInstanceOf[XGBoostClassificationModel].numOfClasses = + params.getOrElse("num_class", "2").toString.toInt + } + model } finally { tracker.stop() } @@ -389,6 +394,7 @@ object XGBoost extends Serializable { modelType match { case "_cls_" => val rawPredictionCol = dataInStream.readUTF() + val numClasses = dataInStream.readInt() val thresholdLength = dataInStream.readInt() var thresholds: Array[Double] = null if (thresholdLength != -1) { @@ -403,6 +409,7 @@ object XGBoost extends Serializable { if (thresholdLength != -1) { xgBoostModel.setThresholds(thresholds) } + xgBoostModel.asInstanceOf[XGBoostClassificationModel].numOfClasses = numClasses xgBoostModel case "_reg_" => val xgBoostModel = new XGBoostRegressionModel(SXGBoost.loadModel(dataInStream)) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala index 8f209eebf..8be830764 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala @@ -305,6 +305,7 @@ abstract class XGBoostModel(protected var _booster: Booster) outputStream.writeUTF("_cls_") saveGeneralModelParam(outputStream) outputStream.writeUTF(model.getRawPredictionCol) + outputStream.writeInt(model.numClasses) // threshold // threshold length if (!isDefined(model.thresholds)) { diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala index 972971efd..d5ac77dbd 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala @@ -192,6 +192,7 @@ class XGBoostDFSuite extends FunSuite with PerTest { val model = XGBoost.trainWithDataFrame(trainingDF, paramMap, round = 5, nWorkers = numWorkers) assert(model.get[Double](model.eta).get == 0.1) assert(model.get[Int](model.maxDepth).get == 6) + assert(model.asInstanceOf[XGBoostClassificationModel].numOfClasses == 6) } test("test use base margin") { diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala index b1db09db7..0b96a5f2c 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala @@ -286,7 +286,7 @@ class XGBoostGeneralSuite extends FunSuite with PerTest { import DataUtils._ val tempDir = Files.createTempDirectory("xgboosttest-") val tempFile = Files.createTempFile(tempDir, "", "") - val trainingRDD = sc.parallelize(Classification.train).map(_.asML) + var trainingRDD = sc.parallelize(Classification.train).map(_.asML) var paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", "objective" -> "reg:linear") // validate regression model @@ -318,6 +318,26 @@ class XGBoostGeneralSuite extends FunSuite with PerTest { assert(loadedXGBoostModel.getFeaturesCol == "features") assert(loadedXGBoostModel.getLabelCol == "label") assert(loadedXGBoostModel.getPredictionCol == "prediction") + // (multiclass) classification model + trainingRDD = sc.parallelize(MultiClassification.train).map(_.asML) + paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "multi:softmax", "num_class" -> "6") + xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5, + nWorkers = numWorkers, useExternalMemory = false) + xgBoostModel.asInstanceOf[XGBoostClassificationModel].setRawPredictionCol("raw_col") + xgBoostModel.asInstanceOf[XGBoostClassificationModel].setThresholds( + Array(0.5, 0.5, 0.5, 0.5, 0.5, 0.5)) + xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath) + loadedXGBoostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath) + assert(loadedXGBoostModel.isInstanceOf[XGBoostClassificationModel]) + assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getRawPredictionCol == + "raw_col") + assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getThresholds.deep == + Array(0.5, 0.5, 0.5, 0.5, 0.5, 0.5).deep) + assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].numOfClasses == 6) + assert(loadedXGBoostModel.getFeaturesCol == "features") + assert(loadedXGBoostModel.getLabelCol == "label") + assert(loadedXGBoostModel.getPredictionCol == "prediction") } test("test use groupData") {