[jvm-packages] add format option when saving a model (#7940)

This commit is contained in:
Bobby Wang 2022-05-30 15:49:59 +08:00 committed by GitHub
parent cc6d57aa0d
commit 6275cdc486
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 153 additions and 30 deletions

View File

@ -30,6 +30,8 @@ import org.apache.spark.sql.functions._
import org.json4s.DefaultFormats
import scala.collection.{Iterator, mutable}
import ml.dmlc.xgboost4j.scala.spark.utils.XGBoostWriter
import org.apache.spark.sql.types.StructType
class XGBoostClassifier (
@ -462,7 +464,8 @@ object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel]
override def load(path: String): XGBoostClassificationModel = super.load(path)
private[XGBoostClassificationModel]
class XGBoostClassificationModelWriter(instance: XGBoostClassificationModel) extends MLWriter {
class XGBoostClassificationModelWriter(instance: XGBoostClassificationModel)
extends XGBoostWriter {
override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
@ -474,7 +477,7 @@ object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel]
val dataPath = new Path(path, "data").toString
val internalPath = new Path(dataPath, "XGBoostClassificationModel")
val outputStream = internalPath.getFileSystem(sc.hadoopConfiguration).create(internalPath)
instance._booster.saveModel(outputStream)
instance._booster.saveModel(outputStream, getModelFormat())
outputStream.close()
}
}

View File

@ -19,6 +19,7 @@ package ml.dmlc.xgboost4j.scala.spark
import scala.collection.{Iterator, mutable}
import ml.dmlc.xgboost4j.scala.spark.params.{DefaultXGBoostParamsReader, _}
import ml.dmlc.xgboost4j.scala.spark.utils.XGBoostWriter
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
import org.apache.hadoop.fs.Path
@ -379,7 +380,7 @@ object XGBoostRegressionModel extends MLReadable[XGBoostRegressionModel] {
override def load(path: String): XGBoostRegressionModel = super.load(path)
private[XGBoostRegressionModel]
class XGBoostRegressionModelWriter(instance: XGBoostRegressionModel) extends MLWriter {
class XGBoostRegressionModelWriter(instance: XGBoostRegressionModel) extends XGBoostWriter {
override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
@ -390,7 +391,7 @@ object XGBoostRegressionModel extends MLReadable[XGBoostRegressionModel] {
val dataPath = new Path(path, "data").toString
val internalPath = new Path(dataPath, "XGBoostRegressionModel")
val outputStream = internalPath.getFileSystem(sc.hadoopConfiguration).create(internalPath)
instance._booster.saveModel(outputStream)
instance._booster.saveModel(outputStream, getModelFormat())
outputStream.close()
}
}

View File

@ -0,0 +1,31 @@
/*
Copyright (c) 2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark.utils
import ml.dmlc.xgboost4j.java.{Booster => JBooster}
import org.apache.spark.ml.util.MLWriter
private[spark] abstract class XGBoostWriter extends MLWriter {
/** Currently it's using the "deprecated" format as
* default, which will be changed into `ubj` in future releases. */
def getModelFormat(): String = {
optionMap.getOrElse("format", JBooster.DEFAULT_FORMAT)
}
}

View File

@ -16,16 +16,18 @@
package ml.dmlc.xgboost4j.scala.spark
import java.io.File
import java.io.{File, FileInputStream}
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.spark.SparkContext
import org.apache.spark.sql._
import org.scalatest.{BeforeAndAfterEach, FunSuite}
import scala.math.min
import scala.util.Random
import org.apache.commons.io.IOUtils
trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
protected val numWorkers: Int = min(Runtime.getRuntime.availableProcessors(), 4)
@ -105,4 +107,22 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
ss.createDataFrame(sc.parallelize(it.toList, numPartitions))
.toDF("id", "label", "features", "group")
}
protected def compareTwoFiles(lhs: String, rhs: String): Boolean = {
withResource(new FileInputStream(lhs)) { lfis =>
withResource(new FileInputStream(rhs)) { rfis =>
IOUtils.contentEquals(lfis, rfis)
}
}
}
/** Executes the provided code block and then closes the resource */
protected def withResource[T <: AutoCloseable, V](r: T)(block: T => V): V = {
try {
block(r)
} finally {
r.close()
}
}
}

View File

@ -429,30 +429,29 @@ class XGBoostClassifierSuite extends FunSuite with PerTest with TmpFolderPerSuit
val trainingDF = buildDataFrame(MultiClassification.train)
val xgb = new XGBoostClassifier(paramMap)
val model = xgb.fit(trainingDF)
val modelPath = new File(tempDir.toFile, "xgbc").getPath
model.write.overwrite().save(modelPath)
val nativeModelPath = new File(tempDir.toFile, "nativeModel").getPath
model.nativeBooster.saveModel(nativeModelPath)
model.write.option("format", "json").save(modelPath)
val nativeJsonModelPath = new File(tempDir.toFile, "nativeModel.json").getPath
model.nativeBooster.saveModel(nativeJsonModelPath)
assert(compareTwoFiles(new File(modelPath, "data/XGBoostClassificationModel").getPath,
nativeModelPath))
}
nativeJsonModelPath))
private def compareTwoFiles(lhs: String, rhs: String): Boolean = {
withResource(new FileInputStream(lhs)) { lfis =>
withResource(new FileInputStream(rhs)) { rfis =>
IOUtils.contentEquals(lfis, rfis)
}
}
}
// test default "deprecated"
val modelUbjPath = new File(tempDir.toFile, "xgbcUbj").getPath
model.write.save(modelUbjPath)
val nativeDeprecatedModelPath = new File(tempDir.toFile, "nativeModel").getPath
model.nativeBooster.saveModel(nativeDeprecatedModelPath)
assert(compareTwoFiles(new File(modelUbjPath, "data/XGBoostClassificationModel").getPath,
nativeDeprecatedModelPath))
/** Executes the provided code block and then closes the resource */
private def withResource[T <: AutoCloseable, V](r: T)(block: T => V): V = {
try {
block(r)
} finally {
r.close()
}
// json file should be indifferent with ubj file
val modelJsonPath = new File(tempDir.toFile, "xgbcJson").getPath
model.write.option("format", "json").save(modelJsonPath)
val nativeUbjModelPath = new File(tempDir.toFile, "nativeModel1.ubj").getPath
model.nativeBooster.saveModel(nativeUbjModelPath)
assert(!compareTwoFiles(new File(modelJsonPath, "data/XGBoostClassificationModel").getPath,
nativeUbjModelPath))
}
}

