[XGBoost4J-Spark] Serialization for custom objective and eval (#7274)

* added type hints to custom_obj and custom_eval for Spark persistence


Co-authored-by: Bobby Wang <wbo4958@gmail.com>
This commit is contained in:
nicovdijk
2021-10-21 10:22:23 +02:00
committed by GitHub
parent 7593fa9982
commit 31a307cf6b
4 changed files with 193 additions and 36 deletions

View File

@@ -18,50 +18,47 @@ package ml.dmlc.xgboost4j.scala.spark.params
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
import ml.dmlc.xgboost4j.scala.spark.TrackerConf
import org.apache.spark.ml.param.{Param, ParamPair, Params}
import org.json4s.{DefaultFormats, Extraction, NoTypeHints}
import org.json4s.jackson.JsonMethods.{compact, parse, render}
import org.json4s.jackson.Serialization
import org.apache.spark.ml.param.{Param, ParamPair, Params}
/**
* General spark parameter that includes TypeHints for (de)serialization using json4s.
*/
class CustomGeneralParam[T: Manifest](
parent: Params,
name: String,
doc: String) extends Param[T](parent, name, doc) {
/** Creates a param pair with the given value (for Java). */
override def w(value: T): ParamPair[T] = super.w(value)
override def jsonEncode(value: T): String = {
implicit val format = Serialization.formats(Utils.getTypeHintsFromClass(value))
compact(render(Extraction.decompose(value)))
}
override def jsonDecode(json: String): T = {
jsonDecodeT(json)
}
private def jsonDecodeT[T](jsonString: String)(implicit m: Manifest[T]): T = {
val json = parse(jsonString)
implicit val formats = DefaultFormats.withHints(Utils.getTypeHintsFromJsonClass(json))
json.extract[T]
}
}
class CustomEvalParam(
parent: Params,
name: String,
doc: String) extends Param[EvalTrait](parent, name, doc) {
/** Creates a param pair with the given value (for Java). */
override def w(value: EvalTrait): ParamPair[EvalTrait] = super.w(value)
override def jsonEncode(value: EvalTrait): String = {
import org.json4s.jackson.Serialization
implicit val formats = Serialization.formats(NoTypeHints)
compact(render(Extraction.decompose(value)))
}
override def jsonDecode(json: String): EvalTrait = {
implicit val formats = DefaultFormats
parse(json).extract[EvalTrait]
}
}
doc: String) extends CustomGeneralParam[EvalTrait](parent, name, doc)
class CustomObjParam(
parent: Params,
name: String,
doc: String) extends Param[ObjectiveTrait](parent, name, doc) {
/** Creates a param pair with the given value (for Java). */
override def w(value: ObjectiveTrait): ParamPair[ObjectiveTrait] = super.w(value)
override def jsonEncode(value: ObjectiveTrait): String = {
import org.json4s.jackson.Serialization
implicit val formats = Serialization.formats(NoTypeHints)
compact(render(Extraction.decompose(value)))
}
override def jsonDecode(json: String): ObjectiveTrait = {
implicit val formats = DefaultFormats
parse(json).extract[ObjectiveTrait]
}
}
doc: String) extends CustomGeneralParam[ObjectiveTrait](parent, name, doc)
class TrackerConfParam(
parent: Params,

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014,2021 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -16,6 +16,8 @@
package ml.dmlc.xgboost4j.scala.spark.params
import org.json4s.{DefaultFormats, FullTypeHints, JField, JValue, NoTypeHints, TypeHints}
// based on org.apache.spark.util copy /paste
private[spark] object Utils {
@@ -30,4 +32,40 @@ private[spark] object Utils {
Class.forName(className, true, getContextOrSparkClassLoader)
// scalastyle:on classforname
}
/**
* Get the TypeHints according to the value
* @param value the instance of class to be serialized
* @return if value is null,
* return NoTypeHints
* else return the FullTypeHints.
*
* The FullTypeHints will save the full class name into the "jsonClass" of the json,
* so we can find the jsonClass and turn it to FullTypeHints when deserializing.
*/
def getTypeHintsFromClass(value: Any): TypeHints = {
if (value == null) { // XGBoost will save the default value (null)
NoTypeHints
} else {
FullTypeHints(List(value.getClass))
}
}
/**
* Get the TypeHints according to the saved jsonClass field
* @param json
* @return TypeHints
*/
def getTypeHintsFromJsonClass(json: JValue): TypeHints = {
val jsonClassField = json findField {
case JField("jsonClass", _) => true
case _ => false
}
jsonClassField.map { field =>
implicit val formats = DefaultFormats
val className = field._2.extract[String]
FullTypeHints(List(Utils.classForName(className)))
}.getOrElse(NoTypeHints)
}
}