[jvm-packages] refactor xgboost read/write (#7956)

1. Removed the duplicated Default XGBoost read/write which is copied from
  spark 2.3.x
2. Put some utils into util package
This commit is contained in:
Bobby Wang 2022-06-01 11:38:49 +08:00 committed by GitHub
parent 27c66f12d1
commit 545fd4548e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 180 additions and 389 deletions

View File

@ -24,8 +24,9 @@ import scala.collection.{AbstractIterator, Iterator, mutable}
import ml.dmlc.xgboost4j.java.Rabit
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
import ml.dmlc.xgboost4j.scala.spark.DataUtils.PackedParams
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils.PackedParams
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
@ -35,10 +36,8 @@ import org.apache.commons.logging.LogFactory
import org.apache.spark.TaskContext
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.{Estimator, Model, PipelineStage}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.linalg.xgboost.XGBoostSchemaUtils
import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType}
import org.apache.spark.storage.StorageLevel
@ -272,7 +271,7 @@ object PreXGBoost extends PreXGBoostProvider {
val features = batchRow.iterator.map(row => row.getAs[Vector](featuresCol))
import DataUtils._
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
val cacheInfo = {
if (useExternalMemory) {
s"$appName-${TaskContext.get().stageId()}-dtest_cache-" +

View File

@ -23,15 +23,13 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.classification._
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.json4s.DefaultFormats
import scala.collection.{Iterator, mutable}
import ml.dmlc.xgboost4j.scala.spark.utils.XGBoostWriter
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.{DefaultXGBoostParamsReader, DefaultXGBoostParamsWriter, XGBoostWriter}
import org.apache.spark.sql.types.StructType
class XGBoostClassifier (
@ -274,7 +272,7 @@ class XGBoostClassificationModel private[ml](
* Note: The performance is not ideal, use it carefully!
*/
override def predict(features: Vector): Double = {
import DataUtils._
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
val dm = new DMatrix(processMissingValues(
Iterator(features.asXGB),
$(missing),
@ -469,10 +467,8 @@ object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel]
override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
implicit val format = DefaultFormats
implicit val sc = super.sparkSession.sparkContext
DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc)
// Save model data
val dataPath = new Path(path, "data").toString
val internalPath = new Path(dataPath, "XGBoostClassificationModel")
@ -495,18 +491,7 @@ object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel]
val dataPath = new Path(path, "data").toString
val internalPath = new Path(dataPath, "XGBoostClassificationModel")
val dataInStream = internalPath.getFileSystem(sc.hadoopConfiguration).open(internalPath)
// The xgboostVersion in the meta can specify if the model is the old xgboost in-compatible
// or the new xgboost compatible.
val numClasses = metadata.xgboostVersion.map { _ =>
implicit val format = DefaultFormats
// For binary:logistic, the numClass parameter can't be set to 2 or not be set.
// For multi:softprob or multi:softmax, the numClass parameter must be set correctly,
// or else, XGBoost will throw exception.
// So it's safe to get numClass from meta data.
(metadata.params \ "numClass").extractOpt[Int].getOrElse(2)
}.getOrElse(dataInStream.readInt())
val numClasses = DefaultXGBoostParamsReader.getNumClass(metadata, dataInStream)
val booster = SXGBoost.loadModel(dataInStream)
val model = new XGBoostClassificationModel(metadata.uid, numClasses, booster)
DefaultXGBoostParamsReader.getAndSetParams(model, metadata)

View File

@ -18,8 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark
import scala.collection.{Iterator, mutable}
import ml.dmlc.xgboost4j.scala.spark.params.{DefaultXGBoostParamsReader, _}
import ml.dmlc.xgboost4j.scala.spark.utils.XGBoostWriter
import ml.dmlc.xgboost4j.scala.spark.params._
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
import org.apache.hadoop.fs.Path
@ -30,9 +29,9 @@ import org.apache.spark.ml._
import org.apache.spark.ml.param._
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.json4s.DefaultFormats
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.util.{DefaultXGBoostParamsReader, DefaultXGBoostParamsWriter, XGBoostWriter}
import org.apache.spark.sql.types.StructType
class XGBoostRegressor (
@ -260,7 +259,7 @@ class XGBoostRegressionModel private[ml] (
* Note: The performance is not ideal, use it carefully!
*/
override def predict(features: Vector): Double = {
import DataUtils._
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
val dm = new DMatrix(processMissingValues(
Iterator(features.asXGB),
$(missing),
@ -384,8 +383,6 @@ object XGBoostRegressionModel extends MLReadable[XGBoostRegressionModel] {
override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
implicit val format = DefaultFormats
implicit val sc = super.sparkSession.sparkContext
DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc)
// Save model data
val dataPath = new Path(path, "data").toString

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014,2021 by Contributors
Copyright (c) 2014-2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -18,6 +18,8 @@ package ml.dmlc.xgboost4j.scala.spark.params
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
import ml.dmlc.xgboost4j.scala.spark.TrackerConf
import ml.dmlc.xgboost4j.scala.spark.util.Utils
import org.apache.spark.ml.param.{Param, ParamPair, Params}
import org.json4s.{DefaultFormats, Extraction, NoTypeHints}
import org.json4s.jackson.JsonMethods.{compact, parse, render}

View File

@ -1,161 +0,0 @@
/*
Copyright (c) 2014-2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark.params
import ml.dmlc.xgboost4j.scala.spark
import org.apache.commons.logging.LogFactory
import org.apache.hadoop.fs.Path
import org.json4s.{DefaultFormats, JValue}
import org.json4s.JsonAST.JObject
import org.json4s.jackson.JsonMethods.{compact, parse, render}
import org.apache.spark.SparkContext
import org.apache.spark.ml.param.Params
import org.apache.spark.ml.util.MLReader
// This originates from apache-spark DefaultPramsReader copy paste
private[spark] object DefaultXGBoostParamsReader {
private val logger = LogFactory.getLog("XGBoostSpark")
private val paramNameCompatibilityMap: Map[String, String] = Map("silent" -> "verbosity")
private val paramValueCompatibilityMap: Map[String, Map[Any, Any]] =
Map("objective" -> Map("reg:linear" -> "reg:squarederror"))
/**
* All info from metadata file.
*
* @param params paramMap, as a `JValue`
* @param metadata All metadata, including the other fields
* @param metadataJson Full metadata file String (for debugging)
*/
case class Metadata(
className: String,
uid: String,
timestamp: Long,
sparkVersion: String,
params: JValue,
metadata: JValue,
metadataJson: String,
xgboostVersion: Option[String] = None) {
/**
* Get the JSON value of the [[org.apache.spark.ml.param.Param]] of the given name.
* This can be useful for getting a Param value before an instance of `Params`
* is available.
*/
def getParamValue(paramName: String): JValue = {
implicit val format = DefaultFormats
params match {
case JObject(pairs) =>
val values = pairs.filter { case (pName, jsonValue) =>
pName == paramName
}.map(_._2)
assert(values.length == 1, s"Expected one instance of Param '$paramName' but found" +
s" ${values.length} in JSON Params: " + pairs.map(_.toString).mkString(", "))
values.head
case _ =>
throw new IllegalArgumentException(
s"Cannot recognize JSON metadata: $metadataJson.")
}
}
}
/**
* Load metadata saved using [[DefaultXGBoostParamsWriter.saveMetadata()]]
*
* @param expectedClassName If non empty, this is checked against the loaded metadata.
* @throws IllegalArgumentException if expectedClassName is specified and does not match metadata
*/
def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = {
val metadataPath = new Path(path, "metadata").toString
val metadataStr = sc.textFile(metadataPath, 1).first()
parseMetadata(metadataStr, expectedClassName)
}
/**
* Parse metadata JSON string produced by [[DefaultXGBoostParamsWriter.getMetadataToSave()]].
* This is a helper function for [[loadMetadata()]].
*
* @param metadataStr JSON string of metadata
* @param expectedClassName If non empty, this is checked against the loaded metadata.
* @throws IllegalArgumentException if expectedClassName is specified and does not match metadata
*/
def parseMetadata(metadataStr: String, expectedClassName: String = ""): Metadata = {
val metadata = parse(metadataStr)
implicit val format = DefaultFormats
val className = (metadata \ "class").extract[String]
val uid = (metadata \ "uid").extract[String]
val timestamp = (metadata \ "timestamp").extract[Long]
val sparkVersion = (metadata \ "sparkVersion").extract[String]
val params = metadata \ "paramMap"
if (expectedClassName.nonEmpty) {
require(className == expectedClassName, s"Error loading metadata: Expected class name" +
s" $expectedClassName but found class name $className")
}
val xgboostVersion = (metadata \ "xgboostVersion").extractOpt[String]
Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr, xgboostVersion)
}
private def handleBrokenlyChangedValue[T](paramName: String, value: T): T = {
paramValueCompatibilityMap.getOrElse(paramName, Map()).getOrElse(value, value).asInstanceOf[T]
}
private def handleBrokenlyChangedName(paramName: String): String = {
paramNameCompatibilityMap.getOrElse(paramName, paramName)
}
/**
* Extract Params from metadata, and set them in the instance.
* This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]].
* TODO: Move to [[Metadata]] method
*/
def getAndSetParams(instance: Params, metadata: Metadata): Unit = {
implicit val format = DefaultFormats
metadata.params match {
case JObject(pairs) =>
pairs.foreach { case (paramName, jsonValue) =>
val finalName = handleBrokenlyChangedName(paramName)
// For the deleted parameters, we'd better to remove it instead of throwing an exception.
// So we need to check if the parameter exists instead of blindly setting it.
if (instance.hasParam(finalName)) {
val param = instance.getParam(finalName)
val value = param.jsonDecode(compact(render(jsonValue)))
instance.set(param, handleBrokenlyChangedValue(paramName, value))
} else {
logger.warn(s"$finalName is no longer used in ${spark.VERSION}")
}
}
case _ =>
throw new IllegalArgumentException(
s"Cannot recognize JSON metadata: ${metadata.metadataJson}.")
}
}
/**
* Load a `Params` instance from the given path, and return it.
* This assumes the instance implements [[org.apache.spark.ml.util.MLReadable]].
*/
def loadParamsInstance[T](path: String, sc: SparkContext): T = {
val metadata = DefaultXGBoostParamsReader.loadMetadata(path, sc)
val cls = Utils.classForName(metadata.className)
cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path)
}
}

View File

@ -1,151 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark.params
import org.apache.hadoop.fs.Path
import org.apache.spark.SparkContext
import org.apache.spark.ml.param.{ParamPair, Params}
import org.json4s.jackson.JsonMethods._
import org.json4s.{JArray, JBool, JDouble, JField, JInt, JNothing, JObject, JString, JValue}
import JsonDSLXGBoost._
import ml.dmlc.xgboost4j.scala.spark
// This originates from apache-spark DefaultPramsWriter copy paste
private[spark] object DefaultXGBoostParamsWriter {
/**
* Saves metadata + Params to: path + "/metadata"
* - class
* - timestamp
* - sparkVersion
* - uid
* - paramMap
* - (optionally, extra metadata)
*
* @param extraMetadata Extra metadata to be saved at same level as uid, paramMap, etc.
* @param paramMap If given, this is saved in the "paramMap" field.
* Otherwise, all [[org.apache.spark.ml.param.Param]]s are encoded using
* [[org.apache.spark.ml.param.Param.jsonEncode()]].
*/
def saveMetadata(
instance: Params,
path: String,
sc: SparkContext,
extraMetadata: Option[JObject] = None,
paramMap: Option[JValue] = None): Unit = {
val metadataPath = new Path(path, "metadata").toString
val metadataJson = getMetadataToSave(instance, sc, extraMetadata, paramMap)
sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
}
/**
* Helper for [[saveMetadata()]] which extracts the JSON to save.
* This is useful for ensemble models which need to save metadata for many sub-models.
*
* @see [[saveMetadata()]] for details on what this includes.
*/
def getMetadataToSave(
instance: Params,
sc: SparkContext,
extraMetadata: Option[JObject] = None,
paramMap: Option[JValue] = None): String = {
val uid = instance.uid
val cls = instance.getClass.getName
val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
val jsonParams = paramMap.getOrElse(render(params.filter{
case ParamPair(p, _) => p != null
}.map {
case ParamPair(p, v) =>
p.name -> parse(p.jsonEncode(v))
}.toList))
val basicMetadata = ("class" -> cls) ~
("timestamp" -> System.currentTimeMillis()) ~
("sparkVersion" -> sc.version) ~
("uid" -> uid) ~
("xgboostVersion" -> spark.VERSION) ~
("paramMap" -> jsonParams)
val metadata = extraMetadata match {
case Some(jObject) =>
basicMetadata ~ jObject
case None =>
basicMetadata
}
val metadataJson: String = compact(render(metadata))
metadataJson
}
}
// Fix json4s bin-incompatible issue.
// This originates from org.json4s.JsonDSL of 3.6.6
object JsonDSLXGBoost {
implicit def seq2jvalue[A](s: Iterable[A])(implicit ev: A => JValue): JArray =
JArray(s.toList.map(ev))
implicit def map2jvalue[A](m: Map[String, A])(implicit ev: A => JValue): JObject =
JObject(m.toList.map { case (k, v) => JField(k, ev(v)) })
implicit def option2jvalue[A](opt: Option[A])(implicit ev: A => JValue): JValue = opt match {
case Some(x) => ev(x)
case None => JNothing
}
implicit def short2jvalue(x: Short): JValue = JInt(x)
implicit def byte2jvalue(x: Byte): JValue = JInt(x)
implicit def char2jvalue(x: Char): JValue = JInt(x)
implicit def int2jvalue(x: Int): JValue = JInt(x)
implicit def long2jvalue(x: Long): JValue = JInt(x)
implicit def bigint2jvalue(x: BigInt): JValue = JInt(x)
implicit def double2jvalue(x: Double): JValue = JDouble(x)
implicit def float2jvalue(x: Float): JValue = JDouble(x.toDouble)
implicit def bigdecimal2jvalue(x: BigDecimal): JValue = JDouble(x.doubleValue)
implicit def boolean2jvalue(x: Boolean): JValue = JBool(x)
implicit def string2jvalue(x: String): JValue = JString(x)
implicit def symbol2jvalue(x: Symbol): JString = JString(x.name)
implicit def pair2jvalue[A](t: (String, A))(implicit ev: A => JValue): JObject =
JObject(List(JField(t._1, ev(t._2))))
implicit def list2jvalue(l: List[JField]): JObject = JObject(l)
implicit def jobject2assoc(o: JObject): JsonListAssoc = new JsonListAssoc(o.obj)
implicit def pair2Assoc[A](t: (String, A))(implicit ev: A => JValue): JsonAssoc[A] =
new JsonAssoc(t)
}
final class JsonAssoc[A](private val left: (String, A)) extends AnyVal {
def ~[B](right: (String, B))(implicit ev1: A => JValue, ev2: B => JValue): JObject = {
val l: JValue = ev1(left._2)
val r: JValue = ev2(right._2)
JObject(JField(left._1, l) :: JField(right._1, r) :: Nil)
}
def ~(right: JObject)(implicit ev: A => JValue): JObject = {
val l: JValue = ev(left._2)
JObject(JField(left._1, l) :: right.obj)
}
def ~~[B](right: (String, B))(implicit ev1: A => JValue, ev2: B => JValue): JObject =
this.~(right)
def ~~(right: JObject)(implicit ev: A => JValue): JObject = this.~(right)
}
final class JsonListAssoc(private val left: List[JField]) extends AnyVal {
def ~(right: (String, JValue)): JObject = JObject(left ::: List(JField(right._1, right._2)))
def ~(right: JObject): JObject = JObject(left ::: right.obj)
def ~~(right: (String, JValue)): JObject = this.~(right)
def ~~(right: JObject): JObject = this.~(right)
}

View File

@ -17,9 +17,9 @@
package ml.dmlc.xgboost4j.scala.spark.params
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.linalg.xgboost.XGBoostSchemaUtils
import org.apache.spark.ml.param.{Param, ParamValidators}
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasHandleInvalid, HasLabelCol, HasWeightCol}
import org.apache.spark.ml.util.XGBoostSchemaUtils
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.types.StructType

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014,2021 by Contributors
Copyright (c) 2014-2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -14,7 +14,7 @@
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
package ml.dmlc.xgboost4j.scala.spark.util
import scala.collection.mutable
@ -24,8 +24,8 @@ import org.apache.spark.HashPartitioner
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Column, DataFrame, Row}
import org.apache.spark.sql.types.{FloatType, IntegerType}
import org.apache.spark.sql.{Column, DataFrame, Row}
object DataUtils extends Serializable {
private[spark] implicit class XGBLabeledPointFeatures(

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014,2021 by Contributors
Copyright (c) 2014-2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -14,7 +14,7 @@
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark.params
package ml.dmlc.xgboost4j.scala.spark.util
import org.json4s.{DefaultFormats, FullTypeHints, JField, JValue, NoTypeHints, TypeHints}

View File

@ -1,31 +0,0 @@
/*
Copyright (c) 2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark.utils
import ml.dmlc.xgboost4j.java.{Booster => JBooster}
import org.apache.spark.ml.util.MLWriter
private[spark] abstract class XGBoostWriter extends MLWriter {
/** Currently it's using the "deprecated" format as
* default, which will be changed into `ubj` in future releases. */
def getModelFormat(): String = {
optionMap.getOrElse("format", JBooster.DEFAULT_FORMAT)
}
}

View File

@ -0,0 +1,150 @@
/*
Copyright (c) 2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.apache.spark.ml.util
import ml.dmlc.xgboost4j.java.{Booster => JBooster}
import ml.dmlc.xgboost4j.scala.spark
import org.apache.commons.logging.LogFactory
import org.apache.hadoop.fs.FSDataInputStream
import org.json4s.DefaultFormats
import org.json4s.JsonAST.JObject
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods.{compact, render}
import org.apache.spark.SparkContext
import org.apache.spark.ml.param.Params
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
abstract class XGBoostWriter extends MLWriter {
/** Currently it's using the "deprecated" format as
* default, which will be changed into `ubj` in future releases. */
def getModelFormat(): String = {
optionMap.getOrElse("format", JBooster.DEFAULT_FORMAT)
}
}
object DefaultXGBoostParamsWriter {
val XGBOOST_VERSION_TAG = "xgboostVersion"
/**
* Saves metadata + Params to: path + "/metadata" using [[DefaultParamsWriter.saveMetadata]]
*/
def saveMetadata(
instance: Params,
path: String,
sc: SparkContext): Unit = {
// save xgboost version to distinguish the old model.
val extraMetadata: JObject = Map(XGBOOST_VERSION_TAG -> ml.dmlc.xgboost4j.scala.spark.VERSION)
DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata))
}
}
object DefaultXGBoostParamsReader {
private val logger = LogFactory.getLog("XGBoostSpark")
/**
* Load metadata saved using [[DefaultParamsReader.loadMetadata()]]
*
* @param expectedClassName If non empty, this is checked against the loaded metadata.
* @throws IllegalArgumentException if expectedClassName is specified and does not match metadata
*/
def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = {
DefaultParamsReader.loadMetadata(path, sc, expectedClassName)
}
/**
* Extract Params from metadata, and set them in the instance.
* This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]].
*
* And it will auto-skip the parameter not defined.
*
* This API is mainly copied from DefaultParamsReader
*/
def getAndSetParams(instance: Params, metadata: Metadata): Unit = {
// XGBoost didn't set the default parameters since the save/load code is copied
// from spark 2.3.x, which means it just used the default values
// as the same with XGBoost version instead of them in model.
// For the compatibility, here we still don't set the default parameters.
// setParams(instance, metadata, isDefault = true)
setParams(instance, metadata, isDefault = false)
}
/** This API is only for XGBoostClassificationModel */
def getNumClass(metadata: Metadata, dataInStream: FSDataInputStream): Int = {
implicit val format = DefaultFormats
// The xgboostVersion in the meta can specify if the model is the old xgboost in-compatible
// or the new xgboost compatible.
val xgbVerOpt = (metadata.metadata \ DefaultXGBoostParamsWriter.XGBOOST_VERSION_TAG)
.extractOpt[String]
// For binary:logistic, the numClass parameter can't be set to 2 or not be set.
// For multi:softprob or multi:softmax, the numClass parameter must be set correctly,
// or else, XGBoost will throw exception.
// So it's safe to get numClass from meta data.
xgbVerOpt
.map { _ => (metadata.params \ "numClass").extractOpt[Int].getOrElse(2) }
.getOrElse(dataInStream.readInt())
}
private def setParams(
instance: Params,
metadata: Metadata,
isDefault: Boolean): Unit = {
val paramsToSet = if (isDefault) metadata.defaultParams else metadata.params
paramsToSet match {
case JObject(pairs) =>
pairs.foreach { case (paramName, jsonValue) =>
val finalName = handleBrokenlyChangedName(paramName)
// For the deleted parameters, we'd better to remove it instead of throwing an exception.
// So we need to check if the parameter exists instead of blindly setting it.
if (instance.hasParam(finalName)) {
val param = instance.getParam(finalName)
val value = param.jsonDecode(compact(render(jsonValue)))
instance.set(param, handleBrokenlyChangedValue(paramName, value))
} else {
logger.warn(s"$finalName is no longer used in ${spark.VERSION}")
}
}
case _ =>
throw new IllegalArgumentException(
s"Cannot recognize JSON metadata: ${metadata.metadataJson}.")
}
}
private val paramNameCompatibilityMap: Map[String, String] = Map("silent" -> "verbosity")
/** This is really not good to do this transformation, but it is needed since there're
* some tests based on 0.82 saved model in which the objective is "reg:linear" */
private val paramValueCompatibilityMap: Map[String, Map[Any, Any]] =
Map("objective" -> Map("reg:linear" -> "reg:squarederror"))
private def handleBrokenlyChangedName(paramName: String): String = {
paramNameCompatibilityMap.getOrElse(paramName, paramName)
}
private def handleBrokenlyChangedValue[T](paramName: String, value: T): T = {
paramValueCompatibilityMap.getOrElse(paramName, Map()).getOrElse(value, value).asInstanceOf[T]
}
}

View File

@ -14,7 +14,7 @@
limitations under the License.
*/
package org.apache.spark.ml.linalg.xgboost
package org.apache.spark.ml.util
import org.apache.spark.sql.types.{BooleanType, DataType, NumericType, StructType}
import org.apache.spark.ml.linalg.VectorUDT

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014-2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -18,7 +18,8 @@ package ml.dmlc.xgboost4j.scala.spark
import org.apache.spark.ml.linalg.Vectors
import org.scalatest.FunSuite
import ml.dmlc.xgboost4j.scala.spark.DataUtils.PackedParams
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils.PackedParams
import org.apache.spark.sql.functions._

View File

@ -50,7 +50,7 @@ class FeatureSizeValidatingSuite extends FunSuite with PerTest {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic",
"num_round" -> 5, "num_workers" -> 2, "use_external_memory" -> true, "missing" -> 0)
import DataUtils._
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
val sparkSession = ss
import sparkSession.implicits._
val repartitioned = sc.parallelize(Synthetic.trainWithDiffFeatureSize, 2)

View File

@ -73,7 +73,7 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
protected def buildDataFrame(
labeledPoints: Seq[XGBLabeledPoint],
numPartitions: Int = numWorkers): DataFrame = {
import DataUtils._
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
val it = labeledPoints.iterator.zipWithIndex
.map { case (labeledPoint: XGBLabeledPoint, id: Int) =>
(id, labeledPoint.label, labeledPoint.features)
@ -98,7 +98,7 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
protected def buildDataFrameWithGroup(
labeledPoints: Seq[XGBLabeledPoint],
numPartitions: Int = numWorkers): DataFrame = {
import DataUtils._
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
val it = labeledPoints.iterator.zipWithIndex
.map { case (labeledPoint: XGBLabeledPoint, id: Int) =>
(id, labeledPoint.label, labeledPoint.features, labeledPoint.group)

View File

@ -310,7 +310,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest with TmpFolderPerSuit
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic",
"num_round" -> 5, "num_workers" -> 2, "missing" -> 0)
import DataUtils._
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
val sparkSession = SparkSession.builder().getOrCreate()
import sparkSession.implicits._
val repartitioned = sc.parallelize(Synthetic.train, 3).map(lp => (lp.label, lp)).partitionBy(
@ -331,7 +331,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest with TmpFolderPerSuit
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic",
"num_round" -> 5, "num_workers" -> 2, "use_external_memory" -> true, "missing" -> 0)
import DataUtils._
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
val sparkSession = SparkSession.builder().getOrCreate()
import sparkSession.implicits._
val repartitioned = sc.parallelize(Synthetic.train, 3).map(lp => (lp.label, lp)).partitionBy(