[XGBoost4J-Spark] Serialization for custom objective and eval (#7274)
* added type hints to custom_obj and custom_eval for Spark persistence Co-authored-by: Bobby Wang <wbo4958@gmail.com>
This commit is contained in:
parent
7593fa9982
commit
31a307cf6b
@ -18,50 +18,47 @@ package ml.dmlc.xgboost4j.scala.spark.params
|
|||||||
|
|
||||||
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
||||||
import ml.dmlc.xgboost4j.scala.spark.TrackerConf
|
import ml.dmlc.xgboost4j.scala.spark.TrackerConf
|
||||||
|
import org.apache.spark.ml.param.{Param, ParamPair, Params}
|
||||||
import org.json4s.{DefaultFormats, Extraction, NoTypeHints}
|
import org.json4s.{DefaultFormats, Extraction, NoTypeHints}
|
||||||
import org.json4s.jackson.JsonMethods.{compact, parse, render}
|
import org.json4s.jackson.JsonMethods.{compact, parse, render}
|
||||||
|
import org.json4s.jackson.Serialization
|
||||||
|
|
||||||
import org.apache.spark.ml.param.{Param, ParamPair, Params}
|
/**
|
||||||
|
* General spark parameter that includes TypeHints for (de)serialization using json4s.
|
||||||
|
*/
|
||||||
|
class CustomGeneralParam[T: Manifest](
|
||||||
|
parent: Params,
|
||||||
|
name: String,
|
||||||
|
doc: String) extends Param[T](parent, name, doc) {
|
||||||
|
|
||||||
|
/** Creates a param pair with the given value (for Java). */
|
||||||
|
override def w(value: T): ParamPair[T] = super.w(value)
|
||||||
|
|
||||||
|
override def jsonEncode(value: T): String = {
|
||||||
|
implicit val format = Serialization.formats(Utils.getTypeHintsFromClass(value))
|
||||||
|
compact(render(Extraction.decompose(value)))
|
||||||
|
}
|
||||||
|
|
||||||
|
override def jsonDecode(json: String): T = {
|
||||||
|
jsonDecodeT(json)
|
||||||
|
}
|
||||||
|
|
||||||
|
private def jsonDecodeT[T](jsonString: String)(implicit m: Manifest[T]): T = {
|
||||||
|
val json = parse(jsonString)
|
||||||
|
implicit val formats = DefaultFormats.withHints(Utils.getTypeHintsFromJsonClass(json))
|
||||||
|
json.extract[T]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
class CustomEvalParam(
|
class CustomEvalParam(
|
||||||
parent: Params,
|
parent: Params,
|
||||||
name: String,
|
name: String,
|
||||||
doc: String) extends Param[EvalTrait](parent, name, doc) {
|
doc: String) extends CustomGeneralParam[EvalTrait](parent, name, doc)
|
||||||
|
|
||||||
/** Creates a param pair with the given value (for Java). */
|
|
||||||
override def w(value: EvalTrait): ParamPair[EvalTrait] = super.w(value)
|
|
||||||
|
|
||||||
override def jsonEncode(value: EvalTrait): String = {
|
|
||||||
import org.json4s.jackson.Serialization
|
|
||||||
implicit val formats = Serialization.formats(NoTypeHints)
|
|
||||||
compact(render(Extraction.decompose(value)))
|
|
||||||
}
|
|
||||||
|
|
||||||
override def jsonDecode(json: String): EvalTrait = {
|
|
||||||
implicit val formats = DefaultFormats
|
|
||||||
parse(json).extract[EvalTrait]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
class CustomObjParam(
|
class CustomObjParam(
|
||||||
parent: Params,
|
parent: Params,
|
||||||
name: String,
|
name: String,
|
||||||
doc: String) extends Param[ObjectiveTrait](parent, name, doc) {
|
doc: String) extends CustomGeneralParam[ObjectiveTrait](parent, name, doc)
|
||||||
|
|
||||||
/** Creates a param pair with the given value (for Java). */
|
|
||||||
override def w(value: ObjectiveTrait): ParamPair[ObjectiveTrait] = super.w(value)
|
|
||||||
|
|
||||||
override def jsonEncode(value: ObjectiveTrait): String = {
|
|
||||||
import org.json4s.jackson.Serialization
|
|
||||||
implicit val formats = Serialization.formats(NoTypeHints)
|
|
||||||
compact(render(Extraction.decompose(value)))
|
|
||||||
}
|
|
||||||
|
|
||||||
override def jsonDecode(json: String): ObjectiveTrait = {
|
|
||||||
implicit val formats = DefaultFormats
|
|
||||||
parse(json).extract[ObjectiveTrait]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
class TrackerConfParam(
|
class TrackerConfParam(
|
||||||
parent: Params,
|
parent: Params,
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 by Contributors
|
Copyright (c) 2014,2021 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -16,6 +16,8 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark.params
|
package ml.dmlc.xgboost4j.scala.spark.params
|
||||||
|
|
||||||
|
import org.json4s.{DefaultFormats, FullTypeHints, JField, JValue, NoTypeHints, TypeHints}
|
||||||
|
|
||||||
// based on org.apache.spark.util copy /paste
|
// based on org.apache.spark.util copy /paste
|
||||||
private[spark] object Utils {
|
private[spark] object Utils {
|
||||||
|
|
||||||
@ -30,4 +32,40 @@ private[spark] object Utils {
|
|||||||
Class.forName(className, true, getContextOrSparkClassLoader)
|
Class.forName(className, true, getContextOrSparkClassLoader)
|
||||||
// scalastyle:on classforname
|
// scalastyle:on classforname
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the TypeHints according to the value
|
||||||
|
* @param value the instance of class to be serialized
|
||||||
|
* @return if value is null,
|
||||||
|
* return NoTypeHints
|
||||||
|
* else return the FullTypeHints.
|
||||||
|
*
|
||||||
|
* The FullTypeHints will save the full class name into the "jsonClass" of the json,
|
||||||
|
* so we can find the jsonClass and turn it to FullTypeHints when deserializing.
|
||||||
|
*/
|
||||||
|
def getTypeHintsFromClass(value: Any): TypeHints = {
|
||||||
|
if (value == null) { // XGBoost will save the default value (null)
|
||||||
|
NoTypeHints
|
||||||
|
} else {
|
||||||
|
FullTypeHints(List(value.getClass))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the TypeHints according to the saved jsonClass field
|
||||||
|
* @param json
|
||||||
|
* @return TypeHints
|
||||||
|
*/
|
||||||
|
def getTypeHintsFromJsonClass(json: JValue): TypeHints = {
|
||||||
|
val jsonClassField = json findField {
|
||||||
|
case JField("jsonClass", _) => true
|
||||||
|
case _ => false
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonClassField.map { field =>
|
||||||
|
implicit val formats = DefaultFormats
|
||||||
|
val className = field._2.extract[String]
|
||||||
|
FullTypeHints(List(Utils.classForName(className)))
|
||||||
|
}.getOrElse(NoTypeHints)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -0,0 +1,84 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2021 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.java.XGBoostError
|
||||||
|
import ml.dmlc.xgboost4j.scala.{DMatrix, ObjectiveTrait}
|
||||||
|
import org.apache.commons.logging.LogFactory
|
||||||
|
import scala.collection.mutable.ListBuffer
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* loglikelihood loss obj function
|
||||||
|
*/
|
||||||
|
class CustomObj(val customParameter: Int = 0) extends ObjectiveTrait {
|
||||||
|
|
||||||
|
val logger = LogFactory.getLog(classOf[CustomObj])
|
||||||
|
|
||||||
|
/**
|
||||||
|
* user define objective function, return gradient and second order gradient
|
||||||
|
*
|
||||||
|
* @param predicts untransformed margin predicts
|
||||||
|
* @param dtrain training data
|
||||||
|
* @return List with two float array, correspond to first order grad and second order grad
|
||||||
|
*/
|
||||||
|
override def getGradient(predicts: Array[Array[Float]], dtrain: DMatrix)
|
||||||
|
: List[Array[Float]] = {
|
||||||
|
val nrow = predicts.length
|
||||||
|
val gradients = new ListBuffer[Array[Float]]
|
||||||
|
var labels: Array[Float] = null
|
||||||
|
try {
|
||||||
|
labels = dtrain.getLabel
|
||||||
|
} catch {
|
||||||
|
case e: XGBoostError =>
|
||||||
|
logger.error(e)
|
||||||
|
throw e
|
||||||
|
case e: Throwable => throw e
|
||||||
|
}
|
||||||
|
val grad = new Array[Float](nrow)
|
||||||
|
val hess = new Array[Float](nrow)
|
||||||
|
val transPredicts = transform(predicts)
|
||||||
|
|
||||||
|
for (i <- 0 until nrow) {
|
||||||
|
val predict = transPredicts(i)(0)
|
||||||
|
grad(i) = predict - labels(i)
|
||||||
|
hess(i) = predict * (1 - predict)
|
||||||
|
}
|
||||||
|
gradients += grad
|
||||||
|
gradients += hess
|
||||||
|
gradients.toList
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* simple sigmoid func
|
||||||
|
*
|
||||||
|
* @param input
|
||||||
|
* @return Note: this func is not concern about numerical stability, only used as example
|
||||||
|
*/
|
||||||
|
def sigmoid(input: Float): Float = {
|
||||||
|
(1 / (1 + Math.exp(-input))).toFloat
|
||||||
|
}
|
||||||
|
|
||||||
|
def transform(predicts: Array[Array[Float]]): Array[Array[Float]] = {
|
||||||
|
val nrow = predicts.length
|
||||||
|
val transPredicts = Array.fill[Float](nrow, 1)(0)
|
||||||
|
for (i <- 0 until nrow) {
|
||||||
|
transPredicts(i)(0) = sigmoid(predicts(i)(0))
|
||||||
|
}
|
||||||
|
transPredicts
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 by Contributors
|
Copyright (c) 2014,2021 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -92,7 +92,6 @@ class PersistenceSuite extends FunSuite with TmpFolderPerSuite with PerTest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
test("test persistence of MLlib pipeline with XGBoostClassificationModel") {
|
test("test persistence of MLlib pipeline with XGBoostClassificationModel") {
|
||||||
|
|
||||||
val r = new Random(0)
|
val r = new Random(0)
|
||||||
// maybe move to shared context, but requires session to import implicits
|
// maybe move to shared context, but requires session to import implicits
|
||||||
val df = ss.createDataFrame(Seq.fill(100)(r.nextInt(2)).map(i => (i, i))).
|
val df = ss.createDataFrame(Seq.fill(100)(r.nextInt(2)).map(i => (i, i))).
|
||||||
@ -133,6 +132,45 @@ class PersistenceSuite extends FunSuite with TmpFolderPerSuite with PerTest {
|
|||||||
assert(xgbModel.getRawPredictionCol === xgbModel2.getRawPredictionCol)
|
assert(xgbModel.getRawPredictionCol === xgbModel2.getRawPredictionCol)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("test persistence of XGBoostClassifier and XGBoostClassificationModel " +
|
||||||
|
"using custom Eval and Obj") {
|
||||||
|
val trainingDF = buildDataFrame(Classification.train)
|
||||||
|
val testDM = new DMatrix(Classification.test.iterator)
|
||||||
|
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
||||||
|
"custom_eval" -> new EvalError, "custom_obj" -> new CustomObj(1),
|
||||||
|
"num_round" -> "10", "num_workers" -> numWorkers)
|
||||||
|
|
||||||
|
val xgbc = new XGBoostClassifier(paramMap)
|
||||||
|
val xgbcPath = new File(tempDir.toFile, "xgbc").getPath
|
||||||
|
xgbc.write.overwrite().save(xgbcPath)
|
||||||
|
val xgbc2 = XGBoostClassifier.load(xgbcPath)
|
||||||
|
val paramMap2 = xgbc2.MLlib2XGBoostParams
|
||||||
|
paramMap.foreach {
|
||||||
|
case ("custom_eval", v) => assert(v.isInstanceOf[EvalError])
|
||||||
|
case ("custom_obj", v) =>
|
||||||
|
assert(v.isInstanceOf[CustomObj])
|
||||||
|
assert(v.asInstanceOf[CustomObj].customParameter ==
|
||||||
|
paramMap2("custom_obj").asInstanceOf[CustomObj].customParameter)
|
||||||
|
case (_, _) =>
|
||||||
|
}
|
||||||
|
|
||||||
|
val eval = new EvalError()
|
||||||
|
|
||||||
|
val model = xgbc.fit(trainingDF)
|
||||||
|
val evalResults = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
|
||||||
|
assert(evalResults < 0.1)
|
||||||
|
val xgbcModelPath = new File(tempDir.toFile, "xgbcModel").getPath
|
||||||
|
model.write.overwrite.save(xgbcModelPath)
|
||||||
|
val model2 = XGBoostClassificationModel.load(xgbcModelPath)
|
||||||
|
assert(Arrays.equals(model._booster.toByteArray, model2._booster.toByteArray))
|
||||||
|
|
||||||
|
assert(model.getEta === model2.getEta)
|
||||||
|
assert(model.getNumRound === model2.getNumRound)
|
||||||
|
assert(model.getRawPredictionCol === model2.getRawPredictionCol)
|
||||||
|
val evalResults2 = eval.eval(model2._booster.predict(testDM, outPutMargin = true), testDM)
|
||||||
|
assert(evalResults === evalResults2)
|
||||||
|
}
|
||||||
|
|
||||||
test("cross-version model loading (0.82)") {
|
test("cross-version model loading (0.82)") {
|
||||||
val modelPath = getClass.getResource("/model/0.82/model").getPath
|
val modelPath = getClass.getResource("/model/0.82/model").getPath
|
||||||
val model = XGBoostClassificationModel.read.load(modelPath)
|
val model = XGBoostClassificationModel.read.load(modelPath)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user