From 545fd4548e303931dafd98d6606454fcdc2b8f2f Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Wed, 1 Jun 2022 11:38:49 +0800 Subject: [PATCH] [jvm-packages] refactor xgboost read/write (#7956) 1. Removed the duplicated Default XGBoost read/write which is copied from spark 2.3.x 2. Put some utils into util package --- .../xgboost4j/scala/spark/PreXGBoost.scala | 9 +- .../scala/spark/XGBoostClassifier.scala | 25 +-- .../scala/spark/XGBoostRegressor.scala | 9 +- .../scala/spark/params/CustomParams.scala | 4 +- .../params/DefaultXGBoostParamsReader.scala | 161 ------------------ .../params/DefaultXGBoostParamsWriter.scala | 151 ---------------- .../spark/params/XGBoostEstimatorCommon.scala | 2 +- .../scala/spark/{ => util}/DataUtils.scala | 6 +- .../scala/spark/{params => util}/Utils.scala | 4 +- .../scala/spark/utils/XGBoostReadWrite.scala | 31 ---- .../spark/ml/util/XGBoostReadWrite.scala | 150 ++++++++++++++++ .../xgboost => util}/XGBoostSchemaUtils.scala | 2 +- .../DeterministicPartitioningSuite.scala | 5 +- .../spark/FeatureSizeValidatingSuite.scala | 2 +- .../dmlc/xgboost4j/scala/spark/PerTest.scala | 4 +- .../scala/spark/XGBoostClassifierSuite.scala | 4 +- 16 files changed, 180 insertions(+), 389 deletions(-) delete mode 100644 jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/DefaultXGBoostParamsReader.scala delete mode 100644 jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/DefaultXGBoostParamsWriter.scala rename jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/{ => util}/DataUtils.scala (99%) rename jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/{params => util}/Utils.scala (96%) delete mode 100644 jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/utils/XGBoostReadWrite.scala create mode 100644 jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/ml/util/XGBoostReadWrite.scala rename jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/ml/{linalg/xgboost => util}/XGBoostSchemaUtils.scala (97%) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/PreXGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/PreXGBoost.scala index 13484f490..e0a365f6c 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/PreXGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/PreXGBoost.scala @@ -24,8 +24,9 @@ import scala.collection.{AbstractIterator, Iterator, mutable} import ml.dmlc.xgboost4j.java.Rabit import ml.dmlc.xgboost4j.scala.{Booster, DMatrix} -import ml.dmlc.xgboost4j.scala.spark.DataUtils.PackedParams +import ml.dmlc.xgboost4j.scala.spark.util.DataUtils.PackedParams import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon +import ml.dmlc.xgboost4j.scala.spark.util.DataUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} @@ -35,10 +36,8 @@ import org.apache.commons.logging.LogFactory import org.apache.spark.TaskContext import org.apache.spark.broadcast.Broadcast -import org.apache.spark.ml.feature.VectorAssembler -import org.apache.spark.ml.{Estimator, Model, PipelineStage} +import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.linalg.Vector -import org.apache.spark.ml.linalg.xgboost.XGBoostSchemaUtils import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType} import org.apache.spark.storage.StorageLevel @@ -272,7 +271,7 @@ object PreXGBoost extends PreXGBoostProvider { val features = batchRow.iterator.map(row => row.getAs[Vector](featuresCol)) - import DataUtils._ + import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._ val cacheInfo = { if (useExternalMemory) { s"$appName-${TaskContext.get().stageId()}-dtest_cache-" + diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala index 59be8b1c7..acc9febff 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala @@ -23,15 +23,13 @@ import org.apache.hadoop.fs.Path import org.apache.spark.broadcast.Broadcast import org.apache.spark.ml.classification._ import org.apache.spark.ml.linalg._ -import org.apache.spark.ml.param._ import org.apache.spark.ml.util._ import org.apache.spark.sql._ import org.apache.spark.sql.functions._ -import org.json4s.DefaultFormats import scala.collection.{Iterator, mutable} -import ml.dmlc.xgboost4j.scala.spark.utils.XGBoostWriter - +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util.{DefaultXGBoostParamsReader, DefaultXGBoostParamsWriter, XGBoostWriter} import org.apache.spark.sql.types.StructType class XGBoostClassifier ( @@ -274,7 +272,7 @@ class XGBoostClassificationModel private[ml]( * Note: The performance is not ideal, use it carefully! */ override def predict(features: Vector): Double = { - import DataUtils._ + import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._ val dm = new DMatrix(processMissingValues( Iterator(features.asXGB), $(missing), @@ -469,10 +467,8 @@ object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel] override protected def saveImpl(path: String): Unit = { // Save metadata and Params - implicit val format = DefaultFormats - implicit val sc = super.sparkSession.sparkContext - DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc) + // Save model data val dataPath = new Path(path, "data").toString val internalPath = new Path(dataPath, "XGBoostClassificationModel") @@ -495,18 +491,7 @@ object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel] val dataPath = new Path(path, "data").toString val internalPath = new Path(dataPath, "XGBoostClassificationModel") val dataInStream = internalPath.getFileSystem(sc.hadoopConfiguration).open(internalPath) - - // The xgboostVersion in the meta can specify if the model is the old xgboost in-compatible - // or the new xgboost compatible. - val numClasses = metadata.xgboostVersion.map { _ => - implicit val format = DefaultFormats - // For binary:logistic, the numClass parameter can't be set to 2 or not be set. - // For multi:softprob or multi:softmax, the numClass parameter must be set correctly, - // or else, XGBoost will throw exception. - // So it's safe to get numClass from meta data. - (metadata.params \ "numClass").extractOpt[Int].getOrElse(2) - }.getOrElse(dataInStream.readInt()) - + val numClasses = DefaultXGBoostParamsReader.getNumClass(metadata, dataInStream) val booster = SXGBoost.loadModel(dataInStream) val model = new XGBoostClassificationModel(metadata.uid, numClasses, booster) DefaultXGBoostParamsReader.getAndSetParams(model, metadata) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala index fcdd347e4..77e0ac6b0 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala @@ -18,8 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark import scala.collection.{Iterator, mutable} -import ml.dmlc.xgboost4j.scala.spark.params.{DefaultXGBoostParamsReader, _} -import ml.dmlc.xgboost4j.scala.spark.utils.XGBoostWriter +import ml.dmlc.xgboost4j.scala.spark.params._ import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost} import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait} import org.apache.hadoop.fs.Path @@ -30,9 +29,9 @@ import org.apache.spark.ml._ import org.apache.spark.ml.param._ import org.apache.spark.sql._ import org.apache.spark.sql.functions._ -import org.json4s.DefaultFormats import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.util.{DefaultXGBoostParamsReader, DefaultXGBoostParamsWriter, XGBoostWriter} import org.apache.spark.sql.types.StructType class XGBoostRegressor ( @@ -260,7 +259,7 @@ class XGBoostRegressionModel private[ml] ( * Note: The performance is not ideal, use it carefully! */ override def predict(features: Vector): Double = { - import DataUtils._ + import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._ val dm = new DMatrix(processMissingValues( Iterator(features.asXGB), $(missing), @@ -384,8 +383,6 @@ object XGBoostRegressionModel extends MLReadable[XGBoostRegressionModel] { override protected def saveImpl(path: String): Unit = { // Save metadata and Params - implicit val format = DefaultFormats - implicit val sc = super.sparkSession.sparkContext DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc) // Save model data val dataPath = new Path(path, "data").toString diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala index c74560218..f838baac2 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2014,2021 by Contributors + Copyright (c) 2014-2022 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,6 +18,8 @@ package ml.dmlc.xgboost4j.scala.spark.params import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait} import ml.dmlc.xgboost4j.scala.spark.TrackerConf +import ml.dmlc.xgboost4j.scala.spark.util.Utils + import org.apache.spark.ml.param.{Param, ParamPair, Params} import org.json4s.{DefaultFormats, Extraction, NoTypeHints} import org.json4s.jackson.JsonMethods.{compact, parse, render} diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/DefaultXGBoostParamsReader.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/DefaultXGBoostParamsReader.scala deleted file mode 100644 index 9fc644b8f..000000000 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/DefaultXGBoostParamsReader.scala +++ /dev/null @@ -1,161 +0,0 @@ -/* - Copyright (c) 2014-2022 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -package ml.dmlc.xgboost4j.scala.spark.params - -import ml.dmlc.xgboost4j.scala.spark -import org.apache.commons.logging.LogFactory -import org.apache.hadoop.fs.Path -import org.json4s.{DefaultFormats, JValue} -import org.json4s.JsonAST.JObject -import org.json4s.jackson.JsonMethods.{compact, parse, render} - -import org.apache.spark.SparkContext -import org.apache.spark.ml.param.Params -import org.apache.spark.ml.util.MLReader - -// This originates from apache-spark DefaultPramsReader copy paste -private[spark] object DefaultXGBoostParamsReader { - - private val logger = LogFactory.getLog("XGBoostSpark") - - private val paramNameCompatibilityMap: Map[String, String] = Map("silent" -> "verbosity") - - private val paramValueCompatibilityMap: Map[String, Map[Any, Any]] = - Map("objective" -> Map("reg:linear" -> "reg:squarederror")) - - /** - * All info from metadata file. - * - * @param params paramMap, as a `JValue` - * @param metadata All metadata, including the other fields - * @param metadataJson Full metadata file String (for debugging) - */ - case class Metadata( - className: String, - uid: String, - timestamp: Long, - sparkVersion: String, - params: JValue, - metadata: JValue, - metadataJson: String, - xgboostVersion: Option[String] = None) { - - /** - * Get the JSON value of the [[org.apache.spark.ml.param.Param]] of the given name. - * This can be useful for getting a Param value before an instance of `Params` - * is available. - */ - def getParamValue(paramName: String): JValue = { - implicit val format = DefaultFormats - params match { - case JObject(pairs) => - val values = pairs.filter { case (pName, jsonValue) => - pName == paramName - }.map(_._2) - assert(values.length == 1, s"Expected one instance of Param '$paramName' but found" + - s" ${values.length} in JSON Params: " + pairs.map(_.toString).mkString(", ")) - values.head - case _ => - throw new IllegalArgumentException( - s"Cannot recognize JSON metadata: $metadataJson.") - } - } - } - - /** - * Load metadata saved using [[DefaultXGBoostParamsWriter.saveMetadata()]] - * - * @param expectedClassName If non empty, this is checked against the loaded metadata. - * @throws IllegalArgumentException if expectedClassName is specified and does not match metadata - */ - def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = { - val metadataPath = new Path(path, "metadata").toString - val metadataStr = sc.textFile(metadataPath, 1).first() - parseMetadata(metadataStr, expectedClassName) - } - - /** - * Parse metadata JSON string produced by [[DefaultXGBoostParamsWriter.getMetadataToSave()]]. - * This is a helper function for [[loadMetadata()]]. - * - * @param metadataStr JSON string of metadata - * @param expectedClassName If non empty, this is checked against the loaded metadata. - * @throws IllegalArgumentException if expectedClassName is specified and does not match metadata - */ - def parseMetadata(metadataStr: String, expectedClassName: String = ""): Metadata = { - val metadata = parse(metadataStr) - - implicit val format = DefaultFormats - val className = (metadata \ "class").extract[String] - val uid = (metadata \ "uid").extract[String] - val timestamp = (metadata \ "timestamp").extract[Long] - val sparkVersion = (metadata \ "sparkVersion").extract[String] - val params = metadata \ "paramMap" - if (expectedClassName.nonEmpty) { - require(className == expectedClassName, s"Error loading metadata: Expected class name" + - s" $expectedClassName but found class name $className") - } - val xgboostVersion = (metadata \ "xgboostVersion").extractOpt[String] - Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr, xgboostVersion) - } - - private def handleBrokenlyChangedValue[T](paramName: String, value: T): T = { - paramValueCompatibilityMap.getOrElse(paramName, Map()).getOrElse(value, value).asInstanceOf[T] - } - - private def handleBrokenlyChangedName(paramName: String): String = { - paramNameCompatibilityMap.getOrElse(paramName, paramName) - } - - /** - * Extract Params from metadata, and set them in the instance. - * This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]]. - * TODO: Move to [[Metadata]] method - */ - def getAndSetParams(instance: Params, metadata: Metadata): Unit = { - implicit val format = DefaultFormats - metadata.params match { - case JObject(pairs) => - pairs.foreach { case (paramName, jsonValue) => - val finalName = handleBrokenlyChangedName(paramName) - // For the deleted parameters, we'd better to remove it instead of throwing an exception. - // So we need to check if the parameter exists instead of blindly setting it. - if (instance.hasParam(finalName)) { - val param = instance.getParam(finalName) - val value = param.jsonDecode(compact(render(jsonValue))) - instance.set(param, handleBrokenlyChangedValue(paramName, value)) - } else { - logger.warn(s"$finalName is no longer used in ${spark.VERSION}") - } - } - case _ => - throw new IllegalArgumentException( - s"Cannot recognize JSON metadata: ${metadata.metadataJson}.") - } - } - - /** - * Load a `Params` instance from the given path, and return it. - * This assumes the instance implements [[org.apache.spark.ml.util.MLReadable]]. - */ - def loadParamsInstance[T](path: String, sc: SparkContext): T = { - val metadata = DefaultXGBoostParamsReader.loadMetadata(path, sc) - val cls = Utils.classForName(metadata.className) - cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path) - } -} - diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/DefaultXGBoostParamsWriter.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/DefaultXGBoostParamsWriter.scala deleted file mode 100644 index 38aa814c2..000000000 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/DefaultXGBoostParamsWriter.scala +++ /dev/null @@ -1,151 +0,0 @@ -/* - Copyright (c) 2014 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -package ml.dmlc.xgboost4j.scala.spark.params - -import org.apache.hadoop.fs.Path - -import org.apache.spark.SparkContext -import org.apache.spark.ml.param.{ParamPair, Params} -import org.json4s.jackson.JsonMethods._ -import org.json4s.{JArray, JBool, JDouble, JField, JInt, JNothing, JObject, JString, JValue} -import JsonDSLXGBoost._ -import ml.dmlc.xgboost4j.scala.spark - -// This originates from apache-spark DefaultPramsWriter copy paste -private[spark] object DefaultXGBoostParamsWriter { - - /** - * Saves metadata + Params to: path + "/metadata" - * - class - * - timestamp - * - sparkVersion - * - uid - * - paramMap - * - (optionally, extra metadata) - * - * @param extraMetadata Extra metadata to be saved at same level as uid, paramMap, etc. - * @param paramMap If given, this is saved in the "paramMap" field. - * Otherwise, all [[org.apache.spark.ml.param.Param]]s are encoded using - * [[org.apache.spark.ml.param.Param.jsonEncode()]]. - */ - def saveMetadata( - instance: Params, - path: String, - sc: SparkContext, - extraMetadata: Option[JObject] = None, - paramMap: Option[JValue] = None): Unit = { - - val metadataPath = new Path(path, "metadata").toString - val metadataJson = getMetadataToSave(instance, sc, extraMetadata, paramMap) - sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath) - } - - /** - * Helper for [[saveMetadata()]] which extracts the JSON to save. - * This is useful for ensemble models which need to save metadata for many sub-models. - * - * @see [[saveMetadata()]] for details on what this includes. - */ - def getMetadataToSave( - instance: Params, - sc: SparkContext, - extraMetadata: Option[JObject] = None, - paramMap: Option[JValue] = None): String = { - val uid = instance.uid - val cls = instance.getClass.getName - val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]] - val jsonParams = paramMap.getOrElse(render(params.filter{ - case ParamPair(p, _) => p != null - }.map { - case ParamPair(p, v) => - p.name -> parse(p.jsonEncode(v)) - }.toList)) - val basicMetadata = ("class" -> cls) ~ - ("timestamp" -> System.currentTimeMillis()) ~ - ("sparkVersion" -> sc.version) ~ - ("uid" -> uid) ~ - ("xgboostVersion" -> spark.VERSION) ~ - ("paramMap" -> jsonParams) - val metadata = extraMetadata match { - case Some(jObject) => - basicMetadata ~ jObject - case None => - basicMetadata - } - val metadataJson: String = compact(render(metadata)) - metadataJson - } -} - -// Fix json4s bin-incompatible issue. -// This originates from org.json4s.JsonDSL of 3.6.6 -object JsonDSLXGBoost { - - implicit def seq2jvalue[A](s: Iterable[A])(implicit ev: A => JValue): JArray = - JArray(s.toList.map(ev)) - - implicit def map2jvalue[A](m: Map[String, A])(implicit ev: A => JValue): JObject = - JObject(m.toList.map { case (k, v) => JField(k, ev(v)) }) - - implicit def option2jvalue[A](opt: Option[A])(implicit ev: A => JValue): JValue = opt match { - case Some(x) => ev(x) - case None => JNothing - } - - implicit def short2jvalue(x: Short): JValue = JInt(x) - implicit def byte2jvalue(x: Byte): JValue = JInt(x) - implicit def char2jvalue(x: Char): JValue = JInt(x) - implicit def int2jvalue(x: Int): JValue = JInt(x) - implicit def long2jvalue(x: Long): JValue = JInt(x) - implicit def bigint2jvalue(x: BigInt): JValue = JInt(x) - implicit def double2jvalue(x: Double): JValue = JDouble(x) - implicit def float2jvalue(x: Float): JValue = JDouble(x.toDouble) - implicit def bigdecimal2jvalue(x: BigDecimal): JValue = JDouble(x.doubleValue) - implicit def boolean2jvalue(x: Boolean): JValue = JBool(x) - implicit def string2jvalue(x: String): JValue = JString(x) - - implicit def symbol2jvalue(x: Symbol): JString = JString(x.name) - implicit def pair2jvalue[A](t: (String, A))(implicit ev: A => JValue): JObject = - JObject(List(JField(t._1, ev(t._2)))) - implicit def list2jvalue(l: List[JField]): JObject = JObject(l) - implicit def jobject2assoc(o: JObject): JsonListAssoc = new JsonListAssoc(o.obj) - implicit def pair2Assoc[A](t: (String, A))(implicit ev: A => JValue): JsonAssoc[A] = - new JsonAssoc(t) -} - -final class JsonAssoc[A](private val left: (String, A)) extends AnyVal { - def ~[B](right: (String, B))(implicit ev1: A => JValue, ev2: B => JValue): JObject = { - val l: JValue = ev1(left._2) - val r: JValue = ev2(right._2) - JObject(JField(left._1, l) :: JField(right._1, r) :: Nil) - } - - def ~(right: JObject)(implicit ev: A => JValue): JObject = { - val l: JValue = ev(left._2) - JObject(JField(left._1, l) :: right.obj) - } - def ~~[B](right: (String, B))(implicit ev1: A => JValue, ev2: B => JValue): JObject = - this.~(right) - def ~~(right: JObject)(implicit ev: A => JValue): JObject = this.~(right) -} - -final class JsonListAssoc(private val left: List[JField]) extends AnyVal { - def ~(right: (String, JValue)): JObject = JObject(left ::: List(JField(right._1, right._2))) - def ~(right: JObject): JObject = JObject(left ::: right.obj) - def ~~(right: (String, JValue)): JObject = this.~(right) - def ~~(right: JObject): JObject = this.~(right) -} diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/XGBoostEstimatorCommon.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/XGBoostEstimatorCommon.scala index 2153c5846..9581ea0f2 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/XGBoostEstimatorCommon.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/XGBoostEstimatorCommon.scala @@ -17,9 +17,9 @@ package ml.dmlc.xgboost4j.scala.spark.params import org.apache.spark.ml.feature.VectorAssembler -import org.apache.spark.ml.linalg.xgboost.XGBoostSchemaUtils import org.apache.spark.ml.param.{Param, ParamValidators} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasHandleInvalid, HasLabelCol, HasWeightCol} +import org.apache.spark.ml.util.XGBoostSchemaUtils import org.apache.spark.sql.Dataset import org.apache.spark.sql.types.StructType diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/DataUtils.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/util/DataUtils.scala similarity index 99% rename from jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/DataUtils.scala rename to jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/util/DataUtils.scala index a34c49daf..acc605b1f 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/DataUtils.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/util/DataUtils.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2014,2021 by Contributors + Copyright (c) 2014-2022 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ limitations under the License. */ -package ml.dmlc.xgboost4j.scala.spark +package ml.dmlc.xgboost4j.scala.spark.util import scala.collection.mutable @@ -24,8 +24,8 @@ import org.apache.spark.HashPartitioner import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint} import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Column, DataFrame, Row} import org.apache.spark.sql.types.{FloatType, IntegerType} +import org.apache.spark.sql.{Column, DataFrame, Row} object DataUtils extends Serializable { private[spark] implicit class XGBLabeledPointFeatures( diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/Utils.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/util/Utils.scala similarity index 96% rename from jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/Utils.scala rename to jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/util/Utils.scala index fb84ad6d6..d5e133b4c 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/Utils.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/util/Utils.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2014,2021 by Contributors + Copyright (c) 2014-2022 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ limitations under the License. */ -package ml.dmlc.xgboost4j.scala.spark.params +package ml.dmlc.xgboost4j.scala.spark.util import org.json4s.{DefaultFormats, FullTypeHints, JField, JValue, NoTypeHints, TypeHints} diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/utils/XGBoostReadWrite.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/utils/XGBoostReadWrite.scala deleted file mode 100644 index 0fbd36bf3..000000000 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/utils/XGBoostReadWrite.scala +++ /dev/null @@ -1,31 +0,0 @@ -/* - Copyright (c) 2022 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -package ml.dmlc.xgboost4j.scala.spark.utils - -import ml.dmlc.xgboost4j.java.{Booster => JBooster} - -import org.apache.spark.ml.util.MLWriter - -private[spark] abstract class XGBoostWriter extends MLWriter { - - /** Currently it's using the "deprecated" format as - * default, which will be changed into `ubj` in future releases. */ - def getModelFormat(): String = { - optionMap.getOrElse("format", JBooster.DEFAULT_FORMAT) - } - -} diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/ml/util/XGBoostReadWrite.scala b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/ml/util/XGBoostReadWrite.scala new file mode 100644 index 000000000..672241be1 --- /dev/null +++ b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/ml/util/XGBoostReadWrite.scala @@ -0,0 +1,150 @@ +/* + Copyright (c) 2022 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package org.apache.spark.ml.util + +import ml.dmlc.xgboost4j.java.{Booster => JBooster} +import ml.dmlc.xgboost4j.scala.spark +import org.apache.commons.logging.LogFactory +import org.apache.hadoop.fs.FSDataInputStream +import org.json4s.DefaultFormats +import org.json4s.JsonAST.JObject +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods.{compact, render} + +import org.apache.spark.SparkContext +import org.apache.spark.ml.param.Params +import org.apache.spark.ml.util.DefaultParamsReader.Metadata + +abstract class XGBoostWriter extends MLWriter { + + /** Currently it's using the "deprecated" format as + * default, which will be changed into `ubj` in future releases. */ + def getModelFormat(): String = { + optionMap.getOrElse("format", JBooster.DEFAULT_FORMAT) + } +} + +object DefaultXGBoostParamsWriter { + + val XGBOOST_VERSION_TAG = "xgboostVersion" + + /** + * Saves metadata + Params to: path + "/metadata" using [[DefaultParamsWriter.saveMetadata]] + */ + def saveMetadata( + instance: Params, + path: String, + sc: SparkContext): Unit = { + // save xgboost version to distinguish the old model. + val extraMetadata: JObject = Map(XGBOOST_VERSION_TAG -> ml.dmlc.xgboost4j.scala.spark.VERSION) + DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata)) + } +} + +object DefaultXGBoostParamsReader { + + private val logger = LogFactory.getLog("XGBoostSpark") + + /** + * Load metadata saved using [[DefaultParamsReader.loadMetadata()]] + * + * @param expectedClassName If non empty, this is checked against the loaded metadata. + * @throws IllegalArgumentException if expectedClassName is specified and does not match metadata + */ + def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = { + DefaultParamsReader.loadMetadata(path, sc, expectedClassName) + } + + /** + * Extract Params from metadata, and set them in the instance. + * This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]]. + * + * And it will auto-skip the parameter not defined. + * + * This API is mainly copied from DefaultParamsReader + */ + def getAndSetParams(instance: Params, metadata: Metadata): Unit = { + + // XGBoost didn't set the default parameters since the save/load code is copied + // from spark 2.3.x, which means it just used the default values + // as the same with XGBoost version instead of them in model. + // For the compatibility, here we still don't set the default parameters. + // setParams(instance, metadata, isDefault = true) + + setParams(instance, metadata, isDefault = false) + } + + /** This API is only for XGBoostClassificationModel */ + def getNumClass(metadata: Metadata, dataInStream: FSDataInputStream): Int = { + implicit val format = DefaultFormats + + // The xgboostVersion in the meta can specify if the model is the old xgboost in-compatible + // or the new xgboost compatible. + val xgbVerOpt = (metadata.metadata \ DefaultXGBoostParamsWriter.XGBOOST_VERSION_TAG) + .extractOpt[String] + + // For binary:logistic, the numClass parameter can't be set to 2 or not be set. + // For multi:softprob or multi:softmax, the numClass parameter must be set correctly, + // or else, XGBoost will throw exception. + // So it's safe to get numClass from meta data. + xgbVerOpt + .map { _ => (metadata.params \ "numClass").extractOpt[Int].getOrElse(2) } + .getOrElse(dataInStream.readInt()) + + } + + private def setParams( + instance: Params, + metadata: Metadata, + isDefault: Boolean): Unit = { + val paramsToSet = if (isDefault) metadata.defaultParams else metadata.params + paramsToSet match { + case JObject(pairs) => + pairs.foreach { case (paramName, jsonValue) => + val finalName = handleBrokenlyChangedName(paramName) + // For the deleted parameters, we'd better to remove it instead of throwing an exception. + // So we need to check if the parameter exists instead of blindly setting it. + if (instance.hasParam(finalName)) { + val param = instance.getParam(finalName) + val value = param.jsonDecode(compact(render(jsonValue))) + instance.set(param, handleBrokenlyChangedValue(paramName, value)) + } else { + logger.warn(s"$finalName is no longer used in ${spark.VERSION}") + } + } + case _ => + throw new IllegalArgumentException( + s"Cannot recognize JSON metadata: ${metadata.metadataJson}.") + } + } + + private val paramNameCompatibilityMap: Map[String, String] = Map("silent" -> "verbosity") + + /** This is really not good to do this transformation, but it is needed since there're + * some tests based on 0.82 saved model in which the objective is "reg:linear" */ + private val paramValueCompatibilityMap: Map[String, Map[Any, Any]] = + Map("objective" -> Map("reg:linear" -> "reg:squarederror")) + + private def handleBrokenlyChangedName(paramName: String): String = { + paramNameCompatibilityMap.getOrElse(paramName, paramName) + } + + private def handleBrokenlyChangedValue[T](paramName: String, value: T): T = { + paramValueCompatibilityMap.getOrElse(paramName, Map()).getOrElse(value, value).asInstanceOf[T] + } + +} diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/ml/linalg/xgboost/XGBoostSchemaUtils.scala b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/ml/util/XGBoostSchemaUtils.scala similarity index 97% rename from jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/ml/linalg/xgboost/XGBoostSchemaUtils.scala rename to jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/ml/util/XGBoostSchemaUtils.scala index 0976067ec..8765d39f3 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/ml/linalg/xgboost/XGBoostSchemaUtils.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/ml/util/XGBoostSchemaUtils.scala @@ -14,7 +14,7 @@ limitations under the License. */ -package org.apache.spark.ml.linalg.xgboost +package org.apache.spark.ml.util import org.apache.spark.sql.types.{BooleanType, DataType, NumericType, StructType} import org.apache.spark.ml.linalg.VectorUDT diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/DeterministicPartitioningSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/DeterministicPartitioningSuite.scala index 67b2ff0c8..61766b755 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/DeterministicPartitioningSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/DeterministicPartitioningSuite.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014-2022 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,7 +18,8 @@ package ml.dmlc.xgboost4j.scala.spark import org.apache.spark.ml.linalg.Vectors import org.scalatest.FunSuite -import ml.dmlc.xgboost4j.scala.spark.DataUtils.PackedParams +import ml.dmlc.xgboost4j.scala.spark.util.DataUtils +import ml.dmlc.xgboost4j.scala.spark.util.DataUtils.PackedParams import org.apache.spark.sql.functions._ diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/FeatureSizeValidatingSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/FeatureSizeValidatingSuite.scala index f96140555..e0151dde3 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/FeatureSizeValidatingSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/FeatureSizeValidatingSuite.scala @@ -50,7 +50,7 @@ class FeatureSizeValidatingSuite extends FunSuite with PerTest { val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", "objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> 2, "use_external_memory" -> true, "missing" -> 0) - import DataUtils._ + import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._ val sparkSession = ss import sparkSession.implicits._ val repartitioned = sc.parallelize(Synthetic.trainWithDiffFeatureSize, 2) diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala index 512dbdb71..e96618c51 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala @@ -73,7 +73,7 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite => protected def buildDataFrame( labeledPoints: Seq[XGBLabeledPoint], numPartitions: Int = numWorkers): DataFrame = { - import DataUtils._ + import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._ val it = labeledPoints.iterator.zipWithIndex .map { case (labeledPoint: XGBLabeledPoint, id: Int) => (id, labeledPoint.label, labeledPoint.features) @@ -98,7 +98,7 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite => protected def buildDataFrameWithGroup( labeledPoints: Seq[XGBLabeledPoint], numPartitions: Int = numWorkers): DataFrame = { - import DataUtils._ + import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._ val it = labeledPoints.iterator.zipWithIndex .map { case (labeledPoint: XGBLabeledPoint, id: Int) => (id, labeledPoint.label, labeledPoint.features, labeledPoint.group) diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala index e8c29fd52..00cc4d575 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala @@ -310,7 +310,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest with TmpFolderPerSuit val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", "objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> 2, "missing" -> 0) - import DataUtils._ + import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._ val sparkSession = SparkSession.builder().getOrCreate() import sparkSession.implicits._ val repartitioned = sc.parallelize(Synthetic.train, 3).map(lp => (lp.label, lp)).partitionBy( @@ -331,7 +331,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest with TmpFolderPerSuit val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", "objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> 2, "use_external_memory" -> true, "missing" -> 0) - import DataUtils._ + import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._ val sparkSession = SparkSession.builder().getOrCreate() import sparkSession.implicits._ val repartitioned = sc.parallelize(Synthetic.train, 3).map(lp => (lp.label, lp)).partitionBy(