View File

@ -16,6 +16,8 @@
package ml.dmlc.xgboost4j.scala.spark
import java.io.File
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
import org.apache.spark.ml.linalg.{Vector, Vectors}
@ -25,7 +27,7 @@ import org.scalatest.FunSuite
import org.apache.spark.ml.feature.VectorAssembler
class XGBoostRegressorSuite extends FunSuite with PerTest {
class XGBoostRegressorSuite extends FunSuite with PerTest with TmpFolderPerSuite {
protected val treeMethod: String = "auto"
test("XGBoost-Spark XGBoostRegressor output should match XGBoost4j") {
@ -310,4 +312,42 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
val df1 = model.transform(vectorizedInput)
df1.show()
}
test("XGBoostRegressionModel should be compatible") {
val trainingDF = buildDataFrame(Regression.train)
val paramMap = Map(
"eta" -> "1",
"max_depth" -> "6",
"silent" -> "1",
"objective" -> "reg:squarederror",
"num_round" -> 5,
"tree_method" -> treeMethod,
"num_workers" -> numWorkers)
val model = new XGBoostRegressor(paramMap).fit(trainingDF)
val modelPath = new File(tempDir.toFile, "xgbc").getPath
model.write.option("format", "json").save(modelPath)
val nativeJsonModelPath = new File(tempDir.toFile, "nativeModel.json").getPath
model.nativeBooster.saveModel(nativeJsonModelPath)
assert(compareTwoFiles(new File(modelPath, "data/XGBoostRegressionModel").getPath,
nativeJsonModelPath))
// test default "deprecated"
val modelUbjPath = new File(tempDir.toFile, "xgbcUbj").getPath
model.write.save(modelUbjPath)
val nativeDeprecatedModelPath = new File(tempDir.toFile, "nativeModel").getPath
model.nativeBooster.saveModel(nativeDeprecatedModelPath)
assert(compareTwoFiles(new File(modelUbjPath, "data/XGBoostRegressionModel").getPath,
nativeDeprecatedModelPath))
// json file should be indifferent with ubj file
val modelJsonPath = new File(tempDir.toFile, "xgbcJson").getPath
model.write.option("format", "json").save(modelJsonPath)
val nativeUbjModelPath = new File(tempDir.toFile, "nativeModel1.ubj").getPath
model.nativeBooster.saveModel(nativeUbjModelPath)
assert(!compareTwoFiles(new File(modelJsonPath, "data/XGBoostRegressionModel").getPath,
nativeUbjModelPath))
}
}

View File

@ -34,6 +34,7 @@ import org.apache.commons.logging.LogFactory;
* Booster for xgboost, this is a model API that support interactive build of a XGBoost Model
*/
public class Booster implements Serializable, KryoSerializable {
public static final String DEFAULT_FORMAT = "deprecated";
private static final Log logger = LogFactory.getLog(Booster.class);
// handle to the booster.
private long handle = 0;
@ -391,7 +392,22 @@ public class Booster implements Serializable, KryoSerializable {
* @param out The output stream
*/
public void saveModel(OutputStream out) throws XGBoostError, IOException {
out.write(this.toByteArray());
saveModel(out, DEFAULT_FORMAT);
}
/**
* Save the model to file opened as output stream.
* The model format is compatible with other xgboost bindings.
* The output stream can only save one xgboost model.
* This function will close the OutputStream after the save.
*
* @param out The output stream
* @param format The model format (ubj, json, deprecated)
* @throws XGBoostError
* @throws IOException
*/
public void saveModel(OutputStream out, String format) throws XGBoostError, IOException {
out.write(this.toByteArray(format));
out.close();
}
@ -643,7 +659,7 @@ public class Booster implements Serializable, KryoSerializable {
* @throws XGBoostError native error
*/
public byte[] toByteArray() throws XGBoostError {
return this.toByteArray("deprecated");
return this.toByteArray(DEFAULT_FORMAT);
}
/**

View File

@ -207,6 +207,7 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
def saveModel(modelPath: String): Unit = {
booster.saveModel(modelPath)
}
/**
* save model to Output stream
*
@ -216,6 +217,18 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
def saveModel(out: java.io.OutputStream): Unit = {
booster.saveModel(out)
}
/**
* save model to Output stream
* @param out output stream
* @param format the supported model format, (json, ubj, deprecated)
* @throws ml.dmlc.xgboost4j.java.XGBoostError
*/
@throws(classOf[XGBoostError])
def saveModel(out: java.io.OutputStream, format: String): Unit = {
booster.saveModel(out, format)
}
/**
* Dump model as Array of string
*
@ -315,7 +328,7 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
*/
@throws(classOf[XGBoostError])
def toByteArray: Array[Byte] = {
booster.toByteArray
booster.toByteArray()
}
/**