[jvm-packages] add format option when saving a model (#7940)
This commit is contained in:
parent
cc6d57aa0d
commit
6275cdc486
@ -30,6 +30,8 @@ import org.apache.spark.sql.functions._
|
|||||||
import org.json4s.DefaultFormats
|
import org.json4s.DefaultFormats
|
||||||
import scala.collection.{Iterator, mutable}
|
import scala.collection.{Iterator, mutable}
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.scala.spark.utils.XGBoostWriter
|
||||||
|
|
||||||
import org.apache.spark.sql.types.StructType
|
import org.apache.spark.sql.types.StructType
|
||||||
|
|
||||||
class XGBoostClassifier (
|
class XGBoostClassifier (
|
||||||
@ -462,7 +464,8 @@ object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel]
|
|||||||
override def load(path: String): XGBoostClassificationModel = super.load(path)
|
override def load(path: String): XGBoostClassificationModel = super.load(path)
|
||||||
|
|
||||||
private[XGBoostClassificationModel]
|
private[XGBoostClassificationModel]
|
||||||
class XGBoostClassificationModelWriter(instance: XGBoostClassificationModel) extends MLWriter {
|
class XGBoostClassificationModelWriter(instance: XGBoostClassificationModel)
|
||||||
|
extends XGBoostWriter {
|
||||||
|
|
||||||
override protected def saveImpl(path: String): Unit = {
|
override protected def saveImpl(path: String): Unit = {
|
||||||
// Save metadata and Params
|
// Save metadata and Params
|
||||||
@ -474,7 +477,7 @@ object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel]
|
|||||||
val dataPath = new Path(path, "data").toString
|
val dataPath = new Path(path, "data").toString
|
||||||
val internalPath = new Path(dataPath, "XGBoostClassificationModel")
|
val internalPath = new Path(dataPath, "XGBoostClassificationModel")
|
||||||
val outputStream = internalPath.getFileSystem(sc.hadoopConfiguration).create(internalPath)
|
val outputStream = internalPath.getFileSystem(sc.hadoopConfiguration).create(internalPath)
|
||||||
instance._booster.saveModel(outputStream)
|
instance._booster.saveModel(outputStream, getModelFormat())
|
||||||
outputStream.close()
|
outputStream.close()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -19,6 +19,7 @@ package ml.dmlc.xgboost4j.scala.spark
|
|||||||
import scala.collection.{Iterator, mutable}
|
import scala.collection.{Iterator, mutable}
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.scala.spark.params.{DefaultXGBoostParamsReader, _}
|
import ml.dmlc.xgboost4j.scala.spark.params.{DefaultXGBoostParamsReader, _}
|
||||||
|
import ml.dmlc.xgboost4j.scala.spark.utils.XGBoostWriter
|
||||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
|
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
|
||||||
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
||||||
import org.apache.hadoop.fs.Path
|
import org.apache.hadoop.fs.Path
|
||||||
@ -379,7 +380,7 @@ object XGBoostRegressionModel extends MLReadable[XGBoostRegressionModel] {
|
|||||||
override def load(path: String): XGBoostRegressionModel = super.load(path)
|
override def load(path: String): XGBoostRegressionModel = super.load(path)
|
||||||
|
|
||||||
private[XGBoostRegressionModel]
|
private[XGBoostRegressionModel]
|
||||||
class XGBoostRegressionModelWriter(instance: XGBoostRegressionModel) extends MLWriter {
|
class XGBoostRegressionModelWriter(instance: XGBoostRegressionModel) extends XGBoostWriter {
|
||||||
|
|
||||||
override protected def saveImpl(path: String): Unit = {
|
override protected def saveImpl(path: String): Unit = {
|
||||||
// Save metadata and Params
|
// Save metadata and Params
|
||||||
@ -390,7 +391,7 @@ object XGBoostRegressionModel extends MLReadable[XGBoostRegressionModel] {
|
|||||||
val dataPath = new Path(path, "data").toString
|
val dataPath = new Path(path, "data").toString
|
||||||
val internalPath = new Path(dataPath, "XGBoostRegressionModel")
|
val internalPath = new Path(dataPath, "XGBoostRegressionModel")
|
||||||
val outputStream = internalPath.getFileSystem(sc.hadoopConfiguration).create(internalPath)
|
val outputStream = internalPath.getFileSystem(sc.hadoopConfiguration).create(internalPath)
|
||||||
instance._booster.saveModel(outputStream)
|
instance._booster.saveModel(outputStream, getModelFormat())
|
||||||
outputStream.close()
|
outputStream.close()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -0,0 +1,31 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2022 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.spark.utils
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.java.{Booster => JBooster}
|
||||||
|
|
||||||
|
import org.apache.spark.ml.util.MLWriter
|
||||||
|
|
||||||
|
private[spark] abstract class XGBoostWriter extends MLWriter {
|
||||||
|
|
||||||
|
/** Currently it's using the "deprecated" format as
|
||||||
|
* default, which will be changed into `ubj` in future releases. */
|
||||||
|
def getModelFormat(): String = {
|
||||||
|
optionMap.getOrElse("format", JBooster.DEFAULT_FORMAT)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@ -16,16 +16,18 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
import java.io.File
|
import java.io.{File, FileInputStream}
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||||
|
|
||||||
import org.apache.spark.SparkContext
|
import org.apache.spark.SparkContext
|
||||||
import org.apache.spark.sql._
|
import org.apache.spark.sql._
|
||||||
import org.scalatest.{BeforeAndAfterEach, FunSuite}
|
import org.scalatest.{BeforeAndAfterEach, FunSuite}
|
||||||
|
|
||||||
import scala.math.min
|
import scala.math.min
|
||||||
import scala.util.Random
|
import scala.util.Random
|
||||||
|
|
||||||
|
import org.apache.commons.io.IOUtils
|
||||||
|
|
||||||
trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
|
trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
|
||||||
|
|
||||||
protected val numWorkers: Int = min(Runtime.getRuntime.availableProcessors(), 4)
|
protected val numWorkers: Int = min(Runtime.getRuntime.availableProcessors(), 4)
|
||||||
@ -105,4 +107,22 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
|
|||||||
ss.createDataFrame(sc.parallelize(it.toList, numPartitions))
|
ss.createDataFrame(sc.parallelize(it.toList, numPartitions))
|
||||||
.toDF("id", "label", "features", "group")
|
.toDF("id", "label", "features", "group")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
protected def compareTwoFiles(lhs: String, rhs: String): Boolean = {
|
||||||
|
withResource(new FileInputStream(lhs)) { lfis =>
|
||||||
|
withResource(new FileInputStream(rhs)) { rfis =>
|
||||||
|
IOUtils.contentEquals(lfis, rfis)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Executes the provided code block and then closes the resource */
|
||||||
|
protected def withResource[T <: AutoCloseable, V](r: T)(block: T => V): V = {
|
||||||
|
try {
|
||||||
|
block(r)
|
||||||
|
} finally {
|
||||||
|
r.close()
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -429,30 +429,29 @@ class XGBoostClassifierSuite extends FunSuite with PerTest with TmpFolderPerSuit
|
|||||||
val trainingDF = buildDataFrame(MultiClassification.train)
|
val trainingDF = buildDataFrame(MultiClassification.train)
|
||||||
val xgb = new XGBoostClassifier(paramMap)
|
val xgb = new XGBoostClassifier(paramMap)
|
||||||
val model = xgb.fit(trainingDF)
|
val model = xgb.fit(trainingDF)
|
||||||
|
|
||||||
val modelPath = new File(tempDir.toFile, "xgbc").getPath
|
val modelPath = new File(tempDir.toFile, "xgbc").getPath
|
||||||
model.write.overwrite().save(modelPath)
|
model.write.option("format", "json").save(modelPath)
|
||||||
val nativeModelPath = new File(tempDir.toFile, "nativeModel").getPath
|
val nativeJsonModelPath = new File(tempDir.toFile, "nativeModel.json").getPath
|
||||||
model.nativeBooster.saveModel(nativeModelPath)
|
model.nativeBooster.saveModel(nativeJsonModelPath)
|
||||||
|
|
||||||
assert(compareTwoFiles(new File(modelPath, "data/XGBoostClassificationModel").getPath,
|
assert(compareTwoFiles(new File(modelPath, "data/XGBoostClassificationModel").getPath,
|
||||||
nativeModelPath))
|
nativeJsonModelPath))
|
||||||
}
|
|
||||||
|
|
||||||
private def compareTwoFiles(lhs: String, rhs: String): Boolean = {
|
// test default "deprecated"
|
||||||
withResource(new FileInputStream(lhs)) { lfis =>
|
val modelUbjPath = new File(tempDir.toFile, "xgbcUbj").getPath
|
||||||
withResource(new FileInputStream(rhs)) { rfis =>
|
model.write.save(modelUbjPath)
|
||||||
IOUtils.contentEquals(lfis, rfis)
|
val nativeDeprecatedModelPath = new File(tempDir.toFile, "nativeModel").getPath
|
||||||
}
|
model.nativeBooster.saveModel(nativeDeprecatedModelPath)
|
||||||
}
|
assert(compareTwoFiles(new File(modelUbjPath, "data/XGBoostClassificationModel").getPath,
|
||||||
}
|
nativeDeprecatedModelPath))
|
||||||
|
|
||||||
/** Executes the provided code block and then closes the resource */
|
// json file should be indifferent with ubj file
|
||||||
private def withResource[T <: AutoCloseable, V](r: T)(block: T => V): V = {
|
val modelJsonPath = new File(tempDir.toFile, "xgbcJson").getPath
|
||||||
try {
|
model.write.option("format", "json").save(modelJsonPath)
|
||||||
block(r)
|
val nativeUbjModelPath = new File(tempDir.toFile, "nativeModel1.ubj").getPath
|
||||||
} finally {
|
model.nativeBooster.saveModel(nativeUbjModelPath)
|
||||||
r.close()
|
assert(!compareTwoFiles(new File(modelJsonPath, "data/XGBoostClassificationModel").getPath,
|
||||||
}
|
nativeUbjModelPath))
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -16,6 +16,8 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
|
import java.io.File
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||||
|
|
||||||
import org.apache.spark.ml.linalg.{Vector, Vectors}
|
import org.apache.spark.ml.linalg.{Vector, Vectors}
|
||||||
@ -25,7 +27,7 @@ import org.scalatest.FunSuite
|
|||||||
|
|
||||||
import org.apache.spark.ml.feature.VectorAssembler
|
import org.apache.spark.ml.feature.VectorAssembler
|
||||||
|
|
||||||
class XGBoostRegressorSuite extends FunSuite with PerTest {
|
class XGBoostRegressorSuite extends FunSuite with PerTest with TmpFolderPerSuite {
|
||||||
protected val treeMethod: String = "auto"
|
protected val treeMethod: String = "auto"
|
||||||
|
|
||||||
test("XGBoost-Spark XGBoostRegressor output should match XGBoost4j") {
|
test("XGBoost-Spark XGBoostRegressor output should match XGBoost4j") {
|
||||||
@ -310,4 +312,42 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
|
|||||||
val df1 = model.transform(vectorizedInput)
|
val df1 = model.transform(vectorizedInput)
|
||||||
df1.show()
|
df1.show()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("XGBoostRegressionModel should be compatible") {
|
||||||
|
val trainingDF = buildDataFrame(Regression.train)
|
||||||
|
val paramMap = Map(
|
||||||
|
"eta" -> "1",
|
||||||
|
"max_depth" -> "6",
|
||||||
|
"silent" -> "1",
|
||||||
|
"objective" -> "reg:squarederror",
|
||||||
|
"num_round" -> 5,
|
||||||
|
"tree_method" -> treeMethod,
|
||||||
|
"num_workers" -> numWorkers)
|
||||||
|
|
||||||
|
val model = new XGBoostRegressor(paramMap).fit(trainingDF)
|
||||||
|
|
||||||
|
val modelPath = new File(tempDir.toFile, "xgbc").getPath
|
||||||
|
model.write.option("format", "json").save(modelPath)
|
||||||
|
val nativeJsonModelPath = new File(tempDir.toFile, "nativeModel.json").getPath
|
||||||
|
model.nativeBooster.saveModel(nativeJsonModelPath)
|
||||||
|
assert(compareTwoFiles(new File(modelPath, "data/XGBoostRegressionModel").getPath,
|
||||||
|
nativeJsonModelPath))
|
||||||
|
|
||||||
|
// test default "deprecated"
|
||||||
|
val modelUbjPath = new File(tempDir.toFile, "xgbcUbj").getPath
|
||||||
|
model.write.save(modelUbjPath)
|
||||||
|
val nativeDeprecatedModelPath = new File(tempDir.toFile, "nativeModel").getPath
|
||||||
|
model.nativeBooster.saveModel(nativeDeprecatedModelPath)
|
||||||
|
assert(compareTwoFiles(new File(modelUbjPath, "data/XGBoostRegressionModel").getPath,
|
||||||
|
nativeDeprecatedModelPath))
|
||||||
|
|
||||||
|
// json file should be indifferent with ubj file
|
||||||
|
val modelJsonPath = new File(tempDir.toFile, "xgbcJson").getPath
|
||||||
|
model.write.option("format", "json").save(modelJsonPath)
|
||||||
|
val nativeUbjModelPath = new File(tempDir.toFile, "nativeModel1.ubj").getPath
|
||||||
|
model.nativeBooster.saveModel(nativeUbjModelPath)
|
||||||
|
assert(!compareTwoFiles(new File(modelJsonPath, "data/XGBoostRegressionModel").getPath,
|
||||||
|
nativeUbjModelPath))
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -34,6 +34,7 @@ import org.apache.commons.logging.LogFactory;
|
|||||||
* Booster for xgboost, this is a model API that support interactive build of a XGBoost Model
|
* Booster for xgboost, this is a model API that support interactive build of a XGBoost Model
|
||||||
*/
|
*/
|
||||||
public class Booster implements Serializable, KryoSerializable {
|
public class Booster implements Serializable, KryoSerializable {
|
||||||
|
public static final String DEFAULT_FORMAT = "deprecated";
|
||||||
private static final Log logger = LogFactory.getLog(Booster.class);
|
private static final Log logger = LogFactory.getLog(Booster.class);
|
||||||
// handle to the booster.
|
// handle to the booster.
|
||||||
private long handle = 0;
|
private long handle = 0;
|
||||||
@ -391,7 +392,22 @@ public class Booster implements Serializable, KryoSerializable {
|
|||||||
* @param out The output stream
|
* @param out The output stream
|
||||||
*/
|
*/
|
||||||
public void saveModel(OutputStream out) throws XGBoostError, IOException {
|
public void saveModel(OutputStream out) throws XGBoostError, IOException {
|
||||||
out.write(this.toByteArray());
|
saveModel(out, DEFAULT_FORMAT);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Save the model to file opened as output stream.
|
||||||
|
* The model format is compatible with other xgboost bindings.
|
||||||
|
* The output stream can only save one xgboost model.
|
||||||
|
* This function will close the OutputStream after the save.
|
||||||
|
*
|
||||||
|
* @param out The output stream
|
||||||
|
* @param format The model format (ubj, json, deprecated)
|
||||||
|
* @throws XGBoostError
|
||||||
|
* @throws IOException
|
||||||
|
*/
|
||||||
|
public void saveModel(OutputStream out, String format) throws XGBoostError, IOException {
|
||||||
|
out.write(this.toByteArray(format));
|
||||||
out.close();
|
out.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -643,7 +659,7 @@ public class Booster implements Serializable, KryoSerializable {
|
|||||||
* @throws XGBoostError native error
|
* @throws XGBoostError native error
|
||||||
*/
|
*/
|
||||||
public byte[] toByteArray() throws XGBoostError {
|
public byte[] toByteArray() throws XGBoostError {
|
||||||
return this.toByteArray("deprecated");
|
return this.toByteArray(DEFAULT_FORMAT);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -207,6 +207,7 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
|
|||||||
def saveModel(modelPath: String): Unit = {
|
def saveModel(modelPath: String): Unit = {
|
||||||
booster.saveModel(modelPath)
|
booster.saveModel(modelPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* save model to Output stream
|
* save model to Output stream
|
||||||
*
|
*
|
||||||
@ -216,6 +217,18 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
|
|||||||
def saveModel(out: java.io.OutputStream): Unit = {
|
def saveModel(out: java.io.OutputStream): Unit = {
|
||||||
booster.saveModel(out)
|
booster.saveModel(out)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* save model to Output stream
|
||||||
|
* @param out output stream
|
||||||
|
* @param format the supported model format, (json, ubj, deprecated)
|
||||||
|
* @throws ml.dmlc.xgboost4j.java.XGBoostError
|
||||||
|
*/
|
||||||
|
@throws(classOf[XGBoostError])
|
||||||
|
def saveModel(out: java.io.OutputStream, format: String): Unit = {
|
||||||
|
booster.saveModel(out, format)
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Dump model as Array of string
|
* Dump model as Array of string
|
||||||
*
|
*
|
||||||
@ -315,7 +328,7 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
|
|||||||
*/
|
*/
|
||||||
@throws(classOf[XGBoostError])
|
@throws(classOf[XGBoostError])
|
||||||
def toByteArray: Array[Byte] = {
|
def toByteArray: Array[Byte] = {
|
||||||
booster.toByteArray
|
booster.toByteArray()
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user