[jvm-packages] fix the persistence of XGBoostEstimator (#2265)

* add back train method but mark as deprecated

* fix scalastyle error

* fix the persistence of XGBoostEstimator

* test persistence of a complete pipeline

* fix compilation issue

* do not allow persist custom_eval and custom_obj

* fix the failed tesl
This commit is contained in:
Nan Zhu 2017-05-08 21:58:06 -07:00 committed by GitHub
parent 6bf968efe6
commit 428453f7d6
12 changed files with 362 additions and 66 deletions

3
.gitignore vendored
View File

@ -89,3 +89,6 @@ build_tests
.DS_Store
lib/
# spark
metastore_db

View File

@ -18,21 +18,22 @@ package ml.dmlc.xgboost4j.scala.spark
import scala.collection.mutable
import scala.collection.mutable.ListBuffer
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, DMatrix => JDMatrix, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
import org.apache.commons.logging.LogFactory
import org.apache.hadoop.fs.{FSDataInputStream, Path}
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
import org.apache.spark.ml.linalg.SparseVector
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Dataset
import org.apache.spark.{SparkContext, TaskContext}
import scala.concurrent.duration.{Duration, MILLISECONDS}
import scala.concurrent.duration.{Duration, FiniteDuration, MILLISECONDS}
object TrackerConf {
def apply(): TrackerConf = TrackerConf(Duration.apply(0L, MILLISECONDS), "python")
def apply(): TrackerConf = TrackerConf(0L, "python")
}
/**
@ -40,13 +41,14 @@ object TrackerConf {
* @param workerConnectionTimeout The timeout for all workers to connect to the tracker.
* Set timeout length to zero to disable timeout.
* Use a finite, non-zero timeout value to prevent tracker from
* hanging indefinitely (supported by "scala" implementation only.)
* hanging indefinitely (in milliseconds)
* (supported by "scala" implementation only.)
* @param trackerImpl Choice between "python" or "scala". The former utilizes the Java wrapper of
* the Python Rabit tracker (in dmlc_core), whereas the latter is implemented
* in Scala without Python components, and with full support of timeouts.
* The Scala implementation is currently experimental, use at your own risk.
*/
case class TrackerConf(workerConnectionTimeout: Duration, trackerImpl: String)
case class TrackerConf(workerConnectionTimeout: Long, trackerImpl: String)
object XGBoost extends Serializable {
private val logger = LogFactory.getLog("XGBoostSpark")
@ -240,14 +242,7 @@ object XGBoost extends Serializable {
case _ => new PyRabitTracker(nWorkers)
}
val connectionTimeout = if (trackerConf.workerConnectionTimeout.isFinite()) {
trackerConf.workerConnectionTimeout.toMillis
} else {
// 0 == Duration.Inf
0L
}
require(tracker.start(connectionTimeout), "FAULT: Failed to start tracker")
require(tracker.start(trackerConf.workerConnectionTimeout), "FAULT: Failed to start tracker")
tracker
}

View File

@ -18,12 +18,14 @@ package ml.dmlc.xgboost4j.scala.spark
import scala.collection.mutable
import ml.dmlc.xgboost4j.scala.spark.params.{BoosterParams, GeneralParams, LearningTaskParams}
import ml.dmlc.xgboost4j.scala.spark.params._
import org.json4s.DefaultFormats
import org.apache.spark.ml.Predictor
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{Vector => MLVector}
import org.apache.spark.ml.param._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.util._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.{Dataset, Row}
@ -34,7 +36,7 @@ import org.apache.spark.sql.{Dataset, Row}
class XGBoostEstimator private[spark](
override val uid: String, xgboostParams: Map[String, Any])
extends Predictor[MLVector, XGBoostEstimator, XGBoostModel]
with LearningTaskParams with GeneralParams with BoosterParams {
with LearningTaskParams with GeneralParams with BoosterParams with MLWritable {
def this(xgboostParams: Map[String, Any]) =
this(Identifiable.randomUID("XGBoostEstimator"), xgboostParams: Map[String, Any])
@ -129,4 +131,38 @@ class XGBoostEstimator private[spark](
override def copy(extra: ParamMap): XGBoostEstimator = {
defaultCopy(extra).asInstanceOf[XGBoostEstimator]
}
override def write: MLWriter = new XGBoostEstimator.XGBoostEstimatorWriter(this)
}
object XGBoostEstimator extends MLReadable[XGBoostEstimator] {
override def read: MLReader[XGBoostEstimator] = new XGBoostEstimatorReader
override def load(path: String): XGBoostEstimator = super.load(path)
private[XGBoostEstimator] class XGBoostEstimatorWriter(instance: XGBoostEstimator)
extends MLWriter {
override protected def saveImpl(path: String): Unit = {
require(instance.fromParamsToXGBParamMap("custom_eval") == null &&
instance.fromParamsToXGBParamMap("custom_obj") == null,
"we do not support persist XGBoostEstimator with customized evaluator and objective" +
" function for now")
implicit val format = DefaultFormats
implicit val sc = super.sparkSession.sparkContext
DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc)
}
}
private class XGBoostEstimatorReader extends MLReader[XGBoostEstimator] {
override def load(path: String): XGBoostEstimator = {
val metadata = DefaultXGBoostParamsReader.loadMetadata(path, sc)
val cls = Utils.classForName(metadata.className)
val instance =
cls.getConstructor(classOf[String]).newInstance(metadata.uid).asInstanceOf[Params]
DefaultXGBoostParamsReader.getAndSetParams(instance, metadata)
instance.asInstanceOf[XGBoostEstimator]
}
}
}

View File

@ -324,14 +324,13 @@ object XGBoostModel extends MLReadable[XGBoostModel] {
implicit val format = DefaultFormats
implicit val sc = super.sparkSession.sparkContext
DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc)
val dataPath = new Path(path, "data").toString
instance.saveModelAsHadoopFile(dataPath)
}
}
private class XGBoostModelModelReader extends MLReader[XGBoostModel] {
private val className = classOf[XGBoostModel].getName
override def load(path: String): XGBoostModel = {
implicit val sc = super.sparkSession.sparkContext
val dataPath = new Path(path, "data").toString
@ -340,5 +339,4 @@ object XGBoostModel extends MLReadable[XGBoostModel] {
XGBoost.loadModelFromHadoopFile(dataPath)
}
}
}

