[jvm-packages] add format option when saving a model (#7940)
This commit is contained in:
@@ -30,6 +30,8 @@ import org.apache.spark.sql.functions._
|
||||
import org.json4s.DefaultFormats
|
||||
import scala.collection.{Iterator, mutable}
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.spark.utils.XGBoostWriter
|
||||
|
||||
import org.apache.spark.sql.types.StructType
|
||||
|
||||
class XGBoostClassifier (
|
||||
@@ -462,7 +464,8 @@ object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel]
|
||||
override def load(path: String): XGBoostClassificationModel = super.load(path)
|
||||
|
||||
private[XGBoostClassificationModel]
|
||||
class XGBoostClassificationModelWriter(instance: XGBoostClassificationModel) extends MLWriter {
|
||||
class XGBoostClassificationModelWriter(instance: XGBoostClassificationModel)
|
||||
extends XGBoostWriter {
|
||||
|
||||
override protected def saveImpl(path: String): Unit = {
|
||||
// Save metadata and Params
|
||||
@@ -474,7 +477,7 @@ object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel]
|
||||
val dataPath = new Path(path, "data").toString
|
||||
val internalPath = new Path(dataPath, "XGBoostClassificationModel")
|
||||
val outputStream = internalPath.getFileSystem(sc.hadoopConfiguration).create(internalPath)
|
||||
instance._booster.saveModel(outputStream)
|
||||
instance._booster.saveModel(outputStream, getModelFormat())
|
||||
outputStream.close()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
import scala.collection.{Iterator, mutable}
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.spark.params.{DefaultXGBoostParamsReader, _}
|
||||
import ml.dmlc.xgboost4j.scala.spark.utils.XGBoostWriter
|
||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
|
||||
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
||||
import org.apache.hadoop.fs.Path
|
||||
@@ -379,7 +380,7 @@ object XGBoostRegressionModel extends MLReadable[XGBoostRegressionModel] {
|
||||
override def load(path: String): XGBoostRegressionModel = super.load(path)
|
||||
|
||||
private[XGBoostRegressionModel]
|
||||
class XGBoostRegressionModelWriter(instance: XGBoostRegressionModel) extends MLWriter {
|
||||
class XGBoostRegressionModelWriter(instance: XGBoostRegressionModel) extends XGBoostWriter {
|
||||
|
||||
override protected def saveImpl(path: String): Unit = {
|
||||
// Save metadata and Params
|
||||
@@ -390,7 +391,7 @@ object XGBoostRegressionModel extends MLReadable[XGBoostRegressionModel] {
|
||||
val dataPath = new Path(path, "data").toString
|
||||
val internalPath = new Path(dataPath, "XGBoostRegressionModel")
|
||||
val outputStream = internalPath.getFileSystem(sc.hadoopConfiguration).create(internalPath)
|
||||
instance._booster.saveModel(outputStream)
|
||||
instance._booster.saveModel(outputStream, getModelFormat())
|
||||
outputStream.close()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
/*
|
||||
Copyright (c) 2022 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark.utils
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{Booster => JBooster}
|
||||
|
||||
import org.apache.spark.ml.util.MLWriter
|
||||
|
||||
private[spark] abstract class XGBoostWriter extends MLWriter {
|
||||
|
||||
/** Currently it's using the "deprecated" format as
|
||||
* default, which will be changed into `ubj` in future releases. */
|
||||
def getModelFormat(): String = {
|
||||
optionMap.getOrElse("format", JBooster.DEFAULT_FORMAT)
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user