[jvm-packages] add format option when saving a model (#7940)

This commit is contained in:
Bobby Wang
2022-05-30 15:49:59 +08:00
committed by GitHub
parent cc6d57aa0d
commit 6275cdc486
8 changed files with 153 additions and 30 deletions

View File

@@ -30,6 +30,8 @@ import org.apache.spark.sql.functions._
import org.json4s.DefaultFormats
import scala.collection.{Iterator, mutable}
import ml.dmlc.xgboost4j.scala.spark.utils.XGBoostWriter
import org.apache.spark.sql.types.StructType
class XGBoostClassifier (
@@ -462,7 +464,8 @@ object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel]
override def load(path: String): XGBoostClassificationModel = super.load(path)
private[XGBoostClassificationModel]
class XGBoostClassificationModelWriter(instance: XGBoostClassificationModel) extends MLWriter {
class XGBoostClassificationModelWriter(instance: XGBoostClassificationModel)
extends XGBoostWriter {
override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
@@ -474,7 +477,7 @@ object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel]
val dataPath = new Path(path, "data").toString
val internalPath = new Path(dataPath, "XGBoostClassificationModel")
val outputStream = internalPath.getFileSystem(sc.hadoopConfiguration).create(internalPath)
instance._booster.saveModel(outputStream)
instance._booster.saveModel(outputStream, getModelFormat())
outputStream.close()
}
}

View File

@@ -19,6 +19,7 @@ package ml.dmlc.xgboost4j.scala.spark
import scala.collection.{Iterator, mutable}
import ml.dmlc.xgboost4j.scala.spark.params.{DefaultXGBoostParamsReader, _}
import ml.dmlc.xgboost4j.scala.spark.utils.XGBoostWriter
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
import org.apache.hadoop.fs.Path
@@ -379,7 +380,7 @@ object XGBoostRegressionModel extends MLReadable[XGBoostRegressionModel] {
override def load(path: String): XGBoostRegressionModel = super.load(path)
private[XGBoostRegressionModel]
class XGBoostRegressionModelWriter(instance: XGBoostRegressionModel) extends MLWriter {
class XGBoostRegressionModelWriter(instance: XGBoostRegressionModel) extends XGBoostWriter {
override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
@@ -390,7 +391,7 @@ object XGBoostRegressionModel extends MLReadable[XGBoostRegressionModel] {
val dataPath = new Path(path, "data").toString
val internalPath = new Path(dataPath, "XGBoostRegressionModel")
val outputStream = internalPath.getFileSystem(sc.hadoopConfiguration).create(internalPath)
instance._booster.saveModel(outputStream)
instance._booster.saveModel(outputStream, getModelFormat())
outputStream.close()
}
}

View File

@@ -0,0 +1,31 @@
/*
Copyright (c) 2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark.utils
import ml.dmlc.xgboost4j.java.{Booster => JBooster}
import org.apache.spark.ml.util.MLWriter
private[spark] abstract class XGBoostWriter extends MLWriter {
/** Currently it's using the "deprecated" format as
* default, which will be changed into `ubj` in future releases. */
def getModelFormat(): String = {
optionMap.getOrElse("format", JBooster.DEFAULT_FORMAT)
}
}