View File

@ -0,0 +1,106 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark.params
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
import ml.dmlc.xgboost4j.scala.spark.TrackerConf
import org.json4s.{DefaultFormats, Extraction, NoTypeHints}
import org.json4s.jackson.JsonMethods.{compact, parse, render}
import org.apache.spark.ml.param.{Param, ParamPair, Params}
class GroupDataParam(
parent: Params,
name: String,
doc: String) extends Param[Seq[Seq[Int]]](parent, name, doc) {
/** Creates a param pair with the given value (for Java). */
override def w(value: Seq[Seq[Int]]): ParamPair[Seq[Seq[Int]]] = super.w(value)
override def jsonEncode(value: Seq[Seq[Int]]): String = {
import org.json4s.jackson.Serialization
implicit val formats = Serialization.formats(NoTypeHints)
compact(render(Extraction.decompose(value)))
}
override def jsonDecode(json: String): Seq[Seq[Int]] = {
implicit val formats = DefaultFormats
parse(json).extract[Seq[Seq[Int]]]
}
}
class CustomEvalParam(
parent: Params,
name: String,
doc: String) extends Param[EvalTrait](parent, name, doc) {
/** Creates a param pair with the given value (for Java). */
override def w(value: EvalTrait): ParamPair[EvalTrait] = super.w(value)
override def jsonEncode(value: EvalTrait): String = {
import org.json4s.jackson.Serialization
implicit val formats = Serialization.formats(NoTypeHints)
compact(render(Extraction.decompose(value)))
}
override def jsonDecode(json: String): EvalTrait = {
implicit val formats = DefaultFormats
parse(json).extract[EvalTrait]
}
}
class CustomObjParam(
parent: Params,
name: String,
doc: String) extends Param[ObjectiveTrait](parent, name, doc) {
/** Creates a param pair with the given value (for Java). */
override def w(value: ObjectiveTrait): ParamPair[ObjectiveTrait] = super.w(value)
override def jsonEncode(value: ObjectiveTrait): String = {
import org.json4s.jackson.Serialization
implicit val formats = Serialization.formats(NoTypeHints)
compact(render(Extraction.decompose(value)))
}
override def jsonDecode(json: String): ObjectiveTrait = {
implicit val formats = DefaultFormats
parse(json).extract[ObjectiveTrait]
}
}
class TrackerConfParam(
parent: Params,
name: String,
doc: String) extends Param[TrackerConf](parent, name, doc) {
/** Creates a param pair with the given value (for Java). */
override def w(value: TrackerConf): ParamPair[TrackerConf] = super.w(value)
override def jsonEncode(value: TrackerConf): String = {
import org.json4s.jackson.Serialization
implicit val formats = Serialization.formats(NoTypeHints)
compact(render(Extraction.decompose(value)))
}
override def jsonDecode(json: String): TrackerConf = {
implicit val formats = DefaultFormats
val parsedValue = parse(json)
println(parsedValue.children)
parsedValue.extract[TrackerConf]
}
}

