[XGBoost4J-Spark] Serialization for custom objective and eval (#7274)
* added type hints to custom_obj and custom_eval for Spark persistence Co-authored-by: Bobby Wang <wbo4958@gmail.com>
This commit is contained in:
parent
7593fa9982
commit
31a307cf6b
@ -18,50 +18,47 @@ package ml.dmlc.xgboost4j.scala.spark.params
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
||||
import ml.dmlc.xgboost4j.scala.spark.TrackerConf
|
||||
import org.apache.spark.ml.param.{Param, ParamPair, Params}
|
||||
import org.json4s.{DefaultFormats, Extraction, NoTypeHints}
|
||||
import org.json4s.jackson.JsonMethods.{compact, parse, render}
|
||||
import org.json4s.jackson.Serialization
|
||||
|
||||
import org.apache.spark.ml.param.{Param, ParamPair, Params}
|
||||
/**
|
||||
* General spark parameter that includes TypeHints for (de)serialization using json4s.
|
||||
*/
|
||||
class CustomGeneralParam[T: Manifest](
|
||||
parent: Params,
|
||||
name: String,
|
||||
doc: String) extends Param[T](parent, name, doc) {
|
||||
|
||||
/** Creates a param pair with the given value (for Java). */
|
||||
override def w(value: T): ParamPair[T] = super.w(value)
|
||||
|
||||
override def jsonEncode(value: T): String = {
|
||||
implicit val format = Serialization.formats(Utils.getTypeHintsFromClass(value))
|
||||
compact(render(Extraction.decompose(value)))
|
||||
}
|
||||
|
||||
override def jsonDecode(json: String): T = {
|
||||
jsonDecodeT(json)
|
||||
}
|
||||
|
||||
private def jsonDecodeT[T](jsonString: String)(implicit m: Manifest[T]): T = {
|
||||
val json = parse(jsonString)
|
||||
implicit val formats = DefaultFormats.withHints(Utils.getTypeHintsFromJsonClass(json))
|
||||
json.extract[T]
|
||||
}
|
||||
}
|
||||
|
||||
class CustomEvalParam(
|
||||
parent: Params,
|
||||
name: String,
|
||||
doc: String) extends Param[EvalTrait](parent, name, doc) {
|
||||
|
||||
/** Creates a param pair with the given value (for Java). */
|
||||
override def w(value: EvalTrait): ParamPair[EvalTrait] = super.w(value)
|
||||
|
||||
override def jsonEncode(value: EvalTrait): String = {
|
||||
import org.json4s.jackson.Serialization
|
||||
implicit val formats = Serialization.formats(NoTypeHints)
|
||||
compact(render(Extraction.decompose(value)))
|
||||
}
|
||||
|
||||
override def jsonDecode(json: String): EvalTrait = {
|
||||
implicit val formats = DefaultFormats
|
||||
parse(json).extract[EvalTrait]
|
||||
}
|
||||
}
|
||||
doc: String) extends CustomGeneralParam[EvalTrait](parent, name, doc)
|
||||
|
||||
class CustomObjParam(
|
||||
parent: Params,
|
||||
name: String,
|
||||
doc: String) extends Param[ObjectiveTrait](parent, name, doc) {
|
||||
|
||||
/** Creates a param pair with the given value (for Java). */
|
||||
override def w(value: ObjectiveTrait): ParamPair[ObjectiveTrait] = super.w(value)
|
||||
|
||||
override def jsonEncode(value: ObjectiveTrait): String = {
|
||||
import org.json4s.jackson.Serialization
|
||||
implicit val formats = Serialization.formats(NoTypeHints)
|
||||
compact(render(Extraction.decompose(value)))
|
||||
}
|
||||
|
||||
override def jsonDecode(json: String): ObjectiveTrait = {
|
||||
implicit val formats = DefaultFormats
|
||||
parse(json).extract[ObjectiveTrait]
|
||||
}
|
||||
}
|
||||
doc: String) extends CustomGeneralParam[ObjectiveTrait](parent, name, doc)
|
||||
|
||||
class TrackerConfParam(
|
||||
parent: Params,
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014,2021 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -16,6 +16,8 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark.params
|
||||
|
||||
import org.json4s.{DefaultFormats, FullTypeHints, JField, JValue, NoTypeHints, TypeHints}
|
||||
|
||||
// based on org.apache.spark.util copy /paste
|
||||
private[spark] object Utils {
|
||||
|
||||
@ -30,4 +32,40 @@ private[spark] object Utils {
|
||||
Class.forName(className, true, getContextOrSparkClassLoader)
|
||||
// scalastyle:on classforname
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the TypeHints according to the value
|
||||
* @param value the instance of class to be serialized
|
||||
* @return if value is null,
|
||||
* return NoTypeHints
|
||||
* else return the FullTypeHints.
|
||||
*
|
||||
* The FullTypeHints will save the full class name into the "jsonClass" of the json,
|
||||
* so we can find the jsonClass and turn it to FullTypeHints when deserializing.
|
||||
*/
|
||||
def getTypeHintsFromClass(value: Any): TypeHints = {
|
||||
if (value == null) { // XGBoost will save the default value (null)
|
||||
NoTypeHints
|
||||
} else {
|
||||
FullTypeHints(List(value.getClass))
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the TypeHints according to the saved jsonClass field
|
||||
* @param json
|
||||
* @return TypeHints
|
||||
*/
|
||||
def getTypeHintsFromJsonClass(json: JValue): TypeHints = {
|
||||
val jsonClassField = json findField {
|
||||
case JField("jsonClass", _) => true
|
||||
case _ => false
|
||||
}
|
||||
|
||||
jsonClassField.map { field =>
|
||||
implicit val formats = DefaultFormats
|
||||
val className = field._2.extract[String]
|
||||
FullTypeHints(List(Utils.classForName(className)))
|
||||
}.getOrElse(NoTypeHints)
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,84 @@
|
||||
/*
|
||||
Copyright (c) 2021 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, ObjectiveTrait}
|
||||
import org.apache.commons.logging.LogFactory
|
||||
import scala.collection.mutable.ListBuffer
|
||||
|
||||
|
||||
/**
|
||||
* loglikelihood loss obj function
|
||||
*/
|
||||
class CustomObj(val customParameter: Int = 0) extends ObjectiveTrait {
|
||||
|
||||
val logger = LogFactory.getLog(classOf[CustomObj])
|
||||
|
||||
/**
|
||||
* user define objective function, return gradient and second order gradient
|
||||
*
|
||||
* @param predicts untransformed margin predicts
|
||||
* @param dtrain training data
|
||||
* @return List with two float array, correspond to first order grad and second order grad
|
||||
*/
|
||||
override def getGradient(predicts: Array[Array[Float]], dtrain: DMatrix)
|
||||
: List[Array[Float]] = {
|
||||
val nrow = predicts.length
|
||||
val gradients = new ListBuffer[Array[Float]]
|
||||
var labels: Array[Float] = null
|
||||
try {
|
||||
labels = dtrain.getLabel
|
||||
} catch {
|
||||
case e: XGBoostError =>
|
||||
logger.error(e)
|
||||
throw e
|
||||
case e: Throwable => throw e
|
||||
}
|
||||
val grad = new Array[Float](nrow)
|
||||
val hess = new Array[Float](nrow)
|
||||
val transPredicts = transform(predicts)
|
||||
|
||||
for (i <- 0 until nrow) {
|
||||
val predict = transPredicts(i)(0)
|
||||
grad(i) = predict - labels(i)
|
||||
hess(i) = predict * (1 - predict)
|
||||
}
|
||||
gradients += grad
|
||||
gradients += hess
|
||||
gradients.toList
|
||||
}
|
||||
|
||||
/**
|
||||
* simple sigmoid func
|
||||
*
|
||||
* @param input
|
||||
* @return Note: this func is not concern about numerical stability, only used as example
|
||||
*/
|
||||
def sigmoid(input: Float): Float = {
|
||||
(1 / (1 + Math.exp(-input))).toFloat
|
||||
}
|
||||
|
||||
def transform(predicts: Array[Array[Float]]): Array[Array[Float]] = {
|
||||
val nrow = predicts.length
|
||||
val transPredicts = Array.fill[Float](nrow, 1)(0)
|
||||
for (i <- 0 until nrow) {
|
||||
transPredicts(i)(0) = sigmoid(predicts(i)(0))
|
||||
}
|
||||
transPredicts
|
||||
}
|
||||
}
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014,2021 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -92,7 +92,6 @@ class PersistenceSuite extends FunSuite with TmpFolderPerSuite with PerTest {
|
||||
}
|
||||
|
||||
test("test persistence of MLlib pipeline with XGBoostClassificationModel") {
|
||||
|
||||
val r = new Random(0)
|
||||
// maybe move to shared context, but requires session to import implicits
|
||||
val df = ss.createDataFrame(Seq.fill(100)(r.nextInt(2)).map(i => (i, i))).
|
||||
@ -133,6 +132,45 @@ class PersistenceSuite extends FunSuite with TmpFolderPerSuite with PerTest {
|
||||
assert(xgbModel.getRawPredictionCol === xgbModel2.getRawPredictionCol)
|
||||
}
|
||||
|
||||
test("test persistence of XGBoostClassifier and XGBoostClassificationModel " +
|
||||
"using custom Eval and Obj") {
|
||||
val trainingDF = buildDataFrame(Classification.train)
|
||||
val testDM = new DMatrix(Classification.test.iterator)
|
||||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
||||
"custom_eval" -> new EvalError, "custom_obj" -> new CustomObj(1),
|
||||
"num_round" -> "10", "num_workers" -> numWorkers)
|
||||
|
||||
val xgbc = new XGBoostClassifier(paramMap)
|
||||
val xgbcPath = new File(tempDir.toFile, "xgbc").getPath
|
||||
xgbc.write.overwrite().save(xgbcPath)
|
||||
val xgbc2 = XGBoostClassifier.load(xgbcPath)
|
||||
val paramMap2 = xgbc2.MLlib2XGBoostParams
|
||||
paramMap.foreach {
|
||||
case ("custom_eval", v) => assert(v.isInstanceOf[EvalError])
|
||||
case ("custom_obj", v) =>
|
||||
assert(v.isInstanceOf[CustomObj])
|
||||
assert(v.asInstanceOf[CustomObj].customParameter ==
|
||||
paramMap2("custom_obj").asInstanceOf[CustomObj].customParameter)
|
||||
case (_, _) =>
|
||||
}
|
||||
|
||||
val eval = new EvalError()
|
||||
|
||||
val model = xgbc.fit(trainingDF)
|
||||
val evalResults = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
|
||||
assert(evalResults < 0.1)
|
||||
val xgbcModelPath = new File(tempDir.toFile, "xgbcModel").getPath
|
||||
model.write.overwrite.save(xgbcModelPath)
|
||||
val model2 = XGBoostClassificationModel.load(xgbcModelPath)
|
||||
assert(Arrays.equals(model._booster.toByteArray, model2._booster.toByteArray))
|
||||
|
||||
assert(model.getEta === model2.getEta)
|
||||
assert(model.getNumRound === model2.getNumRound)
|
||||
assert(model.getRawPredictionCol === model2.getRawPredictionCol)
|
||||
val evalResults2 = eval.eval(model2._booster.predict(testDM, outPutMargin = true), testDM)
|
||||
assert(evalResults === evalResults2)
|
||||
}
|
||||
|
||||
test("cross-version model loading (0.82)") {
|
||||
val modelPath = getClass.getResource("/model/0.82/model").getPath
|
||||
val model = XGBoostClassificationModel.read.load(modelPath)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user