[jvm-packages] refactor xgboost read/write (#7956)
1. Removed the duplicated Default XGBoost read/write which is copied from spark 2.3.x 2. Put some utils into util package
This commit is contained in:
parent
27c66f12d1
commit
545fd4548e
@ -24,8 +24,9 @@ import scala.collection.{AbstractIterator, Iterator, mutable}
|
||||
|
||||
import ml.dmlc.xgboost4j.java.Rabit
|
||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
|
||||
import ml.dmlc.xgboost4j.scala.spark.DataUtils.PackedParams
|
||||
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils.PackedParams
|
||||
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon
|
||||
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils
|
||||
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.{DataFrame, Dataset, Row}
|
||||
@ -35,10 +36,8 @@ import org.apache.commons.logging.LogFactory
|
||||
|
||||
import org.apache.spark.TaskContext
|
||||
import org.apache.spark.broadcast.Broadcast
|
||||
import org.apache.spark.ml.feature.VectorAssembler
|
||||
import org.apache.spark.ml.{Estimator, Model, PipelineStage}
|
||||
import org.apache.spark.ml.{Estimator, Model}
|
||||
import org.apache.spark.ml.linalg.Vector
|
||||
import org.apache.spark.ml.linalg.xgboost.XGBoostSchemaUtils
|
||||
import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType}
|
||||
import org.apache.spark.storage.StorageLevel
|
||||
|
||||
@ -272,7 +271,7 @@ object PreXGBoost extends PreXGBoostProvider {
|
||||
|
||||
val features = batchRow.iterator.map(row => row.getAs[Vector](featuresCol))
|
||||
|
||||
import DataUtils._
|
||||
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
|
||||
val cacheInfo = {
|
||||
if (useExternalMemory) {
|
||||
s"$appName-${TaskContext.get().stageId()}-dtest_cache-" +
|
||||
|
||||
@ -23,15 +23,13 @@ import org.apache.hadoop.fs.Path
|
||||
import org.apache.spark.broadcast.Broadcast
|
||||
import org.apache.spark.ml.classification._
|
||||
import org.apache.spark.ml.linalg._
|
||||
import org.apache.spark.ml.param._
|
||||
import org.apache.spark.ml.util._
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.json4s.DefaultFormats
|
||||
import scala.collection.{Iterator, mutable}
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.spark.utils.XGBoostWriter
|
||||
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
import org.apache.spark.ml.util.{DefaultXGBoostParamsReader, DefaultXGBoostParamsWriter, XGBoostWriter}
|
||||
import org.apache.spark.sql.types.StructType
|
||||
|
||||
class XGBoostClassifier (
|
||||
@ -274,7 +272,7 @@ class XGBoostClassificationModel private[ml](
|
||||
* Note: The performance is not ideal, use it carefully!
|
||||
*/
|
||||
override def predict(features: Vector): Double = {
|
||||
import DataUtils._
|
||||
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
|
||||
val dm = new DMatrix(processMissingValues(
|
||||
Iterator(features.asXGB),
|
||||
$(missing),
|
||||
@ -469,10 +467,8 @@ object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel]
|
||||
|
||||
override protected def saveImpl(path: String): Unit = {
|
||||
// Save metadata and Params
|
||||
implicit val format = DefaultFormats
|
||||
implicit val sc = super.sparkSession.sparkContext
|
||||
|
||||
DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc)
|
||||
|
||||
// Save model data
|
||||
val dataPath = new Path(path, "data").toString
|
||||
val internalPath = new Path(dataPath, "XGBoostClassificationModel")
|
||||
@ -495,18 +491,7 @@ object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel]
|
||||
val dataPath = new Path(path, "data").toString
|
||||
val internalPath = new Path(dataPath, "XGBoostClassificationModel")
|
||||
val dataInStream = internalPath.getFileSystem(sc.hadoopConfiguration).open(internalPath)
|
||||
|
||||
// The xgboostVersion in the meta can specify if the model is the old xgboost in-compatible
|
||||
// or the new xgboost compatible.
|
||||
val numClasses = metadata.xgboostVersion.map { _ =>
|
||||
implicit val format = DefaultFormats
|
||||
// For binary:logistic, the numClass parameter can't be set to 2 or not be set.
|
||||
// For multi:softprob or multi:softmax, the numClass parameter must be set correctly,
|
||||
// or else, XGBoost will throw exception.
|
||||
// So it's safe to get numClass from meta data.
|
||||
(metadata.params \ "numClass").extractOpt[Int].getOrElse(2)
|
||||
}.getOrElse(dataInStream.readInt())
|
||||
|
||||
val numClasses = DefaultXGBoostParamsReader.getNumClass(metadata, dataInStream)
|
||||
val booster = SXGBoost.loadModel(dataInStream)
|
||||
val model = new XGBoostClassificationModel(metadata.uid, numClasses, booster)
|
||||
DefaultXGBoostParamsReader.getAndSetParams(model, metadata)
|
||||
|
||||
@ -18,8 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import scala.collection.{Iterator, mutable}
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.spark.params.{DefaultXGBoostParamsReader, _}
|
||||
import ml.dmlc.xgboost4j.scala.spark.utils.XGBoostWriter
|
||||
import ml.dmlc.xgboost4j.scala.spark.params._
|
||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
|
||||
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
||||
import org.apache.hadoop.fs.Path
|
||||
@ -30,9 +29,9 @@ import org.apache.spark.ml._
|
||||
import org.apache.spark.ml.param._
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.json4s.DefaultFormats
|
||||
|
||||
import org.apache.spark.broadcast.Broadcast
|
||||
import org.apache.spark.ml.util.{DefaultXGBoostParamsReader, DefaultXGBoostParamsWriter, XGBoostWriter}
|
||||
import org.apache.spark.sql.types.StructType
|
||||
|
||||
class XGBoostRegressor (
|
||||
@ -260,7 +259,7 @@ class XGBoostRegressionModel private[ml] (
|
||||
* Note: The performance is not ideal, use it carefully!
|
||||
*/
|
||||
override def predict(features: Vector): Double = {
|
||||
import DataUtils._
|
||||
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
|
||||
val dm = new DMatrix(processMissingValues(
|
||||
Iterator(features.asXGB),
|
||||
$(missing),
|
||||
@ -384,8 +383,6 @@ object XGBoostRegressionModel extends MLReadable[XGBoostRegressionModel] {
|
||||
|
||||
override protected def saveImpl(path: String): Unit = {
|
||||
// Save metadata and Params
|
||||
implicit val format = DefaultFormats
|
||||
implicit val sc = super.sparkSession.sparkContext
|
||||
DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc)
|
||||
// Save model data
|
||||
val dataPath = new Path(path, "data").toString
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014,2021 by Contributors
|
||||
Copyright (c) 2014-2022 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -18,6 +18,8 @@ package ml.dmlc.xgboost4j.scala.spark.params
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
||||
import ml.dmlc.xgboost4j.scala.spark.TrackerConf
|
||||
import ml.dmlc.xgboost4j.scala.spark.util.Utils
|
||||
|
||||
import org.apache.spark.ml.param.{Param, ParamPair, Params}
|
||||
import org.json4s.{DefaultFormats, Extraction, NoTypeHints}
|
||||
import org.json4s.jackson.JsonMethods.{compact, parse, render}
|
||||
|
||||
@ -1,161 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2014-2022 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark.params
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.spark
|
||||
import org.apache.commons.logging.LogFactory
|
||||
import org.apache.hadoop.fs.Path
|
||||
import org.json4s.{DefaultFormats, JValue}
|
||||
import org.json4s.JsonAST.JObject
|
||||
import org.json4s.jackson.JsonMethods.{compact, parse, render}
|
||||
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.ml.param.Params
|
||||
import org.apache.spark.ml.util.MLReader
|
||||
|
||||
// This originates from apache-spark DefaultPramsReader copy paste
|
||||
private[spark] object DefaultXGBoostParamsReader {
|
||||
|
||||
private val logger = LogFactory.getLog("XGBoostSpark")
|
||||
|
||||
private val paramNameCompatibilityMap: Map[String, String] = Map("silent" -> "verbosity")
|
||||
|
||||
private val paramValueCompatibilityMap: Map[String, Map[Any, Any]] =
|
||||
Map("objective" -> Map("reg:linear" -> "reg:squarederror"))
|
||||
|
||||
/**
|
||||
* All info from metadata file.
|
||||
*
|
||||
* @param params paramMap, as a `JValue`
|
||||
* @param metadata All metadata, including the other fields
|
||||
* @param metadataJson Full metadata file String (for debugging)
|
||||
*/
|
||||
case class Metadata(
|
||||
className: String,
|
||||
uid: String,
|
||||
timestamp: Long,
|
||||
sparkVersion: String,
|
||||
params: JValue,
|
||||
metadata: JValue,
|
||||
metadataJson: String,
|
||||
xgboostVersion: Option[String] = None) {
|
||||
|
||||
/**
|
||||
* Get the JSON value of the [[org.apache.spark.ml.param.Param]] of the given name.
|
||||
* This can be useful for getting a Param value before an instance of `Params`
|
||||
* is available.
|
||||
*/
|
||||
def getParamValue(paramName: String): JValue = {
|
||||
implicit val format = DefaultFormats
|
||||
params match {
|
||||
case JObject(pairs) =>
|
||||
val values = pairs.filter { case (pName, jsonValue) =>
|
||||
pName == paramName
|
||||
}.map(_._2)
|
||||
assert(values.length == 1, s"Expected one instance of Param '$paramName' but found" +
|
||||
s" ${values.length} in JSON Params: " + pairs.map(_.toString).mkString(", "))
|
||||
values.head
|
||||
case _ =>
|
||||
throw new IllegalArgumentException(
|
||||
s"Cannot recognize JSON metadata: $metadataJson.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Load metadata saved using [[DefaultXGBoostParamsWriter.saveMetadata()]]
|
||||
*
|
||||
* @param expectedClassName If non empty, this is checked against the loaded metadata.
|
||||
* @throws IllegalArgumentException if expectedClassName is specified and does not match metadata
|
||||
*/
|
||||
def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = {
|
||||
val metadataPath = new Path(path, "metadata").toString
|
||||
val metadataStr = sc.textFile(metadataPath, 1).first()
|
||||
parseMetadata(metadataStr, expectedClassName)
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse metadata JSON string produced by [[DefaultXGBoostParamsWriter.getMetadataToSave()]].
|
||||
* This is a helper function for [[loadMetadata()]].
|
||||
*
|
||||
* @param metadataStr JSON string of metadata
|
||||
* @param expectedClassName If non empty, this is checked against the loaded metadata.
|
||||
* @throws IllegalArgumentException if expectedClassName is specified and does not match metadata
|
||||
*/
|
||||
def parseMetadata(metadataStr: String, expectedClassName: String = ""): Metadata = {
|
||||
val metadata = parse(metadataStr)
|
||||
|
||||
implicit val format = DefaultFormats
|
||||
val className = (metadata \ "class").extract[String]
|
||||
val uid = (metadata \ "uid").extract[String]
|
||||
val timestamp = (metadata \ "timestamp").extract[Long]
|
||||
val sparkVersion = (metadata \ "sparkVersion").extract[String]
|
||||
val params = metadata \ "paramMap"
|
||||
if (expectedClassName.nonEmpty) {
|
||||
require(className == expectedClassName, s"Error loading metadata: Expected class name" +
|
||||
s" $expectedClassName but found class name $className")
|
||||
}
|
||||
val xgboostVersion = (metadata \ "xgboostVersion").extractOpt[String]
|
||||
Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr, xgboostVersion)
|
||||
}
|
||||
|
||||
private def handleBrokenlyChangedValue[T](paramName: String, value: T): T = {
|
||||
paramValueCompatibilityMap.getOrElse(paramName, Map()).getOrElse(value, value).asInstanceOf[T]
|
||||
}
|
||||
|
||||
private def handleBrokenlyChangedName(paramName: String): String = {
|
||||
paramNameCompatibilityMap.getOrElse(paramName, paramName)
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract Params from metadata, and set them in the instance.
|
||||
* This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]].
|
||||
* TODO: Move to [[Metadata]] method
|
||||
*/
|
||||
def getAndSetParams(instance: Params, metadata: Metadata): Unit = {
|
||||
implicit val format = DefaultFormats
|
||||
metadata.params match {
|
||||
case JObject(pairs) =>
|
||||
pairs.foreach { case (paramName, jsonValue) =>
|
||||
val finalName = handleBrokenlyChangedName(paramName)
|
||||
// For the deleted parameters, we'd better to remove it instead of throwing an exception.
|
||||
// So we need to check if the parameter exists instead of blindly setting it.
|
||||
if (instance.hasParam(finalName)) {
|
||||
val param = instance.getParam(finalName)
|
||||
val value = param.jsonDecode(compact(render(jsonValue)))
|
||||
instance.set(param, handleBrokenlyChangedValue(paramName, value))
|
||||
} else {
|
||||
logger.warn(s"$finalName is no longer used in ${spark.VERSION}")
|
||||
}
|
||||
}
|
||||
case _ =>
|
||||
throw new IllegalArgumentException(
|
||||
s"Cannot recognize JSON metadata: ${metadata.metadataJson}.")
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Load a `Params` instance from the given path, and return it.
|
||||
* This assumes the instance implements [[org.apache.spark.ml.util.MLReadable]].
|
||||
*/
|
||||
def loadParamsInstance[T](path: String, sc: SparkContext): T = {
|
||||
val metadata = DefaultXGBoostParamsReader.loadMetadata(path, sc)
|
||||
val cls = Utils.classForName(metadata.className)
|
||||
cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path)
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,151 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark.params
|
||||
|
||||
import org.apache.hadoop.fs.Path
|
||||
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.ml.param.{ParamPair, Params}
|
||||
import org.json4s.jackson.JsonMethods._
|
||||
import org.json4s.{JArray, JBool, JDouble, JField, JInt, JNothing, JObject, JString, JValue}
|
||||
import JsonDSLXGBoost._
|
||||
import ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
// This originates from apache-spark DefaultPramsWriter copy paste
|
||||
private[spark] object DefaultXGBoostParamsWriter {
|
||||
|
||||
/**
|
||||
* Saves metadata + Params to: path + "/metadata"
|
||||
* - class
|
||||
* - timestamp
|
||||
* - sparkVersion
|
||||
* - uid
|
||||
* - paramMap
|
||||
* - (optionally, extra metadata)
|
||||
*
|
||||
* @param extraMetadata Extra metadata to be saved at same level as uid, paramMap, etc.
|
||||
* @param paramMap If given, this is saved in the "paramMap" field.
|
||||
* Otherwise, all [[org.apache.spark.ml.param.Param]]s are encoded using
|
||||
* [[org.apache.spark.ml.param.Param.jsonEncode()]].
|
||||
*/
|
||||
def saveMetadata(
|
||||
instance: Params,
|
||||
path: String,
|
||||
sc: SparkContext,
|
||||
extraMetadata: Option[JObject] = None,
|
||||
paramMap: Option[JValue] = None): Unit = {
|
||||
|
||||
val metadataPath = new Path(path, "metadata").toString
|
||||
val metadataJson = getMetadataToSave(instance, sc, extraMetadata, paramMap)
|
||||
sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper for [[saveMetadata()]] which extracts the JSON to save.
|
||||
* This is useful for ensemble models which need to save metadata for many sub-models.
|
||||
*
|
||||
* @see [[saveMetadata()]] for details on what this includes.
|
||||
*/
|
||||
def getMetadataToSave(
|
||||
instance: Params,
|
||||
sc: SparkContext,
|
||||
extraMetadata: Option[JObject] = None,
|
||||
paramMap: Option[JValue] = None): String = {
|
||||
val uid = instance.uid
|
||||
val cls = instance.getClass.getName
|
||||
val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
|
||||
val jsonParams = paramMap.getOrElse(render(params.filter{
|
||||
case ParamPair(p, _) => p != null
|
||||
}.map {
|
||||
case ParamPair(p, v) =>
|
||||
p.name -> parse(p.jsonEncode(v))
|
||||
}.toList))
|
||||
val basicMetadata = ("class" -> cls) ~
|
||||
("timestamp" -> System.currentTimeMillis()) ~
|
||||
("sparkVersion" -> sc.version) ~
|
||||
("uid" -> uid) ~
|
||||
("xgboostVersion" -> spark.VERSION) ~
|
||||
("paramMap" -> jsonParams)
|
||||
val metadata = extraMetadata match {
|
||||
case Some(jObject) =>
|
||||
basicMetadata ~ jObject
|
||||
case None =>
|
||||
basicMetadata
|
||||
}
|
||||
val metadataJson: String = compact(render(metadata))
|
||||
metadataJson
|
||||
}
|
||||
}
|
||||
|
||||
// Fix json4s bin-incompatible issue.
|
||||
// This originates from org.json4s.JsonDSL of 3.6.6
|
||||
object JsonDSLXGBoost {
|
||||
|
||||
implicit def seq2jvalue[A](s: Iterable[A])(implicit ev: A => JValue): JArray =
|
||||
JArray(s.toList.map(ev))
|
||||
|
||||
implicit def map2jvalue[A](m: Map[String, A])(implicit ev: A => JValue): JObject =
|
||||
JObject(m.toList.map { case (k, v) => JField(k, ev(v)) })
|
||||
|
||||
implicit def option2jvalue[A](opt: Option[A])(implicit ev: A => JValue): JValue = opt match {
|
||||
case Some(x) => ev(x)
|
||||
case None => JNothing
|
||||
}
|
||||
|
||||
implicit def short2jvalue(x: Short): JValue = JInt(x)
|
||||
implicit def byte2jvalue(x: Byte): JValue = JInt(x)
|
||||
implicit def char2jvalue(x: Char): JValue = JInt(x)
|
||||
implicit def int2jvalue(x: Int): JValue = JInt(x)
|
||||
implicit def long2jvalue(x: Long): JValue = JInt(x)
|
||||
implicit def bigint2jvalue(x: BigInt): JValue = JInt(x)
|
||||
implicit def double2jvalue(x: Double): JValue = JDouble(x)
|
||||
implicit def float2jvalue(x: Float): JValue = JDouble(x.toDouble)
|
||||
implicit def bigdecimal2jvalue(x: BigDecimal): JValue = JDouble(x.doubleValue)
|
||||
implicit def boolean2jvalue(x: Boolean): JValue = JBool(x)
|
||||
implicit def string2jvalue(x: String): JValue = JString(x)
|
||||
|
||||
implicit def symbol2jvalue(x: Symbol): JString = JString(x.name)
|
||||
implicit def pair2jvalue[A](t: (String, A))(implicit ev: A => JValue): JObject =
|
||||
JObject(List(JField(t._1, ev(t._2))))
|
||||
implicit def list2jvalue(l: List[JField]): JObject = JObject(l)
|
||||
implicit def jobject2assoc(o: JObject): JsonListAssoc = new JsonListAssoc(o.obj)
|
||||
implicit def pair2Assoc[A](t: (String, A))(implicit ev: A => JValue): JsonAssoc[A] =
|
||||
new JsonAssoc(t)
|
||||
}
|
||||
|
||||
final class JsonAssoc[A](private val left: (String, A)) extends AnyVal {
|
||||
def ~[B](right: (String, B))(implicit ev1: A => JValue, ev2: B => JValue): JObject = {
|
||||
val l: JValue = ev1(left._2)
|
||||
val r: JValue = ev2(right._2)
|
||||
JObject(JField(left._1, l) :: JField(right._1, r) :: Nil)
|
||||
}
|
||||
|
||||
def ~(right: JObject)(implicit ev: A => JValue): JObject = {
|
||||
val l: JValue = ev(left._2)
|
||||
JObject(JField(left._1, l) :: right.obj)
|
||||
}
|
||||
def ~~[B](right: (String, B))(implicit ev1: A => JValue, ev2: B => JValue): JObject =
|
||||
this.~(right)
|
||||
def ~~(right: JObject)(implicit ev: A => JValue): JObject = this.~(right)
|
||||
}
|
||||
|
||||
final class JsonListAssoc(private val left: List[JField]) extends AnyVal {
|
||||
def ~(right: (String, JValue)): JObject = JObject(left ::: List(JField(right._1, right._2)))
|
||||
def ~(right: JObject): JObject = JObject(left ::: right.obj)
|
||||
def ~~(right: (String, JValue)): JObject = this.~(right)
|
||||
def ~~(right: JObject): JObject = this.~(right)
|
||||
}
|
||||
@ -17,9 +17,9 @@
|
||||
package ml.dmlc.xgboost4j.scala.spark.params
|
||||
|
||||
import org.apache.spark.ml.feature.VectorAssembler
|
||||
import org.apache.spark.ml.linalg.xgboost.XGBoostSchemaUtils
|
||||
import org.apache.spark.ml.param.{Param, ParamValidators}
|
||||
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasHandleInvalid, HasLabelCol, HasWeightCol}
|
||||
import org.apache.spark.ml.util.XGBoostSchemaUtils
|
||||
import org.apache.spark.sql.Dataset
|
||||
import org.apache.spark.sql.types.StructType
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014,2021 by Contributors
|
||||
Copyright (c) 2014-2022 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -14,7 +14,7 @@
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
package ml.dmlc.xgboost4j.scala.spark.util
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
@ -24,8 +24,8 @@ import org.apache.spark.HashPartitioner
|
||||
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
|
||||
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.{Column, DataFrame, Row}
|
||||
import org.apache.spark.sql.types.{FloatType, IntegerType}
|
||||
import org.apache.spark.sql.{Column, DataFrame, Row}
|
||||
|
||||
object DataUtils extends Serializable {
|
||||
private[spark] implicit class XGBLabeledPointFeatures(
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014,2021 by Contributors
|
||||
Copyright (c) 2014-2022 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -14,7 +14,7 @@
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark.params
|
||||
package ml.dmlc.xgboost4j.scala.spark.util
|
||||
|
||||
import org.json4s.{DefaultFormats, FullTypeHints, JField, JValue, NoTypeHints, TypeHints}
|
||||
|
||||
@ -1,31 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2022 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark.utils
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{Booster => JBooster}
|
||||
|
||||
import org.apache.spark.ml.util.MLWriter
|
||||
|
||||
private[spark] abstract class XGBoostWriter extends MLWriter {
|
||||
|
||||
/** Currently it's using the "deprecated" format as
|
||||
* default, which will be changed into `ubj` in future releases. */
|
||||
def getModelFormat(): String = {
|
||||
optionMap.getOrElse("format", JBooster.DEFAULT_FORMAT)
|
||||
}
|
||||
|
||||
}
|
||||
@ -0,0 +1,150 @@
|
||||
/*
|
||||
Copyright (c) 2022 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.ml.util
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{Booster => JBooster}
|
||||
import ml.dmlc.xgboost4j.scala.spark
|
||||
import org.apache.commons.logging.LogFactory
|
||||
import org.apache.hadoop.fs.FSDataInputStream
|
||||
import org.json4s.DefaultFormats
|
||||
import org.json4s.JsonAST.JObject
|
||||
import org.json4s.JsonDSL._
|
||||
import org.json4s.jackson.JsonMethods.{compact, render}
|
||||
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.ml.param.Params
|
||||
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
|
||||
|
||||
abstract class XGBoostWriter extends MLWriter {
|
||||
|
||||
/** Currently it's using the "deprecated" format as
|
||||
* default, which will be changed into `ubj` in future releases. */
|
||||
def getModelFormat(): String = {
|
||||
optionMap.getOrElse("format", JBooster.DEFAULT_FORMAT)
|
||||
}
|
||||
}
|
||||
|
||||
object DefaultXGBoostParamsWriter {
|
||||
|
||||
val XGBOOST_VERSION_TAG = "xgboostVersion"
|
||||
|
||||
/**
|
||||
* Saves metadata + Params to: path + "/metadata" using [[DefaultParamsWriter.saveMetadata]]
|
||||
*/
|
||||
def saveMetadata(
|
||||
instance: Params,
|
||||
path: String,
|
||||
sc: SparkContext): Unit = {
|
||||
// save xgboost version to distinguish the old model.
|
||||
val extraMetadata: JObject = Map(XGBOOST_VERSION_TAG -> ml.dmlc.xgboost4j.scala.spark.VERSION)
|
||||
DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata))
|
||||
}
|
||||
}
|
||||
|
||||
object DefaultXGBoostParamsReader {
|
||||
|
||||
private val logger = LogFactory.getLog("XGBoostSpark")
|
||||
|
||||
/**
|
||||
* Load metadata saved using [[DefaultParamsReader.loadMetadata()]]
|
||||
*
|
||||
* @param expectedClassName If non empty, this is checked against the loaded metadata.
|
||||
* @throws IllegalArgumentException if expectedClassName is specified and does not match metadata
|
||||
*/
|
||||
def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = {
|
||||
DefaultParamsReader.loadMetadata(path, sc, expectedClassName)
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract Params from metadata, and set them in the instance.
|
||||
* This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]].
|
||||
*
|
||||
* And it will auto-skip the parameter not defined.
|
||||
*
|
||||
* This API is mainly copied from DefaultParamsReader
|
||||
*/
|
||||
def getAndSetParams(instance: Params, metadata: Metadata): Unit = {
|
||||
|
||||
// XGBoost didn't set the default parameters since the save/load code is copied
|
||||
// from spark 2.3.x, which means it just used the default values
|
||||
// as the same with XGBoost version instead of them in model.
|
||||
// For the compatibility, here we still don't set the default parameters.
|
||||
// setParams(instance, metadata, isDefault = true)
|
||||
|
||||
setParams(instance, metadata, isDefault = false)
|
||||
}
|
||||
|
||||
/** This API is only for XGBoostClassificationModel */
|
||||
def getNumClass(metadata: Metadata, dataInStream: FSDataInputStream): Int = {
|
||||
implicit val format = DefaultFormats
|
||||
|
||||
// The xgboostVersion in the meta can specify if the model is the old xgboost in-compatible
|
||||
// or the new xgboost compatible.
|
||||
val xgbVerOpt = (metadata.metadata \ DefaultXGBoostParamsWriter.XGBOOST_VERSION_TAG)
|
||||
.extractOpt[String]
|
||||
|
||||
// For binary:logistic, the numClass parameter can't be set to 2 or not be set.
|
||||
// For multi:softprob or multi:softmax, the numClass parameter must be set correctly,
|
||||
// or else, XGBoost will throw exception.
|
||||
// So it's safe to get numClass from meta data.
|
||||
xgbVerOpt
|
||||
.map { _ => (metadata.params \ "numClass").extractOpt[Int].getOrElse(2) }
|
||||
.getOrElse(dataInStream.readInt())
|
||||
|
||||
}
|
||||
|
||||
private def setParams(
|
||||
instance: Params,
|
||||
metadata: Metadata,
|
||||
isDefault: Boolean): Unit = {
|
||||
val paramsToSet = if (isDefault) metadata.defaultParams else metadata.params
|
||||
paramsToSet match {
|
||||
case JObject(pairs) =>
|
||||
pairs.foreach { case (paramName, jsonValue) =>
|
||||
val finalName = handleBrokenlyChangedName(paramName)
|
||||
// For the deleted parameters, we'd better to remove it instead of throwing an exception.
|
||||
// So we need to check if the parameter exists instead of blindly setting it.
|
||||
if (instance.hasParam(finalName)) {
|
||||
val param = instance.getParam(finalName)
|
||||
val value = param.jsonDecode(compact(render(jsonValue)))
|
||||
instance.set(param, handleBrokenlyChangedValue(paramName, value))
|
||||
} else {
|
||||
logger.warn(s"$finalName is no longer used in ${spark.VERSION}")
|
||||
}
|
||||
}
|
||||
case _ =>
|
||||
throw new IllegalArgumentException(
|
||||
s"Cannot recognize JSON metadata: ${metadata.metadataJson}.")
|
||||
}
|
||||
}
|
||||
|
||||
private val paramNameCompatibilityMap: Map[String, String] = Map("silent" -> "verbosity")
|
||||
|
||||
/** This is really not good to do this transformation, but it is needed since there're
|
||||
* some tests based on 0.82 saved model in which the objective is "reg:linear" */
|
||||
private val paramValueCompatibilityMap: Map[String, Map[Any, Any]] =
|
||||
Map("objective" -> Map("reg:linear" -> "reg:squarederror"))
|
||||
|
||||
private def handleBrokenlyChangedName(paramName: String): String = {
|
||||
paramNameCompatibilityMap.getOrElse(paramName, paramName)
|
||||
}
|
||||
|
||||
private def handleBrokenlyChangedValue[T](paramName: String, value: T): T = {
|
||||
paramValueCompatibilityMap.getOrElse(paramName, Map()).getOrElse(value, value).asInstanceOf[T]
|
||||
}
|
||||
|
||||
}
|
||||
@ -14,7 +14,7 @@
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.ml.linalg.xgboost
|
||||
package org.apache.spark.ml.util
|
||||
|
||||
import org.apache.spark.sql.types.{BooleanType, DataType, NumericType, StructType}
|
||||
import org.apache.spark.ml.linalg.VectorUDT
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014-2022 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -18,7 +18,8 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import org.apache.spark.ml.linalg.Vectors
|
||||
import org.scalatest.FunSuite
|
||||
import ml.dmlc.xgboost4j.scala.spark.DataUtils.PackedParams
|
||||
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils
|
||||
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils.PackedParams
|
||||
|
||||
import org.apache.spark.sql.functions._
|
||||
|
||||
|
||||
@ -50,7 +50,7 @@ class FeatureSizeValidatingSuite extends FunSuite with PerTest {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic",
|
||||
"num_round" -> 5, "num_workers" -> 2, "use_external_memory" -> true, "missing" -> 0)
|
||||
import DataUtils._
|
||||
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
|
||||
val sparkSession = ss
|
||||
import sparkSession.implicits._
|
||||
val repartitioned = sc.parallelize(Synthetic.trainWithDiffFeatureSize, 2)
|
||||
|
||||
@ -73,7 +73,7 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
|
||||
protected def buildDataFrame(
|
||||
labeledPoints: Seq[XGBLabeledPoint],
|
||||
numPartitions: Int = numWorkers): DataFrame = {
|
||||
import DataUtils._
|
||||
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
|
||||
val it = labeledPoints.iterator.zipWithIndex
|
||||
.map { case (labeledPoint: XGBLabeledPoint, id: Int) =>
|
||||
(id, labeledPoint.label, labeledPoint.features)
|
||||
@ -98,7 +98,7 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
|
||||
protected def buildDataFrameWithGroup(
|
||||
labeledPoints: Seq[XGBLabeledPoint],
|
||||
numPartitions: Int = numWorkers): DataFrame = {
|
||||
import DataUtils._
|
||||
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
|
||||
val it = labeledPoints.iterator.zipWithIndex
|
||||
.map { case (labeledPoint: XGBLabeledPoint, id: Int) =>
|
||||
(id, labeledPoint.label, labeledPoint.features, labeledPoint.group)
|
||||
|
||||
@ -310,7 +310,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest with TmpFolderPerSuit
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic",
|
||||
"num_round" -> 5, "num_workers" -> 2, "missing" -> 0)
|
||||
import DataUtils._
|
||||
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
|
||||
val sparkSession = SparkSession.builder().getOrCreate()
|
||||
import sparkSession.implicits._
|
||||
val repartitioned = sc.parallelize(Synthetic.train, 3).map(lp => (lp.label, lp)).partitionBy(
|
||||
@ -331,7 +331,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest with TmpFolderPerSuit
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic",
|
||||
"num_round" -> 5, "num_workers" -> 2, "use_external_memory" -> true, "missing" -> 0)
|
||||
import DataUtils._
|
||||
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
|
||||
val sparkSession = SparkSession.builder().getOrCreate()
|
||||
import sparkSession.implicits._
|
||||
val repartitioned = sc.parallelize(Synthetic.train, 3).map(lp => (lp.label, lp)).partitionBy(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user