View File

@ -0,0 +1,136 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark.params
import org.apache.hadoop.fs.Path
import org.json4s.{DefaultFormats, JValue}
import org.json4s.JsonAST.JObject
import org.json4s.jackson.JsonMethods.{compact, parse, render}
import org.apache.spark.SparkContext
import org.apache.spark.ml.param.Params
import org.apache.spark.ml.util.MLReader
// This originates from apache-spark DefaultPramsReader copy paste
private[spark] object DefaultXGBoostParamsReader {
/**
* All info from metadata file.
*
* @param params paramMap, as a `JValue`
* @param metadata All metadata, including the other fields
* @param metadataJson Full metadata file String (for debugging)
*/
case class Metadata(
className: String,
uid: String,
timestamp: Long,
sparkVersion: String,
params: JValue,
metadata: JValue,
metadataJson: String) {
/**
* Get the JSON value of the [[org.apache.spark.ml.param.Param]] of the given name.
* This can be useful for getting a Param value before an instance of `Params`
* is available.
*/
def getParamValue(paramName: String): JValue = {
implicit val format = DefaultFormats
params match {
case JObject(pairs) =>
val values = pairs.filter { case (pName, jsonValue) =>
pName == paramName
}.map(_._2)
assert(values.length == 1, s"Expected one instance of Param '$paramName' but found" +
s" ${values.length} in JSON Params: " + pairs.map(_.toString).mkString(", "))
values.head
case _ =>
throw new IllegalArgumentException(
s"Cannot recognize JSON metadata: $metadataJson.")
}
}
}
/**
* Load metadata saved using [[DefaultXGBoostParamsWriter.saveMetadata()]]
*
* @param expectedClassName If non empty, this is checked against the loaded metadata.
* @throws IllegalArgumentException if expectedClassName is specified and does not match metadata
*/
def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = {
val metadataPath = new Path(path, "metadata").toString
val metadataStr = sc.textFile(metadataPath, 1).first()
parseMetadata(metadataStr, expectedClassName)
}
/**
* Parse metadata JSON string produced by [[DefaultXGBoostParamsWriter.getMetadataToSave()]].
* This is a helper function for [[loadMetadata()]].
*
* @param metadataStr JSON string of metadata
* @param expectedClassName If non empty, this is checked against the loaded metadata.
* @throws IllegalArgumentException if expectedClassName is specified and does not match metadata
*/
def parseMetadata(metadataStr: String, expectedClassName: String = ""): Metadata = {
val metadata = parse(metadataStr)
implicit val format = DefaultFormats
val className = (metadata \ "class").extract[String]
val uid = (metadata \ "uid").extract[String]
val timestamp = (metadata \ "timestamp").extract[Long]
val sparkVersion = (metadata \ "sparkVersion").extract[String]
val params = metadata \ "paramMap"
if (expectedClassName.nonEmpty) {
require(className == expectedClassName, s"Error loading metadata: Expected class name" +
s" $expectedClassName but found class name $className")
}
Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr)
}
/**
* Extract Params from metadata, and set them in the instance.
* This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]].
* TODO: Move to [[Metadata]] method
*/
def getAndSetParams(instance: Params, metadata: Metadata): Unit = {
implicit val format = DefaultFormats
metadata.params match {
case JObject(pairs) =>
pairs.foreach { case (paramName, jsonValue) =>
val param = instance.getParam(paramName)
val value = param.jsonDecode(compact(render(jsonValue)))
instance.set(param, value)
}
case _ =>
throw new IllegalArgumentException(
s"Cannot recognize JSON metadata: ${metadata.metadataJson}.")
}
}
/**
* Load a `Params` instance from the given path, and return it.
* This assumes the instance implements [[org.apache.spark.ml.util.MLReadable]].
*/
def loadParamsInstance[T](path: String, sc: SparkContext): T = {
val metadata = DefaultXGBoostParamsReader.loadMetadata(path, sc)
val cls = Utils.classForName(metadata.className)
cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path)
}
}

