diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala index a37a3901f..77683e914 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala @@ -463,7 +463,6 @@ object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel] val dataPath = new Path(path, "data").toString val internalPath = new Path(dataPath, "XGBoostClassificationModel") val outputStream = internalPath.getFileSystem(sc.hadoopConfiguration).create(internalPath) - outputStream.writeInt(instance.numClasses) instance._booster.saveModel(outputStream) outputStream.close() } @@ -477,13 +476,22 @@ object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel] override def load(path: String): XGBoostClassificationModel = { implicit val sc = super.sparkSession.sparkContext - val metadata = DefaultXGBoostParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString val internalPath = new Path(dataPath, "XGBoostClassificationModel") val dataInStream = internalPath.getFileSystem(sc.hadoopConfiguration).open(internalPath) - val numClasses = dataInStream.readInt() + + // The xgboostVersion in the meta can specify if the model is the old xgboost in-compatible + // or the new xgboost compatible. + val numClasses = metadata.xgboostVersion.map { _ => + implicit val format = DefaultFormats + // For binary:logistic, the numClass parameter can't be set to 2 or not be set. + // For multi:softprob or multi:softmax, the numClass parameter must be set correctly, + // or else, XGBoost will throw exception. + // So it's safe to get numClass from meta data. + (metadata.params \ "numClass").extractOpt[Int].getOrElse(2) + }.getOrElse(dataInStream.readInt()) val booster = SXGBoost.loadModel(dataInStream) val model = new XGBoostClassificationModel(metadata.uid, numClasses, booster) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/DefaultXGBoostParamsReader.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/DefaultXGBoostParamsReader.scala index d7d4fca77..9fc644b8f 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/DefaultXGBoostParamsReader.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/DefaultXGBoostParamsReader.scala @@ -51,7 +51,8 @@ private[spark] object DefaultXGBoostParamsReader { sparkVersion: String, params: JValue, metadata: JValue, - metadataJson: String) { + metadataJson: String, + xgboostVersion: Option[String] = None) { /** * Get the JSON value of the [[org.apache.spark.ml.param.Param]] of the given name. @@ -108,8 +109,8 @@ private[spark] object DefaultXGBoostParamsReader { require(className == expectedClassName, s"Error loading metadata: Expected class name" + s" $expectedClassName but found class name $className") } - - Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr) + val xgboostVersion = (metadata \ "xgboostVersion").extractOpt[String] + Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr, xgboostVersion) } private def handleBrokenlyChangedValue[T](paramName: String, value: T): T = { diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/DefaultXGBoostParamsWriter.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/DefaultXGBoostParamsWriter.scala index 92769d010..38aa814c2 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/DefaultXGBoostParamsWriter.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/DefaultXGBoostParamsWriter.scala @@ -22,8 +22,8 @@ import org.apache.spark.SparkContext import org.apache.spark.ml.param.{ParamPair, Params} import org.json4s.jackson.JsonMethods._ import org.json4s.{JArray, JBool, JDouble, JField, JInt, JNothing, JObject, JString, JValue} - import JsonDSLXGBoost._ +import ml.dmlc.xgboost4j.scala.spark // This originates from apache-spark DefaultPramsWriter copy paste private[spark] object DefaultXGBoostParamsWriter { @@ -78,6 +78,7 @@ private[spark] object DefaultXGBoostParamsWriter { ("timestamp" -> System.currentTimeMillis()) ~ ("sparkVersion" -> sc.version) ~ ("uid" -> uid) ~ + ("xgboostVersion" -> spark.VERSION) ~ ("paramMap" -> jsonParams) val metadata = extraMetadata match { case Some(jObject) => diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala index 0fa851f57..4abd464ad 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala @@ -16,16 +16,19 @@ package ml.dmlc.xgboost4j.scala.spark +import java.io.{File, FileInputStream} + import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost} import org.apache.spark.ml.linalg._ import org.apache.spark.sql._ import org.scalatest.FunSuite +import org.apache.commons.io.IOUtils import org.apache.spark.Partitioner import org.apache.spark.ml.feature.VectorAssembler -class XGBoostClassifierSuite extends FunSuite with PerTest { +class XGBoostClassifierSuite extends FunSuite with PerTest with TmpFolderPerSuite { protected val treeMethod: String = "auto" @@ -391,4 +394,37 @@ class XGBoostClassifierSuite extends FunSuite with PerTest { df1.show() } + test("XGBoostClassificationModel should be compatible") { + val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "multi:softprob", "num_class" -> "6", "num_round" -> 5, + "num_workers" -> numWorkers, "tree_method" -> treeMethod) + val trainingDF = buildDataFrame(MultiClassification.train) + val xgb = new XGBoostClassifier(paramMap) + val model = xgb.fit(trainingDF) + val modelPath = new File(tempDir.toFile, "xgbc").getPath + model.write.overwrite().save(modelPath) + val nativeModelPath = new File(tempDir.toFile, "nativeModel").getPath + model.nativeBooster.saveModel(nativeModelPath) + + assert(compareTwoFiles(new File(modelPath, "data/XGBoostClassificationModel").getPath, + nativeModelPath)) + } + + private def compareTwoFiles(lhs: String, rhs: String): Boolean = { + withResource(new FileInputStream(lhs)) { lfis => + withResource(new FileInputStream(rhs)) { rfis => + IOUtils.contentEquals(lfis, rfis) + } + } + } + + /** Executes the provided code block and then closes the resource */ + private def withResource[T <: AutoCloseable, V](r: T)(block: T => V): V = { + try { + block(r) + } finally { + r.close() + } + } + }