[XGBoost4J-Spark] Serialization for custom objective and eval (#7274)
* added type hints to custom_obj and custom_eval for Spark persistence Co-authored-by: Bobby Wang <wbo4958@gmail.com>
This commit is contained in:
@@ -18,50 +18,47 @@ package ml.dmlc.xgboost4j.scala.spark.params
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
||||
import ml.dmlc.xgboost4j.scala.spark.TrackerConf
|
||||
import org.apache.spark.ml.param.{Param, ParamPair, Params}
|
||||
import org.json4s.{DefaultFormats, Extraction, NoTypeHints}
|
||||
import org.json4s.jackson.JsonMethods.{compact, parse, render}
|
||||
import org.json4s.jackson.Serialization
|
||||
|
||||
import org.apache.spark.ml.param.{Param, ParamPair, Params}
|
||||
/**
|
||||
* General spark parameter that includes TypeHints for (de)serialization using json4s.
|
||||
*/
|
||||
class CustomGeneralParam[T: Manifest](
|
||||
parent: Params,
|
||||
name: String,
|
||||
doc: String) extends Param[T](parent, name, doc) {
|
||||
|
||||
/** Creates a param pair with the given value (for Java). */
|
||||
override def w(value: T): ParamPair[T] = super.w(value)
|
||||
|
||||
override def jsonEncode(value: T): String = {
|
||||
implicit val format = Serialization.formats(Utils.getTypeHintsFromClass(value))
|
||||
compact(render(Extraction.decompose(value)))
|
||||
}
|
||||
|
||||
override def jsonDecode(json: String): T = {
|
||||
jsonDecodeT(json)
|
||||
}
|
||||
|
||||
private def jsonDecodeT[T](jsonString: String)(implicit m: Manifest[T]): T = {
|
||||
val json = parse(jsonString)
|
||||
implicit val formats = DefaultFormats.withHints(Utils.getTypeHintsFromJsonClass(json))
|
||||
json.extract[T]
|
||||
}
|
||||
}
|
||||
|
||||
class CustomEvalParam(
|
||||
parent: Params,
|
||||
name: String,
|
||||
doc: String) extends Param[EvalTrait](parent, name, doc) {
|
||||
|
||||
/** Creates a param pair with the given value (for Java). */
|
||||
override def w(value: EvalTrait): ParamPair[EvalTrait] = super.w(value)
|
||||
|
||||
override def jsonEncode(value: EvalTrait): String = {
|
||||
import org.json4s.jackson.Serialization
|
||||
implicit val formats = Serialization.formats(NoTypeHints)
|
||||
compact(render(Extraction.decompose(value)))
|
||||
}
|
||||
|
||||
override def jsonDecode(json: String): EvalTrait = {
|
||||
implicit val formats = DefaultFormats
|
||||
parse(json).extract[EvalTrait]
|
||||
}
|
||||
}
|
||||
doc: String) extends CustomGeneralParam[EvalTrait](parent, name, doc)
|
||||
|
||||
class CustomObjParam(
|
||||
parent: Params,
|
||||
name: String,
|
||||
doc: String) extends Param[ObjectiveTrait](parent, name, doc) {
|
||||
|
||||
/** Creates a param pair with the given value (for Java). */
|
||||
override def w(value: ObjectiveTrait): ParamPair[ObjectiveTrait] = super.w(value)
|
||||
|
||||
override def jsonEncode(value: ObjectiveTrait): String = {
|
||||
import org.json4s.jackson.Serialization
|
||||
implicit val formats = Serialization.formats(NoTypeHints)
|
||||
compact(render(Extraction.decompose(value)))
|
||||
}
|
||||
|
||||
override def jsonDecode(json: String): ObjectiveTrait = {
|
||||
implicit val formats = DefaultFormats
|
||||
parse(json).extract[ObjectiveTrait]
|
||||
}
|
||||
}
|
||||
doc: String) extends CustomGeneralParam[ObjectiveTrait](parent, name, doc)
|
||||
|
||||
class TrackerConfParam(
|
||||
parent: Params,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014,2021 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -16,6 +16,8 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark.params
|
||||
|
||||
import org.json4s.{DefaultFormats, FullTypeHints, JField, JValue, NoTypeHints, TypeHints}
|
||||
|
||||
// based on org.apache.spark.util copy /paste
|
||||
private[spark] object Utils {
|
||||
|
||||
@@ -30,4 +32,40 @@ private[spark] object Utils {
|
||||
Class.forName(className, true, getContextOrSparkClassLoader)
|
||||
// scalastyle:on classforname
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the TypeHints according to the value
|
||||
* @param value the instance of class to be serialized
|
||||
* @return if value is null,
|
||||
* return NoTypeHints
|
||||
* else return the FullTypeHints.
|
||||
*
|
||||
* The FullTypeHints will save the full class name into the "jsonClass" of the json,
|
||||
* so we can find the jsonClass and turn it to FullTypeHints when deserializing.
|
||||
*/
|
||||
def getTypeHintsFromClass(value: Any): TypeHints = {
|
||||
if (value == null) { // XGBoost will save the default value (null)
|
||||
NoTypeHints
|
||||
} else {
|
||||
FullTypeHints(List(value.getClass))
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the TypeHints according to the saved jsonClass field
|
||||
* @param json
|
||||
* @return TypeHints
|
||||
*/
|
||||
def getTypeHintsFromJsonClass(json: JValue): TypeHints = {
|
||||
val jsonClassField = json findField {
|
||||
case JField("jsonClass", _) => true
|
||||
case _ => false
|
||||
}
|
||||
|
||||
jsonClassField.map { field =>
|
||||
implicit val formats = DefaultFormats
|
||||
val className = field._2.extract[String]
|
||||
FullTypeHints(List(Utils.classForName(className)))
|
||||
}.getOrElse(NoTypeHints)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user