View File

@ -46,6 +46,7 @@ private[spark] object DefaultXGBoostParamsWriter {
sc: SparkContext,
extraMetadata: Option[JObject] = None,
paramMap: Option[JValue] = None): Unit = {
val metadataPath = new Path(path, "metadata").toString
val metadataJson = getMetadataToSave(instance, sc, extraMetadata, paramMap)
sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
@ -65,7 +66,9 @@ private[spark] object DefaultXGBoostParamsWriter {
val uid = instance.uid
val cls = instance.getClass.getName
val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
val jsonParams = paramMap.getOrElse(render(params.map {
val jsonParams = paramMap.getOrElse(render(params.filter{
case ParamPair(p, _) => p != null
}.map {
case ParamPair(p, v) =>
p.name -> parse(p.jsonEncode(v))
}.toList))

View File

@ -20,8 +20,6 @@ import ml.dmlc.xgboost4j.scala.spark.TrackerConf
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
import org.apache.spark.ml.param._
import scala.concurrent.duration.{Duration, NANOSECONDS}
trait GeneralParams extends Params {
/**
@ -58,13 +56,13 @@ trait GeneralParams extends Params {
/**
* customized objective function provided by user. default: null
*/
val customObj = new Param[ObjectiveTrait](this, "custom_obj", "customized objective function " +
val customObj = new CustomObjParam(this, "custom_obj", "customized objective function " +
"provided by user")
/**
* customized evaluation function provided by user. default: null
*/
val customEval = new Param[EvalTrait](this, "custom_eval", "customized evaluation function " +
val customEval = new CustomEvalParam(this, "custom_eval", "customized evaluation function " +
"provided by user")
/**
@ -99,7 +97,7 @@ trait GeneralParams extends Params {
* Note that zero timeout value means to wait indefinitely (equivalent to Duration.Inf).
* Ignored if the tracker implementation is "python".
*/
val trackerConf = new Param[TrackerConf](this, "tracker_conf", "Rabit tracker configurations")
val trackerConf = new TrackerConfParam(this, "tracker_conf", "Rabit tracker configurations")
setDefault(round -> 1, nWorkers -> 1, numThreadPerTask -> 1,
useExternalMemory -> false, silent -> 0,

View File

@ -57,7 +57,7 @@ trait LearningTaskParams extends Params {
* group data specify each group sizes for ranking task. To correspond to partition of
* training data, it is nested.
*/
val groupData = new Param[Seq[Seq[Int]]](this, "groupData", "group data specify each group size" +
val groupData = new GroupDataParam(this, "groupData", "group data specify each group size" +
" for ranking task. To correspond to partition of training data, it is nested.")
setDefault(objective -> "reg:linear", baseScore -> 0.5, numClasses -> 2, groupData -> null)

View File

@ -18,17 +18,21 @@ package ml.dmlc.xgboost4j.scala.spark
import java.io.File
import scala.collection.mutable
import scala.collection.mutable.ListBuffer
import scala.io.Source
import scala.util.Random
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
import org.apache.spark.SparkContext
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.feature.{LabeledPoint, VectorAssembler}
import org.apache.spark.ml.linalg.DenseVector
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
import org.apache.spark.sql._
class XGBoostDFSuite extends SharedSparkContext with Utils {

View File

@ -110,7 +110,7 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic",
"tracker_conf" -> TrackerConf(1 minute, "scala")).toMap
"tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala")).toMap
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
nWorkers = numWorkers)
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),

View File

@ -18,67 +18,84 @@ package ml.dmlc.xgboost4j.scala.spark
import java.io.{File, FileNotFoundException}
import scala.util.Random
import org.apache.spark.SparkConf
import org.apache.spark.ml.feature._
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.sql.SparkSession
import scala.concurrent.duration._
case class Foobar(TARGET: Int, bar: Double, baz: Double)
class XGBoostSparkPipelinePersistence extends SharedSparkContext with Utils {
override def afterAll(): Unit = {
super.afterAll()
delete(new File("./testxgbPipe"))
delete(new File("./test2xgbPipe"))
delete(new File("./testxgbEst"))
delete(new File("./testxgbModel"))
delete(new File("./test2xgbModel"))
}
private def delete(f: File) {
if (f.isDirectory()) {
for (c <- f.listFiles()) {
delete(c)
if (f.exists()) {
if (f.isDirectory()) {
for (c <- f.listFiles()) {
delete(c)
}
}
if (!f.delete()) {
throw new FileNotFoundException("Failed to delete file: " + f)
}
}
if (!f.delete()) {
throw new FileNotFoundException("Failed to delete file: " + f)
}
}
test("test sparks pipeline persistence of dataframe-based model") {
// maybe move to shared context, but requires session to import implicits.
// what about introducing https://github.com/holdenk/spark-testing-base ?
val conf: SparkConf = new SparkConf()
.setAppName("foo")
.setMaster("local[*]")
test("test persistence of XGBoostEstimator") {
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "multi:softmax", "num_class" -> "6")
val xgbEstimator = new XGBoostEstimator(paramMap)
xgbEstimator.write.overwrite().save("./testxgbEst")
val loadedxgbEstimator = XGBoostEstimator.read.load("./testxgbEst")
val loadedParamMap = loadedxgbEstimator.fromParamsToXGBParamMap
paramMap.foreach {
case (k, v) => assert(v == loadedParamMap(k).toString)
}
}
val spark: SparkSession = SparkSession
.builder()
.config(conf)
.getOrCreate()
test("test persistence of a complete pipeline") {
val conf = new SparkConf().setAppName("foo").setMaster("local[*]")
val spark = SparkSession.builder().config(conf).getOrCreate()
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "multi:softmax", "num_class" -> "6")
val r = new Random(0)
val assembler = new VectorAssembler().setInputCols(Array("feature")).setOutputCol("features")
val xgbEstimator = new XGBoostEstimator(paramMap)
val pipeline = new Pipeline().setStages(Array(assembler, xgbEstimator))
pipeline.write.overwrite().save("testxgbPipe")
val loadedPipeline = Pipeline.read.load("testxgbPipe")
val loadedEstimator = loadedPipeline.getStages(1).asInstanceOf[XGBoostEstimator]
val loadedParamMap = loadedEstimator.fromParamsToXGBParamMap
paramMap.foreach {
case (k, v) => assert(v == loadedParamMap(k).toString)
}
}
import spark.implicits._
test("test persistence of XGBoostModel") {
val conf = new SparkConf().setAppName("foo").setMaster("local[*]")
val spark = SparkSession.builder().config(conf).getOrCreate()
val r = new Random(0)
// maybe move to shared context, but requires session to import implicits
val df = Seq(Foobar(0, 0.5, 1), Foobar(1, 0.01, 0.8),
Foobar(0, 0.8, 0.5), Foobar(1, 8.4, 0.04))
.toDS
val df = spark.createDataFrame(Seq.fill(10000)(r.nextInt(2)).map(i => (i, i))).
toDF("feature", "label")
val vectorAssembler = new VectorAssembler()
.setInputCols(df.columns
.filter(!_.contains("TARGET")))
.filter(!_.contains("label")))
.setOutputCol("features")
val xgbEstimator = new XGBoostEstimator(Map("num_rounds" -> 10,
"tracker_conf" -> TrackerConf(1 minute, "scala")
))
.setFeaturesCol("features")
.setLabelCol("TARGET")
"tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala")
)).setFeaturesCol("features").setLabelCol("label")
// separate
val predModel = xgbEstimator.fit(vectorAssembler.transform(df))
predModel.write.overwrite.save("test2xgbPipe")
val same2Model = XGBoostModel.load("test2xgbPipe")
predModel.write.overwrite.save("test2xgbModel")
val same2Model = XGBoostModel.load("test2xgbModel")
assert(java.util.Arrays.equals(predModel.booster.toByteArray, same2Model.booster.toByteArray))
val predParamMap = predModel.extractParamMap()
@ -93,8 +110,8 @@ class XGBoostSparkPipelinePersistence extends SharedSparkContext with Utils {
// chained
val predictionModel = new Pipeline().setStages(Array(vectorAssembler, xgbEstimator)).fit(df)
predictionModel.write.overwrite.save("testxgbPipe")
val sameModel = PipelineModel.load("testxgbPipe")
predictionModel.write.overwrite.save("testxgbModel")
val sameModel = PipelineModel.load("testxgbModel")
val predictionModelXGB = predictionModel.stages.collect { case xgb: XGBoostModel => xgb } head
val sameModelXGB = sameModel.stages.collect { case xgb: XGBoostModel => xgb } head