From 31a307cf6bca0f0145be241bf12e2fe558ac42d1 Mon Sep 17 00:00:00 2001 From: nicovdijk <43035522+nicovdijk@users.noreply.github.com> Date: Thu, 21 Oct 2021 10:22:23 +0200 Subject: [PATCH] [XGBoost4J-Spark] Serialization for custom objective and eval (#7274) * added type hints to custom_obj and custom_eval for Spark persistence Co-authored-by: Bobby Wang --- .../scala/spark/params/CustomParams.scala | 63 +++++++------- .../xgboost4j/scala/spark/params/Utils.scala | 40 ++++++++- .../xgboost4j/scala/spark/CustomObj.scala | 84 +++++++++++++++++++ .../scala/spark/PersistenceSuite.scala | 42 +++++++++- 4 files changed, 193 insertions(+), 36 deletions(-) create mode 100644 jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala index 784be2aa0..c74560218 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala @@ -18,50 +18,47 @@ package ml.dmlc.xgboost4j.scala.spark.params import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait} import ml.dmlc.xgboost4j.scala.spark.TrackerConf +import org.apache.spark.ml.param.{Param, ParamPair, Params} import org.json4s.{DefaultFormats, Extraction, NoTypeHints} import org.json4s.jackson.JsonMethods.{compact, parse, render} +import org.json4s.jackson.Serialization -import org.apache.spark.ml.param.{Param, ParamPair, Params} +/** + * General spark parameter that includes TypeHints for (de)serialization using json4s. + */ +class CustomGeneralParam[T: Manifest]( + parent: Params, + name: String, + doc: String) extends Param[T](parent, name, doc) { + + /** Creates a param pair with the given value (for Java). */ + override def w(value: T): ParamPair[T] = super.w(value) + + override def jsonEncode(value: T): String = { + implicit val format = Serialization.formats(Utils.getTypeHintsFromClass(value)) + compact(render(Extraction.decompose(value))) + } + + override def jsonDecode(json: String): T = { + jsonDecodeT(json) + } + + private def jsonDecodeT[T](jsonString: String)(implicit m: Manifest[T]): T = { + val json = parse(jsonString) + implicit val formats = DefaultFormats.withHints(Utils.getTypeHintsFromJsonClass(json)) + json.extract[T] + } +} class CustomEvalParam( parent: Params, name: String, - doc: String) extends Param[EvalTrait](parent, name, doc) { - - /** Creates a param pair with the given value (for Java). */ - override def w(value: EvalTrait): ParamPair[EvalTrait] = super.w(value) - - override def jsonEncode(value: EvalTrait): String = { - import org.json4s.jackson.Serialization - implicit val formats = Serialization.formats(NoTypeHints) - compact(render(Extraction.decompose(value))) - } - - override def jsonDecode(json: String): EvalTrait = { - implicit val formats = DefaultFormats - parse(json).extract[EvalTrait] - } -} + doc: String) extends CustomGeneralParam[EvalTrait](parent, name, doc) class CustomObjParam( parent: Params, name: String, - doc: String) extends Param[ObjectiveTrait](parent, name, doc) { - - /** Creates a param pair with the given value (for Java). */ - override def w(value: ObjectiveTrait): ParamPair[ObjectiveTrait] = super.w(value) - - override def jsonEncode(value: ObjectiveTrait): String = { - import org.json4s.jackson.Serialization - implicit val formats = Serialization.formats(NoTypeHints) - compact(render(Extraction.decompose(value))) - } - - override def jsonDecode(json: String): ObjectiveTrait = { - implicit val formats = DefaultFormats - parse(json).extract[ObjectiveTrait] - } -} + doc: String) extends CustomGeneralParam[ObjectiveTrait](parent, name, doc) class TrackerConfParam( parent: Params, diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/Utils.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/Utils.scala index 7d6e7b9ed..fb84ad6d6 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/Utils.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/Utils.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014,2021 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,8 @@ package ml.dmlc.xgboost4j.scala.spark.params +import org.json4s.{DefaultFormats, FullTypeHints, JField, JValue, NoTypeHints, TypeHints} + // based on org.apache.spark.util copy /paste private[spark] object Utils { @@ -30,4 +32,40 @@ private[spark] object Utils { Class.forName(className, true, getContextOrSparkClassLoader) // scalastyle:on classforname } + + /** + * Get the TypeHints according to the value + * @param value the instance of class to be serialized + * @return if value is null, + * return NoTypeHints + * else return the FullTypeHints. + * + * The FullTypeHints will save the full class name into the "jsonClass" of the json, + * so we can find the jsonClass and turn it to FullTypeHints when deserializing. + */ + def getTypeHintsFromClass(value: Any): TypeHints = { + if (value == null) { // XGBoost will save the default value (null) + NoTypeHints + } else { + FullTypeHints(List(value.getClass)) + } + } + + /** + * Get the TypeHints according to the saved jsonClass field + * @param json + * @return TypeHints + */ + def getTypeHintsFromJsonClass(json: JValue): TypeHints = { + val jsonClassField = json findField { + case JField("jsonClass", _) => true + case _ => false + } + + jsonClassField.map { field => + implicit val formats = DefaultFormats + val className = field._2.extract[String] + FullTypeHints(List(Utils.classForName(className))) + }.getOrElse(NoTypeHints) + } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala new file mode 100644 index 000000000..b9a39a14d --- /dev/null +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala @@ -0,0 +1,84 @@ +/* + Copyright (c) 2021 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala.spark + +import ml.dmlc.xgboost4j.java.XGBoostError +import ml.dmlc.xgboost4j.scala.{DMatrix, ObjectiveTrait} +import org.apache.commons.logging.LogFactory +import scala.collection.mutable.ListBuffer + + +/** + * loglikelihood loss obj function + */ +class CustomObj(val customParameter: Int = 0) extends ObjectiveTrait { + + val logger = LogFactory.getLog(classOf[CustomObj]) + + /** + * user define objective function, return gradient and second order gradient + * + * @param predicts untransformed margin predicts + * @param dtrain training data + * @return List with two float array, correspond to first order grad and second order grad + */ + override def getGradient(predicts: Array[Array[Float]], dtrain: DMatrix) + : List[Array[Float]] = { + val nrow = predicts.length + val gradients = new ListBuffer[Array[Float]] + var labels: Array[Float] = null + try { + labels = dtrain.getLabel + } catch { + case e: XGBoostError => + logger.error(e) + throw e + case e: Throwable => throw e + } + val grad = new Array[Float](nrow) + val hess = new Array[Float](nrow) + val transPredicts = transform(predicts) + + for (i <- 0 until nrow) { + val predict = transPredicts(i)(0) + grad(i) = predict - labels(i) + hess(i) = predict * (1 - predict) + } + gradients += grad + gradients += hess + gradients.toList + } + + /** + * simple sigmoid func + * + * @param input + * @return Note: this func is not concern about numerical stability, only used as example + */ + def sigmoid(input: Float): Float = { + (1 / (1 + Math.exp(-input))).toFloat + } + + def transform(predicts: Array[Array[Float]]): Array[Array[Float]] = { + val nrow = predicts.length + val transPredicts = Array.fill[Float](nrow, 1)(0) + for (i <- 0 until nrow) { + transPredicts(i)(0) = sigmoid(predicts(i)(0)) + } + transPredicts + } +} diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala index ebe1d8546..a1732c7f7 100755 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014,2021 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -92,7 +92,6 @@ class PersistenceSuite extends FunSuite with TmpFolderPerSuite with PerTest { } test("test persistence of MLlib pipeline with XGBoostClassificationModel") { - val r = new Random(0) // maybe move to shared context, but requires session to import implicits val df = ss.createDataFrame(Seq.fill(100)(r.nextInt(2)).map(i => (i, i))). @@ -133,6 +132,45 @@ class PersistenceSuite extends FunSuite with TmpFolderPerSuite with PerTest { assert(xgbModel.getRawPredictionCol === xgbModel2.getRawPredictionCol) } + test("test persistence of XGBoostClassifier and XGBoostClassificationModel " + + "using custom Eval and Obj") { + val trainingDF = buildDataFrame(Classification.train) + val testDM = new DMatrix(Classification.test.iterator) + val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1", + "custom_eval" -> new EvalError, "custom_obj" -> new CustomObj(1), + "num_round" -> "10", "num_workers" -> numWorkers) + + val xgbc = new XGBoostClassifier(paramMap) + val xgbcPath = new File(tempDir.toFile, "xgbc").getPath + xgbc.write.overwrite().save(xgbcPath) + val xgbc2 = XGBoostClassifier.load(xgbcPath) + val paramMap2 = xgbc2.MLlib2XGBoostParams + paramMap.foreach { + case ("custom_eval", v) => assert(v.isInstanceOf[EvalError]) + case ("custom_obj", v) => + assert(v.isInstanceOf[CustomObj]) + assert(v.asInstanceOf[CustomObj].customParameter == + paramMap2("custom_obj").asInstanceOf[CustomObj].customParameter) + case (_, _) => + } + + val eval = new EvalError() + + val model = xgbc.fit(trainingDF) + val evalResults = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) + assert(evalResults < 0.1) + val xgbcModelPath = new File(tempDir.toFile, "xgbcModel").getPath + model.write.overwrite.save(xgbcModelPath) + val model2 = XGBoostClassificationModel.load(xgbcModelPath) + assert(Arrays.equals(model._booster.toByteArray, model2._booster.toByteArray)) + + assert(model.getEta === model2.getEta) + assert(model.getNumRound === model2.getNumRound) + assert(model.getRawPredictionCol === model2.getRawPredictionCol) + val evalResults2 = eval.eval(model2._booster.predict(testDM, outPutMargin = true), testDM) + assert(evalResults === evalResults2) + } + test("cross-version model loading (0.82)") { val modelPath = getClass.getResource("/model/0.82/model").getPath val model = XGBoostClassificationModel.read.load(modelPath)