From 6275cdc4866d65d1e723971e8af55f174a63cfce Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Mon, 30 May 2022 15:49:59 +0800 Subject: [PATCH] [jvm-packages] add format option when saving a model (#7940) --- .../scala/spark/XGBoostClassifier.scala | 7 +++- .../scala/spark/XGBoostRegressor.scala | 5 ++- .../scala/spark/utils/XGBoostReadWrite.scala | 31 ++++++++++++++ .../dmlc/xgboost4j/scala/spark/PerTest.scala | 24 ++++++++++- .../scala/spark/XGBoostClassifierSuite.scala | 39 +++++++++-------- .../scala/spark/XGBoostRegressorSuite.scala | 42 ++++++++++++++++++- .../java/ml/dmlc/xgboost4j/java/Booster.java | 20 ++++++++- .../ml/dmlc/xgboost4j/scala/Booster.scala | 15 ++++++- 8 files changed, 153 insertions(+), 30 deletions(-) create mode 100644 jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/utils/XGBoostReadWrite.scala diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala index 2f6827787..59be8b1c7 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala @@ -30,6 +30,8 @@ import org.apache.spark.sql.functions._ import org.json4s.DefaultFormats import scala.collection.{Iterator, mutable} +import ml.dmlc.xgboost4j.scala.spark.utils.XGBoostWriter + import org.apache.spark.sql.types.StructType class XGBoostClassifier ( @@ -462,7 +464,8 @@ object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel] override def load(path: String): XGBoostClassificationModel = super.load(path) private[XGBoostClassificationModel] - class XGBoostClassificationModelWriter(instance: XGBoostClassificationModel) extends MLWriter { + class XGBoostClassificationModelWriter(instance: XGBoostClassificationModel) + extends XGBoostWriter { override protected def saveImpl(path: String): Unit = { // Save metadata and Params @@ -474,7 +477,7 @@ object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel] val dataPath = new Path(path, "data").toString val internalPath = new Path(dataPath, "XGBoostClassificationModel") val outputStream = internalPath.getFileSystem(sc.hadoopConfiguration).create(internalPath) - instance._booster.saveModel(outputStream) + instance._booster.saveModel(outputStream, getModelFormat()) outputStream.close() } } diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala index 0402beb62..fcdd347e4 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala @@ -19,6 +19,7 @@ package ml.dmlc.xgboost4j.scala.spark import scala.collection.{Iterator, mutable} import ml.dmlc.xgboost4j.scala.spark.params.{DefaultXGBoostParamsReader, _} +import ml.dmlc.xgboost4j.scala.spark.utils.XGBoostWriter import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost} import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait} import org.apache.hadoop.fs.Path @@ -379,7 +380,7 @@ object XGBoostRegressionModel extends MLReadable[XGBoostRegressionModel] { override def load(path: String): XGBoostRegressionModel = super.load(path) private[XGBoostRegressionModel] - class XGBoostRegressionModelWriter(instance: XGBoostRegressionModel) extends MLWriter { + class XGBoostRegressionModelWriter(instance: XGBoostRegressionModel) extends XGBoostWriter { override protected def saveImpl(path: String): Unit = { // Save metadata and Params @@ -390,7 +391,7 @@ object XGBoostRegressionModel extends MLReadable[XGBoostRegressionModel] { val dataPath = new Path(path, "data").toString val internalPath = new Path(dataPath, "XGBoostRegressionModel") val outputStream = internalPath.getFileSystem(sc.hadoopConfiguration).create(internalPath) - instance._booster.saveModel(outputStream) + instance._booster.saveModel(outputStream, getModelFormat()) outputStream.close() } } diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/utils/XGBoostReadWrite.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/utils/XGBoostReadWrite.scala new file mode 100644 index 000000000..0fbd36bf3 --- /dev/null +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/utils/XGBoostReadWrite.scala @@ -0,0 +1,31 @@ +/* + Copyright (c) 2022 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala.spark.utils + +import ml.dmlc.xgboost4j.java.{Booster => JBooster} + +import org.apache.spark.ml.util.MLWriter + +private[spark] abstract class XGBoostWriter extends MLWriter { + + /** Currently it's using the "deprecated" format as + * default, which will be changed into `ubj` in future releases. */ + def getModelFormat(): String = { + optionMap.getOrElse("format", JBooster.DEFAULT_FORMAT) + } + +} diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala index f5775bc4d..512dbdb71 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala @@ -16,16 +16,18 @@ package ml.dmlc.xgboost4j.scala.spark -import java.io.File +import java.io.{File, FileInputStream} import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} + import org.apache.spark.SparkContext import org.apache.spark.sql._ import org.scalatest.{BeforeAndAfterEach, FunSuite} - import scala.math.min import scala.util.Random +import org.apache.commons.io.IOUtils + trait PerTest extends BeforeAndAfterEach { self: FunSuite => protected val numWorkers: Int = min(Runtime.getRuntime.availableProcessors(), 4) @@ -105,4 +107,22 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite => ss.createDataFrame(sc.parallelize(it.toList, numPartitions)) .toDF("id", "label", "features", "group") } + + + protected def compareTwoFiles(lhs: String, rhs: String): Boolean = { + withResource(new FileInputStream(lhs)) { lfis => + withResource(new FileInputStream(rhs)) { rfis => + IOUtils.contentEquals(lfis, rfis) + } + } + } + + /** Executes the provided code block and then closes the resource */ + protected def withResource[T <: AutoCloseable, V](r: T)(block: T => V): V = { + try { + block(r) + } finally { + r.close() + } + } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala index 9fe2479e5..e8c29fd52 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala @@ -429,30 +429,29 @@ class XGBoostClassifierSuite extends FunSuite with PerTest with TmpFolderPerSuit val trainingDF = buildDataFrame(MultiClassification.train) val xgb = new XGBoostClassifier(paramMap) val model = xgb.fit(trainingDF) + val modelPath = new File(tempDir.toFile, "xgbc").getPath - model.write.overwrite().save(modelPath) - val nativeModelPath = new File(tempDir.toFile, "nativeModel").getPath - model.nativeBooster.saveModel(nativeModelPath) - + model.write.option("format", "json").save(modelPath) + val nativeJsonModelPath = new File(tempDir.toFile, "nativeModel.json").getPath + model.nativeBooster.saveModel(nativeJsonModelPath) assert(compareTwoFiles(new File(modelPath, "data/XGBoostClassificationModel").getPath, - nativeModelPath)) - } + nativeJsonModelPath)) - private def compareTwoFiles(lhs: String, rhs: String): Boolean = { - withResource(new FileInputStream(lhs)) { lfis => - withResource(new FileInputStream(rhs)) { rfis => - IOUtils.contentEquals(lfis, rfis) - } - } - } + // test default "deprecated" + val modelUbjPath = new File(tempDir.toFile, "xgbcUbj").getPath + model.write.save(modelUbjPath) + val nativeDeprecatedModelPath = new File(tempDir.toFile, "nativeModel").getPath + model.nativeBooster.saveModel(nativeDeprecatedModelPath) + assert(compareTwoFiles(new File(modelUbjPath, "data/XGBoostClassificationModel").getPath, + nativeDeprecatedModelPath)) - /** Executes the provided code block and then closes the resource */ - private def withResource[T <: AutoCloseable, V](r: T)(block: T => V): V = { - try { - block(r) - } finally { - r.close() - } + // json file should be indifferent with ubj file + val modelJsonPath = new File(tempDir.toFile, "xgbcJson").getPath + model.write.option("format", "json").save(modelJsonPath) + val nativeUbjModelPath = new File(tempDir.toFile, "nativeModel1.ubj").getPath + model.nativeBooster.saveModel(nativeUbjModelPath) + assert(!compareTwoFiles(new File(modelJsonPath, "data/XGBoostClassificationModel").getPath, + nativeUbjModelPath)) } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala index a530313b9..4e3d59b25 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala @@ -16,6 +16,8 @@ package ml.dmlc.xgboost4j.scala.spark +import java.io.File + import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost} import org.apache.spark.ml.linalg.{Vector, Vectors} @@ -25,7 +27,7 @@ import org.scalatest.FunSuite import org.apache.spark.ml.feature.VectorAssembler -class XGBoostRegressorSuite extends FunSuite with PerTest { +class XGBoostRegressorSuite extends FunSuite with PerTest with TmpFolderPerSuite { protected val treeMethod: String = "auto" test("XGBoost-Spark XGBoostRegressor output should match XGBoost4j") { @@ -310,4 +312,42 @@ class XGBoostRegressorSuite extends FunSuite with PerTest { val df1 = model.transform(vectorizedInput) df1.show() } + + test("XGBoostRegressionModel should be compatible") { + val trainingDF = buildDataFrame(Regression.train) + val paramMap = Map( + "eta" -> "1", + "max_depth" -> "6", + "silent" -> "1", + "objective" -> "reg:squarederror", + "num_round" -> 5, + "tree_method" -> treeMethod, + "num_workers" -> numWorkers) + + val model = new XGBoostRegressor(paramMap).fit(trainingDF) + + val modelPath = new File(tempDir.toFile, "xgbc").getPath + model.write.option("format", "json").save(modelPath) + val nativeJsonModelPath = new File(tempDir.toFile, "nativeModel.json").getPath + model.nativeBooster.saveModel(nativeJsonModelPath) + assert(compareTwoFiles(new File(modelPath, "data/XGBoostRegressionModel").getPath, + nativeJsonModelPath)) + + // test default "deprecated" + val modelUbjPath = new File(tempDir.toFile, "xgbcUbj").getPath + model.write.save(modelUbjPath) + val nativeDeprecatedModelPath = new File(tempDir.toFile, "nativeModel").getPath + model.nativeBooster.saveModel(nativeDeprecatedModelPath) + assert(compareTwoFiles(new File(modelUbjPath, "data/XGBoostRegressionModel").getPath, + nativeDeprecatedModelPath)) + + // json file should be indifferent with ubj file + val modelJsonPath = new File(tempDir.toFile, "xgbcJson").getPath + model.write.option("format", "json").save(modelJsonPath) + val nativeUbjModelPath = new File(tempDir.toFile, "nativeModel1.ubj").getPath + model.nativeBooster.saveModel(nativeUbjModelPath) + assert(!compareTwoFiles(new File(modelJsonPath, "data/XGBoostRegressionModel").getPath, + nativeUbjModelPath)) + } + } diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java index f08435f3a..ed1a3f5c9 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java @@ -34,6 +34,7 @@ import org.apache.commons.logging.LogFactory; * Booster for xgboost, this is a model API that support interactive build of a XGBoost Model */ public class Booster implements Serializable, KryoSerializable { + public static final String DEFAULT_FORMAT = "deprecated"; private static final Log logger = LogFactory.getLog(Booster.class); // handle to the booster. private long handle = 0; @@ -391,7 +392,22 @@ public class Booster implements Serializable, KryoSerializable { * @param out The output stream */ public void saveModel(OutputStream out) throws XGBoostError, IOException { - out.write(this.toByteArray()); + saveModel(out, DEFAULT_FORMAT); + } + + /** + * Save the model to file opened as output stream. + * The model format is compatible with other xgboost bindings. + * The output stream can only save one xgboost model. + * This function will close the OutputStream after the save. + * + * @param out The output stream + * @param format The model format (ubj, json, deprecated) + * @throws XGBoostError + * @throws IOException + */ + public void saveModel(OutputStream out, String format) throws XGBoostError, IOException { + out.write(this.toByteArray(format)); out.close(); } @@ -643,7 +659,7 @@ public class Booster implements Serializable, KryoSerializable { * @throws XGBoostError native error */ public byte[] toByteArray() throws XGBoostError { - return this.toByteArray("deprecated"); + return this.toByteArray(DEFAULT_FORMAT); } /** diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala index 88f5607d3..a1d122679 100644 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala @@ -207,6 +207,7 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster) def saveModel(modelPath: String): Unit = { booster.saveModel(modelPath) } + /** * save model to Output stream * @@ -216,6 +217,18 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster) def saveModel(out: java.io.OutputStream): Unit = { booster.saveModel(out) } + + /** + * save model to Output stream + * @param out output stream + * @param format the supported model format, (json, ubj, deprecated) + * @throws ml.dmlc.xgboost4j.java.XGBoostError + */ + @throws(classOf[XGBoostError]) + def saveModel(out: java.io.OutputStream, format: String): Unit = { + booster.saveModel(out, format) + } + /** * Dump model as Array of string * @@ -315,7 +328,7 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster) */ @throws(classOf[XGBoostError]) def toByteArray: Array[Byte] = { - booster.toByteArray + booster.toByteArray() } /**